In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions.one_hot_categorical as one_hot_sample

import queuetorch.env as env
from queuetorch.env import QueuingNetwork
import yaml
from tqdm import trange

import numpy as np
import matplotlib.pyplot as plt

In [None]:
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  

In [None]:
def policy_cost(dq, buffer, T, init_queues = None, seed = 42):

    total_cost = torch.tensor([[0.]]*dq.batch)
    obs, state = dq.reset(seed = seed, init_queues = init_queues, buffer = buffer)

    total_cost = torch.tensor([[0.]]*dq.batch)
    torch.manual_seed(seed)

    for _ in trange(T):
        queues, time = obs

        # Softmax policy for differentiability
        pr = F.one_hot(torch.argmax(dq.mu*dq.h * 1.*(queues > 0.).unsqueeze(1), dim = 2), num_classes = dq.q)
        pr = torch.minimum((pr * dq.network), queues.unsqueeze(1).repeat(1, dq.s, 1))
        pr += 1*torch.all(pr == 0., dim = 2).reshape(dq.batch,dq.s,1) * dq.network
        pr /= torch.sum(pr, dim = -1).reshape(dq.batch, dq.s, 1)

        action = one_hot_sample.OneHotCategorical(probs = pr).sample()
        
        obs, state, cost, buffer_cost, event_time = dq.step(state, action, buffer)
        total_cost += cost + buffer_cost

    return total_cost / state.time, queues

def fd_grad(dq, buffer, T = 10000, init_queues = None, sigma = 1, seed = 42):
    
    # finite differences
    # f(x)
    batch_buffer = buffer.repeat(dq.batch,1)
    rand_signs = torch.sign(torch.randn(batch_buffer.size()))

    # Plus
    plus_buffer = torch.round(F.relu(batch_buffer + sigma * rand_signs))

    torch.manual_seed(seed)
    print('plus')
    #print(plus_buffer)

    plus_cost, _ = policy_cost(dq, plus_buffer, T, init_queues, seed)

    # Minus
    minus_buffer = torch.round(F.relu(batch_buffer - sigma * rand_signs))
    obs, state = dq.reset(seed = seed, init_queues = init_queues, buffer = minus_buffer)

    total_cost = torch.tensor([[0.]]*dq.batch)
    torch.manual_seed(seed)
    print('minus')
    #print(minus_buffer)
    
    minus_cost, queues = policy_cost(dq, minus_buffer, T, init_queues, seed)

    # Calculate gradient
    two_point_grad = torch.mean(((plus_cost - minus_cost) / (2*sigma)) * rand_signs, dim = 0)
    
    return two_point_grad, queues.detach()

def pathwise_grad(dq, buffer, T = 10000, init_queues = None, seed = 42):

    # Cost and grad
    avg_cost, queues = policy_cost(dq, buffer.repeat(dq.batch,1), T, init_queues, seed)
    avg_cost.backward()

    grad = buffer.grad

    return grad, queues.detach()


# Experiments

In [None]:
name = 'mm1'
with open(f'../configs/env/{name}.yaml', 'r') as f:
    env_config = yaml.safe_load(f)

In [None]:
dq = env.load_env(env_config, temp = 0.5, batch = 100, seed = 42, device = 'cpu')

In [None]:
H = 10000
b = 100.

# Gradient Descent Loop

In [None]:
buffer_float = 1. * torch.ones((1,dq.q)).float()
buffer = torch.round(buffer_float)
buffer.requires_grad = True
num_iter = 30
alpha = 1.0
buffers = [buffer]

train_seed = 42

for i in range(num_iter):
    
    # Calculate gradient
    dq = env.load_env(env_config, temp = 0.1, batch = 1, seed = train_seed + i, device = 'cpu')
    dq.buffer_control = True
    dq.b = torch.tensor([b]*dq.q)

    if i == 0:
        grad, queues = pathwise_grad(dq, buffer, H, init_queues = None, seed = train_seed)
    else:
        grad, queues = pathwise_grad(dq, buffer, H, init_queues = init_queues, seed = train_seed + i)

    # sign gd
    sign_grad = torch.sign(grad)
    buffer_float = F.relu(buffer_float.detach() - sign_grad)
    print(f'grad: {grad}')
    print(f'buffer_float: {buffer_float}')
    print(f'buffer: {buffer}')

    init_queues = queues.detach()

    buffer = torch.round(buffer_float.detach())
    buffers.append(buffer)
    buffer.requires_grad = True
    

# Evaluate

In [None]:
## Evaluate along the trajectory
batch = 100
test_seed = 90000342
buffer_costs = []

for count, buffer in enumerate(buffers):
    print(buffer)
    
    dq = env.load_env(env_config, temp = 0.1, batch = batch, seed = test_seed, device = 'cpu')
    dq.buffer_control = True
    dq.b = torch.tensor([b]*dq.q)
    
    torch.manual_seed(test_seed)
    obs, state = dq.reset(seed = test_seed)
    total_cost = torch.tensor([[0.]]*batch)

    for _ in trange(H):
        
        queues, time = obs

        # C-mu rule
        pr = F.one_hot(torch.argmax(dq.mu*dq.h * 1.*(queues > 0.).unsqueeze(1), dim = 2), num_classes = dq.q)
        action = one_hot_sample.OneHotCategorical(probs = pr).sample()

        #action = pr
        obs, state, cost, buffer_cost, event_time = dq.step(state, action, buffer)
        total_cost += cost + buffer_cost
    
    print(float(torch.mean(total_cost / state.time)))
    buffer_costs.append(float(torch.mean(total_cost / state.time)))

In [None]:
## Evaluate for all points in a grid
all_buffers = [i for i in range(1,30)]
all_buffer_costs = []

for buffer in all_buffers:
    buffer = torch.tensor([buffer])
    
    dq = env.load_env(env_config, temp = 0.1, batch = batch, seed = test_seed, device = 'cpu')
    dq.buffer_control = True
    dq.b = torch.tensor([b]*dq.q)
    
    torch.manual_seed(test_seed)
    obs, state = dq.reset(seed = test_seed, buffer = buffer)
    total_cost = torch.tensor([[0.]]*batch)

    for _ in trange(H):
        
        queues, time = obs
        
        pr = F.one_hot(torch.argmax(dq.mu*dq.h * 1.*(queues > 0.).unsqueeze(1), dim = 2), num_classes = dq.q)
        action = one_hot_sample.OneHotCategorical(probs = pr).sample()

        #action = pr
        obs, state, cost, buffer_cost, event_time = dq.step(state, action, buffer)
        total_cost += cost + buffer_cost

    avg_cost = torch.mean(total_cost / state.time)
    print(buffer, torch.mean(avg_cost))
    all_buffer_costs.append(float(torch.mean(avg_cost)))

In [None]:
buffer_cost_dict = {}
for count,buf in enumerate(all_buffers):
    buffer_cost_dict[buf] = all_buffer_costs[count]

list_buffers = [int(a) for a in buffers]

In [None]:
plt.plot(all_buffers, all_buffer_costs, zorder = 0, linewidth = 4, color = 'darkgreen', label = 'MM1 Holding Cost')
plt.scatter(list_buffers, [buffer_cost_dict[int(b)] for b in list_buffers], color = 'orangered', zorder = 1, linewidth = 2)
plt.plot(list_buffers, [buffer_cost_dict[int(b)] for b in list_buffers], color = 'orangered', zorder = 3, linewidth = 2, label = r'Pathwise ($B = 1$)')
plt.legend(fontsize = 20)
plt.ylabel('Average Cost', fontsize = 20)
plt.xlabel('Buffer', fontsize = 20)
plt.yticks(fontsize = 20)
plt.xticks(fontsize = 20)
plt.tight_layout()
plt.show()