In [None]:
%cd ..
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions.one_hot_categorical as one_hot_sample

import queuetorch.env as env
from queuetorch.env import QueuingNetwork
import yaml
from tqdm import trange

import numpy as np
import matplotlib.pyplot as plt

In [None]:
name = 'multiclass'
with open(f'./configs/env/{name}.yaml', 'r') as f:
    env_config = yaml.safe_load(f)

In [None]:
dq = env.load_env(env_config, temp = 0.1, batch = 1, seed = 23, device = 'cpu')

In [None]:
dq.h = torch.tensor([1.]*5)
dq.mu = torch.tensor([[[1. + 0.1*i for i in range(1,6)]]])

In [None]:
print(dq.h)
print(dq.mu)

In [None]:
priority = torch.zeros((1,dq.q)).float()
sum_priority = priority.clone()

train_seed = 42

priority.requires_grad = True
alpha = 0.1
num_iter = 50
st_steps = [priority.detach()]
avg_iterate = [sum_priority.clone()]
num = 1


for i in range(num_iter):
    # Calculate gradient
    dq = env.load_env(env_config, temp = 0.00001, batch = 1, seed = 23, device = 'cpu')
    dq.h = torch.tensor([1.]*5)
    dq.mu = torch.tensor([[[1. + 0.1*i for i in range(1,6)]]])

    if i > 0:
        obs, state = dq.reset(seed = train_seed, init_queues = init_queues)
    else:
        obs, state = dq.reset(seed = train_seed)
    total_cost = torch.tensor([[0.]]*dq.batch)
    
    
    for _ in trange(1000):
        queues, time = obs
        
        pr = F.softmax(priority.repeat(dq.batch,dq.s,1), -1) * dq.network
        pr = torch.minimum(pr, queues.unsqueeze(1).repeat(1, dq.s, 1))
        pr += 1*torch.all(pr == 0., dim = 2).reshape(dq.batch,dq.s,1) * dq.network
        pr /= torch.sum(pr, dim = -1).reshape(dq.batch, dq.s, 1) 
        
        action = pr
        obs, state, cost, event_time = dq.step(state, action)
        total_cost += cost

    init_queues = queues.detach()
    avg_cost = torch.mean(total_cost / state.time)
    avg_cost.backward()

    print(f'priority: {priority}')
    print(f'avg_cost: {avg_cost}')

    normalized_grad = priority.grad / torch.linalg.norm(priority.grad)
    #normalized_grad = priority.grad
    
    priority = priority.detach() - alpha * normalized_grad
    print(f'grad: {normalized_grad}')
    print()
    
    st_steps.append(priority.detach())
    sum_priority += priority.detach()
    num += 1
    avg_iterate.append(sum_priority.clone() / num)
    
    priority.requires_grad = True
    

In [None]:
st_steps_l = torch.stack(avg_iterate)[:,0,:]

In [None]:
k = 30
plt.style.use('seaborn-v0_8-white')
plt.bar([f'{i}' for i in range(1,1+len(st_steps_l[k]))], st_steps_l[k], color = 'orangered', label = 'Pathwise (B = 1)')
plt.axhline(0, color = 'black')
plt.ylabel(r'Policy Score $\theta_{j}$', fontsize = 20)
plt.xlabel('Queue', fontsize = 20)
plt.yticks(fontsize = 20)
plt.xticks(fontsize = 20)
plt.legend(fontsize = 15)

plt.tight_layout()
#plt.savefig('./plot/cmu_bar_q_5_value.png',dpi = 300)
plt.show()