In [1]:
import sys
import random
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import sumolib
import traci
from sumolib import checkBinary
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.data import Data, Batch
import sys
import io
from contextlib import redirect_stdout
import matplotlib.pyplot as plt
import pandas as pd
import os
import math
from collections import namedtuple, deque
import gym
from torch_geometric.utils import dense_to_sparse
import copy
from itertools import count

if 'SUMO_HOME' in os.environ:
    print('SUMO_HOME found')
    sys.path.append(os.path.join(os.environ['SUMO_HOME'], 'tools'))

# sumoBinary = checkBinary('sumo-gui')
sumoBinary = checkBinary('sumo')
roadNetwork = "./config/osm.sumocfg"
sumoCmd = [sumoBinary, "-c", roadNetwork, "--start", "--quit-on-end"]
# use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print("Using device: " + str(device))

SUMO_HOME found
Using device: cuda


In [2]:
def intervehicleConnectivity(threshold = None):
    xs = []
    ys = []
    for vehicle in traci.vehicle.getIDList():
        x, y = traci.vehicle.getPosition(vehicle)
        xs.append(x)
        ys.append(y)
    xs = torch.tensor(xs, dtype=torch.float32).view(-1,1)
    ys = torch.tensor(ys, dtype=torch.float32).view(-1,1)
    intervehicle_distances = torch.sqrt((xs - xs.t())**2 + (ys - ys.t())**2)
    if threshold is not None:
        # make the distances 1 if less than the threshold, 0 otherwise
        connectivity = torch.where(intervehicle_distances < threshold, torch.ones_like(intervehicle_distances), torch.zeros_like(intervehicle_distances))
    return connectivity, xs, ys

In [3]:
def randomTrips(dur=1000, density=12):
    os.system("python $SUMO_HOME/tools/randomTrips.py -n config/osm.net.xml.gz -r config/osm.passenger.trips.xml -e " + str(dur) + " -l --insertion-density=" + str(density))

def shouldContinueSim():
    numVehicles = traci.simulation.getMinExpectedNumber()
    return True if numVehicles > 0 else False

def restart(sumoCmd):
    with io.StringIO() as buf, redirect_stdout(buf):
        try:
            traci.close()
        except:
            pass
        traci.start(sumoCmd)

def close():
    traci.close()

# randomTrips(800, 1.5)

In [4]:
def simplify_graph(adj_matrix):
    adj_matrix = adj_matrix - torch.eye(adj_matrix.size(0))
    degrees = torch.sum(adj_matrix, axis=0)
    nodes_to_keep = np.where(degrees > 0)[0]
    new_adj_matrix = adj_matrix[np.ix_(nodes_to_keep, nodes_to_keep)]
    return new_adj_matrix

def bfs_distance(adj_matrix):
    n_hop_matrix = torch.ones_like(adj_matrix) * (-100)
    for start_node in range(adj_matrix.size(0)):
        visited = [0] * adj_matrix.size(0)
        queue = deque([(start_node, 0)])
        visited[start_node] = True
        
        while queue:
            current_node, current_dist = queue.popleft()
            
            for neighbor, connected in enumerate(adj_matrix[current_node]):
                if connected and not visited[neighbor]:
                    queue.append((neighbor, current_dist + 1))
                    visited[neighbor] = True
                    n_hop_matrix[start_node, neighbor] = current_dist + 1

    return n_hop_matrix

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        
        # Make sure that the d_model is divisible by the number of heads
        assert self.head_dim * num_heads == d_model
        
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value):
        batch_size = query.size(0)

        Q = self.query_linear(query)  # (batch_size, seq_len, d_model)
        K = self.key_linear(key)      # (batch_size, seq_len, d_model)
        V = self.value_linear(value)  # (batch_size, seq_len, d_model)
        
        # Reshape and transpose for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        attn_weights = F.softmax(scores, dim=-1)
        
        # Weighted sum
        attn_output = torch.matmul(attn_weights, V)
        
        # Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.out_linear(attn_output)
        
        return output

In [6]:
Transition = namedtuple('Transition',
                        ('data', 'action', 'next_state', 'reward', 'shuffle_indices'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


In [7]:
class GDQN_Attention(nn.Module):
    def __init__(self, in_channels=13, n_nodes=55, hidden_dim=32, max_n_neighbors=6):
        super(GDQN_Attention, self).__init__()
        self.n_nodes = n_nodes
        self.hidden_dim = hidden_dim
        self.convs1 = SAGEConv(in_channels, hidden_dim)
        self.convs2 = SAGEConv(hidden_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim * max_n_neighbors, max_n_neighbors)
        self.selu = nn.SELU()
        self.cross_attn = CrossAttention(hidden_dim, 1)
        self.max_n_neighbors = max_n_neighbors

    def forward(self, data, shuffle_indices=None):
        x, edge_index = data.x, data.edge_index
        batch_neighbor_mask = x[:, -3].reshape(-1, self.n_nodes)
        x = self.convs1(x, edge_index)
        x = self.selu(x)
        x = self.convs2(x, edge_index)
        x = self.selu(x)

        keys = x.reshape(-1, self.n_nodes, self.hidden_dim)
        values = x.reshape(-1, self.n_nodes, self.hidden_dim)
        batch_size = keys.size(0)
        querys = []
        for i in range(batch_size):
            neighbor_indices = torch.where(batch_neighbor_mask[i])
            query = keys[i][neighbor_indices]
            if query.size(0) < self.max_n_neighbors:
                query = F.pad(query, (0, 0, 0, self.max_n_neighbors - query.size(0)), "constant", 0)
            querys.append(query)
        querys = torch.stack(querys)

        if batch_size == 1:
            shuffle_indices = shuffle_indices.unsqueeze(0)
        if shuffle_indices is not None:
            batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, shuffle_indices.size(1))
            shuffled_query = querys[batch_indices, shuffle_indices]
        else:
            shuffled_query = querys
        x = self.cross_attn(shuffled_query, keys, values)
        x = x.reshape(-1, self.hidden_dim * self.max_n_neighbors)
        x = self.selu(x)
        x = self.fc1(x)
        return x

In [8]:
class RoutingGym(gym.Env):
    def __init__(self, sumoCmd, max_steps=1100, n_nodes=55, max_routing_steps=30, max_n_neighbors=6):
        self.sumoCmd = sumoCmd
        self.step_counter = 0
        self.max_steps = max_steps
        self.n_nodes = n_nodes
        self.start_node = None
        self.end_node = None
        self.current_node = None
        self.node_features = None
        self.adj_matrix = None
        self.edge_index = None
        self.hop_thresh = None
        self.routing_done = False
        self.routing_steps = 0
        self.min_n_hops = None
        self.end_node_indicator = torch.zeros(n_nodes)
        self.max_routing_steps = max_routing_steps
        self.n_hop_matrix = None

        self.to_remove_indices = None
        self.prunned_adj_matrix = None
        self.prunned_n_hop_matrix = None
        self.state = None
        self.max_n_neighbors = max_n_neighbors
        self.one_hot_distances = None
        self.curr_node_indicator = None
        self.neighbour_indicator = None

        self.xs = None
        self.ys = None
        self.norm_x = None
        self.norm_y = None

    def reset(self):
        try:
            traci.close()
        except:
            pass
        traci.start(sumoCmd)
        self.step_counter = 0

        while self.step_counter < 400:
            traci.simulationStep()
            self.step_counter += 1

    def node_pruning(self):
        self.prunned_adj_matrix = copy.deepcopy(self.adj_matrix)
        self.prunned_n_hop_matrix = copy.deepcopy(self.n_hop_matrix)
        neighbor_indices = np.where(self.adj_matrix[self.current_node] == 1)[0]
        if len(neighbor_indices) >= self.max_n_neighbors:
            two_hop_neighbours_indices = np.where(self.n_hop_matrix[self.current_node] == 2)[0]
            two_hop_neighbours_mask = (self.n_hop_matrix[self.current_node] == 2).type(torch.int)
            # direct neighbours connectivities with two hop neighbours
            neighbour_dict = {}
            for neighbour_index in neighbor_indices:
                neighbour_dict[neighbour_index] = two_hop_neighbours_indices[np.where(self.adj_matrix[neighbour_index][two_hop_neighbours_indices] == 1)[0]]
            # sort by the number of two hop neighbours
            neighbour_dict = dict(sorted(neighbour_dict.items(), key=lambda item: len(item[1]), reverse=True))

            self.to_remove_indices = []
            action_space = 0
            for neighbour_index, two_hop_neighbours_indices in neighbour_dict.items():
                mask_sum_before = torch.sum(two_hop_neighbours_mask)
                two_hop_neighbours_mask[two_hop_neighbours_indices] = 0
                mask_sum_after = torch.sum(two_hop_neighbours_mask)
                if mask_sum_after < mask_sum_before:
                    action_space += 1
                else:
                    self.to_remove_indices.append(neighbour_index)
            if action_space < self.max_n_neighbors:
                self.to_remove_indices = random.sample(self.to_remove_indices, len(self.to_remove_indices) - (self.max_n_neighbors - action_space))
            self.prunned_adj_matrix[self.to_remove_indices, :] = 0
            self.prunned_adj_matrix[:, self.to_remove_indices] = 0
            self.prunned_n_hop_matrix[self.to_remove_indices, :] = 0
            self.prunned_n_hop_matrix[:, self.to_remove_indices] = 0

    def step(self):
        traci.simulationStep()
        self.routing_done = False
        self.routing_steps = 0
        self.step_counter += 1
        self.adj_matrix, self.xs, self.ys = intervehicleConnectivity(800)
        self.adj_matrix = simplify_graph(self.adj_matrix)
        self.select_start_end_nodes()
        self.current_node = self.start_node
        self.adj_matrix = F.pad(self.adj_matrix, (0, self.n_nodes - self.adj_matrix.size(0), 
                                                  0, self.n_nodes - self.adj_matrix.size(1)), "constant", 0)
        self.n_hop_matrix = F.pad(self.n_hop_matrix, (0, self.n_nodes - self.n_hop_matrix.size(0), 
                                                      0, self.n_nodes - self.n_hop_matrix.size(1)), "constant", -100)
        self.node_pruning()
        # set diagonal to 0
        self.prunned_n_hop_matrix = self.prunned_n_hop_matrix - torch.diag(torch.diag(self.prunned_n_hop_matrix))
        self.curr_node_indicator = torch.zeros(self.n_nodes)
        self.curr_node_indicator[self.current_node] = 1
        self.end_node_indicator = self.prunned_adj_matrix[self.end_node]
        distances = self.prunned_n_hop_matrix[self.current_node]
        distances[distances == -100] = 8
        distances[distances > 7] = 7
        self.one_hot_distances = F.one_hot(distances.long(), num_classes=8).type(torch.float32)
        self.neighbour_indicator = self.prunned_adj_matrix[self.current_node]
        self.node_features = torch.cat((self.one_hot_distances, self.end_node_indicator.unsqueeze(1), 
                                        self.curr_node_indicator.unsqueeze(1), self.neighbour_indicator.unsqueeze(1), 
                                        self.norm_x, self.norm_y), dim=1).to(device)
        self.state = Data(x=self.node_features, edge_index=self.get_edge_index())
        return self.state
    
    def refresh(self):
        self.select_start_end_nodes()
        self.routing_done = False
        self.routing_steps = 0
        self.current_node = self.start_node
        self.node_pruning()
        self.prunned_n_hop_matrix = self.prunned_n_hop_matrix - torch.diag(torch.diag(self.prunned_n_hop_matrix))
        self.curr_node_indicator = torch.zeros(self.n_nodes)
        self.curr_node_indicator[self.current_node] = 1
        self.end_node_indicator = self.prunned_adj_matrix[self.end_node]
        self.neighbour_indicator = self.prunned_adj_matrix[self.current_node]
        self.neighbour_indicator = self.prunned_adj_matrix[self.current_node]
        self.node_features = torch.cat((self.one_hot_distances, self.end_node_indicator.unsqueeze(1),
                                            self.curr_node_indicator.unsqueeze(1), self.neighbour_indicator.unsqueeze(1), 
                                            self.norm_x, self.norm_y), dim=1).to(device)
        self.state = Data(x=self.node_features, edge_index=self.get_edge_index())
        return self.state


    def select_start_end_nodes(self):
        self.n_hop_matrix = bfs_distance(self.adj_matrix)
        self.hop_thresh = min(self.n_hop_matrix.max(), 5)
        starts, ends = torch.where(self.hop_thresh == self.n_hop_matrix)
        starts = starts.tolist()
        ends = ends.tolist()
        self.start_node, self.end_node = random.choice(list(zip(starts, ends)))
        # minimal number of hops between start and end nodes
        self.min_n_hops = self.n_hop_matrix[self.start_node, self.end_node]

        start_x = self.xs[self.start_node]
        start_y = self.ys[self.start_node]
        end_x = self.xs[self.end_node]
        end_y = self.ys[self.end_node]
        self.norm_x = (self.xs - end_x) / (start_x - end_x)
        self.norm_y = (self.ys - end_y) / (start_y - end_y)
        self.norm_x = F.pad(self.norm_x, (0, 0, 0, self.n_nodes - self.norm_x.size(0)), "constant", 0)
        self.norm_y = F.pad(self.norm_y, (0, 0, 0, self.n_nodes - self.norm_y.size(0)), "constant", 0)

    def act(self, neighbor_index):
        self.routing_steps += 1
        neighbors = torch.where(self.prunned_adj_matrix[self.current_node] == 1)[0]
        valid_action_size = len(neighbors)
        if valid_action_size <= neighbor_index:
            self.routing_done = self.routing_steps >= self.max_routing_steps
            if self.routing_done:
                print("Failed, ", self.min_n_hops)
                return self.state, torch.tensor(-1).to(device), self.routing_done
            self.node_features = torch.cat((self.one_hot_distances, self.end_node_indicator.unsqueeze(1),
                                             self.curr_node_indicator.unsqueeze(1), self.neighbour_indicator.unsqueeze(1),
                                               self.norm_x, self.norm_y), dim=1).to(device)
            return self.state, torch.tensor(-0.15).to(device), self.routing_done
        else:
            next_hop = neighbors[neighbor_index]
            reward = self.compute_reward(next_hop)
            self.current_node = next_hop
            self.node_pruning()
            self.curr_node_indicator = torch.zeros(self.n_nodes)
            self.curr_node_indicator[self.current_node] = 1
            self.neighbour_indicator = self.prunned_adj_matrix[self.current_node]
            self.node_features = torch.cat((self.one_hot_distances, self.end_node_indicator.unsqueeze(1),
                                             self.curr_node_indicator.unsqueeze(1), self.neighbour_indicator.unsqueeze(1),
                                               self.norm_x, self.norm_y), dim=1).to(device)
            self.state = Data(x=self.node_features, edge_index=self.get_edge_index())
            return self.state, torch.tensor(reward).to(device), self.routing_done

    def get_adj_matrix(self):
        return copy.deepcopy(self.adj_matrix).to(device)
    
    def get_edge_index(self):
        self.edge_index, _ = dense_to_sparse(self.prunned_adj_matrix)
        return copy.deepcopy(self.edge_index).to(device)

    def compute_reward(self, next_hop):
        if self.routing_steps >= self.max_routing_steps:
            print("Failed, ", self.min_n_hops)
            self.routing_done = True
            return -1
        elif self.adj_matrix[self.current_node, self.end_node] == 1:
            print("Routing done, number of hops: ", self.routing_steps, " minimum number of hops: ", self.min_n_hops)
            self.routing_done = True
            return (self.min_n_hops / self.routing_steps)
        elif self.n_hop_matrix[self.current_node, self.end_node] > self.n_hop_matrix[next_hop, self.end_node]:
            return 0.1
        else:
            return -0.15

    def get_action_mask(self):
        action_mask = copy.deepcopy(self.prunned_adj_matrix[self.current_node])
        action_mask = F.pad(action_mask, (0, self.n_nodes - action_mask.size(0)), "constant", 0).to(device)
        return action_mask
        
    def sim_done(self):
        """
        function: get the done state of simulation.
        """
        return not (shouldContinueSim() and self.step_counter <= self.max_steps)

In [9]:
BATCH_SIZE = 128
GAMMA = 0.95
EPS_START = 0.2
EPS_END = 0.001
EPS_DECAY = 2000
TAU = 0.05
LR = 0.001

n_nodes = 55
env = RoutingGym(sumoCmd, 1100, n_nodes, max_n_neighbors=6)
# env = DummyEnv()
max_n_neighbors = 6

policy_net = GDQN_Attention(n_nodes=n_nodes, max_n_neighbors=max_n_neighbors).to(device)
target_net = GDQN_Attention(n_nodes=n_nodes, max_n_neighbors=max_n_neighbors).to(device)
target_net.load_state_dict(policy_net.state_dict())
# policy_net.load_state_dict(torch.load("policy_net.pth"))
# target_net.load_state_dict(torch.load("policy_net.pth"))

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(1000000)


steps_done = 0


def select_action(data, action_mask, shuffle_indices):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return torch.tensor([policy_net(data, shuffle_indices).max(1).indices.item()], device=device), True
    else:
        valid_size = len(torch.where(action_mask == 1)[0])
        return torch.randint(0, valid_size, (1,), device=device), False

In [10]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = Batch.from_data_list([s for s in batch.next_state
                                                if s is not None])
    data_batch = Batch.from_data_list(batch.data)
    action_batch = torch.stack(batch.action)
    reward_batch = torch.concat(batch.reward)
    shuffle_indices_batch = torch.stack(batch.shuffle_indices)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(data_batch, shuffle_indices_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        # we do not need to shuffle the indices for the target network
        # firstly, we don't train the target network
        # secondly, we only need the max value anyway
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()


In [11]:
# Initialize the environment and get its state
n_epoch = 10
for e in range(n_epoch):
    env.reset()
    done = False
    step_num = 0
    while not done:
        step_num += 1
        done = env.sim_done()
        state = env.step()
        for i in range(200):
            state = env.refresh()
            accumulated_reward = 0
            routing_done = False
            while not routing_done:
                action_mask = env.get_action_mask()

                # shuffle_indices = torch.randperm(max_n_neighbors)
                shuffle_indices = torch.tensor([0, 1, 2, 3, 4, 5])
                action, use_policy = select_action(state, action_mask, shuffle_indices)

                if use_policy:
                    next_state, reward, routing_done = env.act(shuffle_indices[action.item()])
                else:
                    next_state, reward, routing_done = env.act(action.item())
                reward = torch.tensor([reward], device=device)
                accumulated_reward += reward.item()

                routing_done = routing_done

                if routing_done:
                    memory.push(state, action, None, reward, shuffle_indices)
                else:
                    memory.push(state, action, next_state, reward, shuffle_indices)

                # Move to the next state
                state = next_state

                optimize_model()
                if steps_done % 20 == 0:
                    target_net_state_dict = target_net.state_dict()
                    policy_net_state_dict = policy_net.state_dict()
                    for key in policy_net_state_dict:
                        target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
                    target_net.load_state_dict(target_net_state_dict)
            print(f"Step: {step_num}, Iteration: {i}, Accumulated reward: {accumulated_reward}")

print('Complete')
plt.ioff()
plt.show()

 Retrying in 1 seconds
***Starting server on port 46029 ***
Loading net-file from './config/osm.net.xml.gz' ... done (104ms).
Loading done.
Simulation version 1.20.0 started with time: 0.00.
Failed,  tensor(5.)
Step: 1, Iteration: 0, Accumulated reward: -2.100000075995922
Failed,  tensor(5.)
Step: 1, Iteration: 1, Accumulated reward: -4.850000157952309
Failed,  tensor(5.)
Step: 1, Iteration: 2, Accumulated reward: -4.3500001430511475
Failed,  tensor(5.)
Step: 1, Iteration: 3, Accumulated reward: -2.100000075995922
Failed,  tensor(5.)
Step: 1, Iteration: 4, Accumulated reward: -2.100000075995922
Routing done, number of hops:  5  minimum number of hops:  tensor(5.)
Step: 1, Iteration: 5, Accumulated reward: 1.4000000059604645


  return self.state, torch.tensor(reward).to(device), self.routing_done


Routing done, number of hops:  29  minimum number of hops:  tensor(5.)
Step: 1, Iteration: 6, Accumulated reward: -0.5275862663984299
Routing done, number of hops:  11  minimum number of hops:  tensor(5.)
Step: 1, Iteration: 7, Accumulated reward: 0.7045454606413841
Routing done, number of hops:  15  minimum number of hops:  tensor(5.)
Step: 1, Iteration: 8, Accumulated reward: 0.4833333268761635
Routing done, number of hops:  7  minimum number of hops:  tensor(5.)
Step: 1, Iteration: 9, Accumulated reward: 1.064285732805729
Failed,  tensor(5.)
Step: 1, Iteration: 10, Accumulated reward: -1.600000061094761
Failed,  tensor(5.)
Step: 1, Iteration: 11, Accumulated reward: -2.6000000908970833
Failed,  tensor(5.)
Step: 1, Iteration: 12, Accumulated reward: -4.600000150501728
Failed,  tensor(5.)
Step: 1, Iteration: 13, Accumulated reward: -3.1000001057982445
Routing done, number of hops:  16  minimum number of hops:  tensor(5.)
Step: 1, Iteration: 14, Accumulated reward: -0.18750003725290298

KeyboardInterrupt: 

In [None]:
# # save policy net and target net
# torch.save(policy_net.state_dict(), 'policy_net.pth')
# torch.save(target_net.state_dict(), 'target_net.pth')

In [None]:
# # Initialize the environment and get its state
# n_epoch = 10
# for e in range(n_epoch):
#     # env.reset()
#     done = False
#     episode_num = 0
#     state = env.step()
#     while not done:
#         # done = env.sim_done()
#         # done = False
#         state = env.refresh()
#         episode_num += 1
#         routing_done = False
#         accumulated_reward = 0
#         while not routing_done:
#             action_mask = env.get_action_mask()

#             # shuffle_indices = torch.randperm(max_n_neighbors)
#             shuffle_indices = torch.tensor([0, 1, 2, 3, 4, 5])
#             action, use_policy = select_action(state, action_mask, shuffle_indices)

#             if use_policy:
#                 next_state, reward, routing_done = env.act(shuffle_indices[action.item()])
#             else:
#                 next_state, reward, routing_done = env.act(action.item())
#             reward = torch.tensor([reward], device=device)
#             accumulated_reward += reward.item()

#             routing_done = routing_done

#             if routing_done:
#                 memory.push(state, action, None, reward, shuffle_indices)
#             else:
#                 memory.push(state, action, next_state, reward, shuffle_indices)

#             # Move to the next state
#             state = next_state

#             optimize_model()
#             if steps_done % 20 == 0:
#                 target_net_state_dict = target_net.state_dict()
#                 policy_net_state_dict = policy_net.state_dict()
#                 for key in policy_net_state_dict:
#                     target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
#                 target_net.load_state_dict(target_net_state_dict)
#         print(f"Episode: {episode_num}, Accumulated reward: {accumulated_reward}")

# print('Complete')
# plt.ioff()
# plt.show()


In [28]:
# class DummyEnv(gym.Env):
#     def __init__(self, max_n_neighbors=2):
#         super(DummyEnv, self).__init__()
#         self.adj_matrix = torch.tensor([[1, 1, 0, 0, 1, 1, 0, 0],
#                                         [1, 1, 1, 0, 1, 1, 0, 0],
#                                         [0, 1, 1, 1, 0, 0, 1, 1],
#                                         [0, 0, 1, 1, 0, 0, 0, 0],
#                                         [1, 1, 0, 0, 1, 0, 1, 0],
#                                         [1, 1, 0, 0, 0, 1, 0, 0],
#                                         [0, 0, 1, 0, 1, 0, 1, 0],
#                                         [0, 0, 1, 0, 0, 0, 0, 1]])
#         self.adj_matrix = self.adj_matrix - torch.eye(self.adj_matrix.size(0))
#         self.adj_matrix = F.pad(self.adj_matrix, (0, 55 - self.adj_matrix.size(0), 0, 55 - self.adj_matrix.size(1)), "constant", 0)
#         self.n_hop_matrix = bfs_distance(self.adj_matrix)
#         self.n_hop_matrix = self.n_hop_matrix - torch.diag(torch.diag(self.n_hop_matrix))
#         self.prunned_adj_matrix = None
#         self.prunned_n_hop_matrix = None
#         self.num_nodes = self.adj_matrix.size(0)
#         self.start_node = 0
#         self.end_node = 3
#         self.current_node = self.start_node
#         self.edge_index = dense_to_sparse(self.adj_matrix)[0]
#         self.routing_steps = 0
#         self.end_node_indicator = torch.zeros(self.num_nodes)
#         self.end_node_indicator[self.end_node] = 1
#         self.routing_done = False
#         self.max_routing_steps = 30
#         self.min_n_hops = 3
#         self.n_nodes = 55
#         self.to_remove_indices = []
#         self.state = None
#         self.max_n_neighbors = max_n_neighbors

#     def reset(self):
#         pass

#     def node_pruning(self):
#         self.prunned_adj_matrix = copy.deepcopy(self.adj_matrix)
#         self.prunned_n_hop_matrix = copy.deepcopy(self.n_hop_matrix)
#         neighbor_indices = np.where(self.adj_matrix[self.current_node] == 1)[0]
#         if len(neighbor_indices) >= self.max_n_neighbors:
#             two_hop_neighbours_indices = np.where(self.n_hop_matrix[self.current_node] == 2)[0]
#             two_hop_neighbours_mask = (self.n_hop_matrix[self.current_node] == 2).type(torch.int)
#             # direct neighbours connectivities with two hop neighbours
#             neighbour_dict = {}
#             for neighbour_index in neighbor_indices:
#                 neighbour_dict[neighbour_index] = two_hop_neighbours_indices[np.where(self.adj_matrix[neighbour_index][two_hop_neighbours_indices] == 1)[0]]
#             # sort by the number of two hop neighbours
#             neighbour_dict = dict(sorted(neighbour_dict.items(), key=lambda item: len(item[1]), reverse=True))

#             self.to_remove_indices = []
#             action_space = 0
#             for neighbour_index, two_hop_neighbours_indices in neighbour_dict.items():
#                 mask_sum_before = torch.sum(two_hop_neighbours_mask)
#                 two_hop_neighbours_mask[two_hop_neighbours_indices] = 0
#                 mask_sum_after = torch.sum(two_hop_neighbours_mask)
#                 if mask_sum_after < mask_sum_before:
#                     action_space += 1
#                 else:
#                     self.to_remove_indices.append(neighbour_index)
#             if action_space < self.max_n_neighbors:
#                 self.to_remove_indices = random.sample(self.to_remove_indices, len(self.to_remove_indices) - (self.max_n_neighbors - action_space))
#             self.prunned_adj_matrix[self.to_remove_indices, :] = 0
#             self.prunned_adj_matrix[:, self.to_remove_indices] = 0
#             self.prunned_n_hop_matrix[self.to_remove_indices, :] = 0
#             self.prunned_n_hop_matrix[:, self.to_remove_indices] = 0

#     def get_edge_index(self):
#         self.edge_index, _ = dense_to_sparse(self.prunned_adj_matrix)
#         return copy.deepcopy(self.edge_index).to(device)
    
#     def step(self):
#         self.routing_done = False
#         self.select_start_end_nodes()
#         self.current_node = self.start_node
#         self.routing_steps = 0
#         self.node_pruning()
#         self.current_node_indicator = torch.zeros(self.num_nodes)
#         self.current_node_indicator[self.current_node] = 1
#         self.neighbors_indicator = self.prunned_adj_matrix[self.current_node]
#         self.node_features = torch.stack((self.current_node_indicator, self.prunned_n_hop_matrix[self.current_node], 
#                                           self.end_node_indicator, self.neighbors_indicator)).T.to(device)
#         self.state = Data(x=self.node_features, edge_index=self.get_edge_index())
#         return self.state
    
#     def select_start_end_nodes(self):
#         self.hop_thresh = min(self.n_hop_matrix.max(), 2)
#         starts, ends = torch.where(self.hop_thresh <= self.n_hop_matrix)
#         starts = starts.tolist()
#         ends = ends.tolist()
#         self.start_node, self.end_node = random.choice(list(zip(starts, ends)))
#         self.end_node_indicator = torch.zeros(self.num_nodes)
#         self.end_node_indicator[self.end_node] = 1
#         # minimal number of hops between start and end nodes
#         self.min_n_hops = self.n_hop_matrix[self.start_node, self.end_node]
    
#     def act(self, neighbor_index):
#         self.routing_steps += 1
#         neighbors = torch.where(self.prunned_adj_matrix[self.current_node] == 1)[0]
#         valid_action_size = len(neighbors)
#         if valid_action_size <= neighbor_index:
#             return self.state, torch.tensor(0).to(device), False
#         else:
#             next_hop = neighbors[neighbor_index]
#             reward = self.compute_reward(next_hop)
#             self.current_node = next_hop
#             self.node_pruning()
#             curr_node_indicators = torch.zeros(self.n_nodes)
#             curr_node_indicators[self.current_node] = 1
#             self.neighbors_indicator = self.prunned_adj_matrix[self.current_node]
#             self.node_features = torch.stack((curr_node_indicators, 
#                                               self.prunned_n_hop_matrix[self.current_node], self.end_node_indicator, self.neighbors_indicator)).T.to(device)
#             self.state = Data(x=self.node_features, edge_index=self.get_edge_index())
#             return self.state, torch.tensor(reward).to(device), self.routing_done
    
#     def compute_reward(self, next_hop):
#         if self.routing_steps >= self.max_routing_steps:
#             print("Failed, ", self.start_node, self.end_node)
#             self.routing_done = True
#             return -1
#         elif self.adj_matrix[self.current_node, self.end_node] == 1:
#             print("Routing done, number of hops: ", self.routing_steps, " minimum number of hops: ", self.min_n_hops)
#             self.routing_done = True
#             return (self.min_n_hops / self.routing_steps) + 1
#         elif self.n_hop_matrix[self.current_node, self.end_node] > self.n_hop_matrix[next_hop, self.end_node]:
#             return 0.1
#         else:
#             return 0

#     def get_action_mask(self):
#         action_mask = copy.deepcopy(self.prunned_adj_matrix[self.current_node])
#         action_mask = F.pad(action_mask, (0, self.n_nodes - action_mask.size(0)), "constant", 0).to(device)
#         return action_mask

In [None]:


# class SumTree:
#     def __init__(self, capacity):
#         self.capacity = capacity
#         self.tree = np.zeros(2 * capacity - 1)
#         self.data = np.zeros(capacity, dtype=object)
#         self.data_pointer = 0

#     def add(self, priority, data):
#         tree_index = self.data_pointer + self.capacity - 1
#         self.data[self.data_pointer] = data
#         self.update(tree_index, priority)
#         self.data_pointer += 1
#         if self.data_pointer >= self.capacity:
#             self.data_pointer = 0

#     def update(self, tree_index, priority):
#         change = priority - self.tree[tree_index]
#         self.tree[tree_index] = priority
#         while tree_index != 0:
#             tree_index = (tree_index - 1) // 2
#             self.tree[tree_index] += change

#     def get_leaf(self, v):
#         parent_index = 0
#         while True:
#             left_child_index = 2 * parent_index + 1
#             right_child_index = left_child_index + 1
#             if left_child_index >= len(self.tree):
#                 leaf_index = parent_index
#                 break
#             else:
#                 if v <= self.tree[left_child_index]:
#                     parent_index = left_child_index
#                 else:
#                     v -= self.tree[left_child_index]
#                     parent_index = right_child_index
#         data_index = leaf_index - self.capacity + 1
#         return leaf_index, self.tree[leaf_index], self.data[data_index]

#     @property
#     def total_priority(self):
#         return self.tree[0]


# class PrioritizedReplayBuffer:
#     def __init__(self, capacity, alpha):
#         self.tree = SumTree(capacity)
#         self.alpha = alpha
        
#     def push(self, *args):
#         max_priority = np.max(self.tree.tree[-self.tree.capacity:])
#         if max_priority == 0:
#             max_priority = 1.0
#         experience = Transition(*args)
#         self.tree.add(max_priority, experience)

#     def sample(self, batch_size, beta):
#         batch = []
#         indices = []
#         priorities = []
#         segment = self.tree.total_priority / batch_size

#         for i in range(batch_size):
#             a = segment * i
#             b = segment * (i + 1)
#             s = random.uniform(a, b)
#             index, priority, data = self.tree.get_leaf(s)
#             batch.append(data)
#             indices.append(index)
#             priorities.append(priority)

#         sampling_probabilities = priorities / self.tree.total_priority
#         is_weights = np.power(self.tree.capacity * sampling_probabilities, -beta)
#         is_weights /= is_weights.max()
#         states = np.array([e.state for e in batch])
#         actions = np.array([e.action for e in batch])
#         rewards = np.array([e.reward for e in batch])
#         next_states = np.array([e.next_state for e in batch])
#         dones = np.array([e.done for e in batch])

#         return (states, actions, rewards, next_states, dones), indices, is_weights

#     def update_priorities(self, indices, errors):
#         priorities = np.abs(errors) + 1e-5
#         for index, priority in zip(indices, priorities):
#             self.tree.update(index, priority)

#     def __len__(self):
#         return len(self.tree.data)

In [None]:
# # Initialize the environment and get its state
# n_epoch = 10
# for e in range(n_epoch):
#     env.reset()
#     done = False
#     step_num = 0
#     while not done:
#         step_num += 1
#         done = env.sim_done()
#         state = env.step()
#         for i in range(200):
#             state = env.refresh()
#             accumulated_reward = 0
#             routing_done = False
#             while not routing_done:
#                 action_mask = env.get_action_mask()

#                 # shuffle_indices = torch.randperm(max_n_neighbors)
#                 shuffle_indices = torch.tensor([0, 1, 2, 3, 4, 5])
#                 action, use_policy = select_action(state, action_mask, shuffle_indices)

#                 if use_policy:
#                     next_state, reward, routing_done = env.act(shuffle_indices[action.item()])
#                 else:
#                     next_state, reward, routing_done = env.act(action.item())
#                 reward = torch.tensor([reward], device=device)
#                 accumulated_reward += reward.item()

#                 routing_done = routing_done

#                 if routing_done:
#                     memory.push(state, action, None, reward, shuffle_indices)
#                 else:
#                     memory.push(state, action, next_state, reward, shuffle_indices)

#                 # Move to the next state
#                 state = next_state

#                 # optimize_model()
#                 # if steps_done % 20 == 0:
#                 #     target_net_state_dict = target_net.state_dict()
#                 #     policy_net_state_dict = policy_net.state_dict()
#                 #     for key in policy_net_state_dict:
#                 #         target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
#                 #     target_net.load_state_dict(target_net_state_dict)
#             print(f"Step: {step_num}, Iteration: {i}, Accumulated reward: {accumulated_reward}")

# print('Complete')
# plt.ioff()
# plt.show()