In [114]:
import sys
import random
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import sumolib
import traci
from sumolib import checkBinary
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
import sys
import io
from contextlib import redirect_stdout
import matplotlib.pyplot as plt
import pandas as pd
import os
import math
from collections import namedtuple, deque
import gym
from torch_geometric.utils import dense_to_sparse
import copy
from itertools import count

if 'SUMO_HOME' in os.environ:
    print('SUMO_HOME found')
    sys.path.append(os.path.join(os.environ['SUMO_HOME'], 'tools'))

# sumoBinary = checkBinary('sumo-gui')
sumoBinary = checkBinary('sumo')
roadNetwork = "./config/osm.sumocfg"
sumoCmd = [sumoBinary, "-c", roadNetwork, "--start", "--quit-on-end"]
# use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print("Using device: " + str(device))

SUMO_HOME found
Using device: cuda


In [115]:
def intervehicleConnectivity(threshold = None):
    xs = []
    ys = []
    for vehicle in traci.vehicle.getIDList():
        x, y = traci.vehicle.getPosition(vehicle)
        xs.append(x)
        ys.append(y)
    xs = torch.tensor(xs, dtype=torch.float32).view(-1,1)
    ys = torch.tensor(ys, dtype=torch.float32).view(-1,1)
    intervehicle_distances = torch.sqrt((xs - xs.t())**2 + (ys - ys.t())**2)
    if threshold is not None:
        # make the distances 1 if less than the threshold, 0 otherwise
        connectivity = torch.where(intervehicle_distances < threshold, torch.ones_like(intervehicle_distances), torch.zeros_like(intervehicle_distances))
    return connectivity

In [116]:
def randomTrips(dur=1000, density=12):
    os.system("python $SUMO_HOME/tools/randomTrips.py -n config/osm.net.xml.gz -r config/osm.passenger.trips.xml -e " + str(dur) + " -l --insertion-density=" + str(density))

def shouldContinueSim():
    numVehicles = traci.simulation.getMinExpectedNumber()
    return True if numVehicles > 0 else False

def restart(sumoCmd):
    with io.StringIO() as buf, redirect_stdout(buf):
        try:
            traci.close()
        except:
            pass
        traci.start(sumoCmd)

def close():
    traci.close()

randomTrips(800, 1.5)

Success.


In [117]:
def bfs_distance(adj_matrix):
    n_hop_matrix = torch.zeros_like(adj_matrix)
    for start_node in range(adj_matrix.size(0)):
        visited = [0] * adj_matrix.size(0)
        queue = deque([(start_node, 0)])
        visited[start_node] = True
        
        while queue:
            current_node, current_dist = queue.popleft()
            
            for neighbor, connected in enumerate(adj_matrix[current_node]):
                if connected and not visited[neighbor]:
                    queue.append((neighbor, current_dist + 1))
                    visited[neighbor] = True
                    n_hop_matrix[start_node, neighbor] = current_dist + 1
    
    return n_hop_matrix

In [158]:
class RoutingGym(gym.Env):
    def __init__(self, sumoCmd, max_steps=1100, n_nodes=57, max_routing_steps=30):
        self.sumoCmd = sumoCmd
        self.step_counter = 0
        self.max_steps = max_steps
        self.n_nodes = n_nodes
        self.vehicle_ids = None
        self.start_node = None
        self.end_node = None
        self.current_node = None
        self.node_features = None
        self.adj_matrix = None
        self.edge_index = None
        self.hop_thresh = None
        self.routing_done = False
        self.routing_steps = 0
        self.min_n_hops = None
        self.end_node_indicator = torch.zeros(n_nodes)
        self.max_routing_steps = max_routing_steps
        self.n_hop_matrix = None
        self.neighbors_indicator = None
    
    def render(self):
        self.show_gui = True

    def reset(self):
        try:
            traci.close()
        except:
            pass
        traci.start(sumoCmd)
        self.step_counter = 0

        while self.step_counter < 400:
            traci.simulationStep()
            self.step_counter += 1

    def step(self):
        traci.simulationStep()
        self.end_node_indicator = torch.zeros(self.n_nodes)
        self.routing_done = False
        self.routing_steps = 0
        self.step_counter += 1
        self.vehicle_ids = traci.vehicle.getIDList()
        self.adj_matrix = intervehicleConnectivity(800)
        # we subtract the diagonal to avoid self loops
        self.adj_matrix = self.adj_matrix - torch.eye(self.adj_matrix.size(0))
        self.select_start_end_nodes()
        self.adj_matrix = F.pad(self.adj_matrix, (0, self.n_nodes - self.adj_matrix.size(0), 0, self.n_nodes - self.adj_matrix.size(1)), "constant", 0)
        self.n_hop_matrix = F.pad(self.n_hop_matrix, (0, self.n_nodes - self.n_hop_matrix.size(0), 0, self.n_nodes - self.n_hop_matrix.size(1)), "constant", 100)
        self.edge_index, _ = dense_to_sparse(self.adj_matrix)
        current_node_indicators = torch.zeros(self.n_nodes)
        current_node_indicators[self.current_node] = 1
        self.end_node_indicator[self.end_node] = 1
        self.neighbors_indicator = self.adj_matrix[self.current_node]
        self.node_features = torch.stack((current_node_indicators, 
                                          self.n_hop_matrix[self.current_node], self.end_node_indicator, self.neighbors_indicator)).T
        
        return self.node_features.to(device)

    def select_start_end_nodes(self):
        self.n_hop_matrix = bfs_distance(self.adj_matrix)
        self.hop_thresh = min(self.n_hop_matrix.max(), 1)
        starts, ends = torch.where(self.n_hop_matrix >= self.hop_thresh)
        starts = starts.tolist()
        ends = ends.tolist()
        self.start_node, self.end_node = random.choice(list(zip(starts, ends)))
        # minimal number of hops between start and end nodes
        self.min_n_hops = self.n_hop_matrix[self.start_node, self.end_node]
        self.current_node = self.start_node

    def act(self, neighbor_index):
        self.routing_steps += 1
        neighbors = torch.where(self.adj_matrix[self.current_node] == 1)[0]
        valid_action_size = len(neighbors)
        if valid_action_size <= neighbor_index:
            if self.node_features.device != device:
                self.node_features = self.node_features.to(device)
            return self.node_features, torch.tensor(0).to(device), False
        # check if the next hop is reachable
        else:
            next_hop = neighbors[neighbor_index]
            reward = self.compute_reward(next_hop)
            self.current_node = next_hop
            curr_node_indicators = torch.zeros(self.n_nodes)
            curr_node_indicators[self.current_node] = 1
            self.neighbors_indicator = self.adj_matrix[self.current_node]
            self.node_features = torch.stack((curr_node_indicators, 
                                              self.n_hop_matrix[self.current_node], self.end_node_indicator, self.neighbors_indicator)).T
            return self.node_features.to(device), torch.tensor(reward).to(device), self.routing_done

    
    def get_action_mask(self):
        action_mask = copy.deepcopy(self.adj_matrix[self.current_node])
        action_mask = F.pad(action_mask, (0, self.n_nodes - action_mask.size(0)), "constant", 0).to(device)
        return action_mask

    def get_adj_matrix(self):
        return copy.deepcopy(self.adj_matrix).to(device)
    
    def get_edge_index(self):
        return copy.deepcopy(self.edge_index).to(device)
        
    def compute_reward(self, next_hop):
        if self.routing_steps >= self.max_routing_steps:
            print("Failed, ", self.min_n_hops)
            self.routing_done = True
            return -2
        elif next_hop == self.end_node:
            print("Routing done, number of hops: ", self.routing_steps, " minimum number of hops: ", self.min_n_hops)
            self.routing_done = True
            return (self.min_n_hops / self.routing_steps) * 5 + 1
        elif self.n_hop_matrix[self.current_node, self.end_node] > self.n_hop_matrix[next_hop, self.end_node]:
            return 0.1
        else:
            return 0
        
    def sim_done(self):
        """
        function: get the done state of simulation.
        """
        return not (shouldContinueSim() and self.step_counter <= self.max_steps)

In [159]:
class GDQN(nn.Module):
    def __init__(self, in_channels=4, n_nodes=57, hidden_dim=64, dropout=0.1, max_n_neighbors=15):
        super(GDQN, self).__init__()
        self.n_nodes = n_nodes
        self.hidden_dim = hidden_dim
        self.convs1 = GCNConv(in_channels, hidden_dim)
        self.convs2 = GCNConv(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_dim * n_nodes, n_nodes)
        self.fc2 = nn.Linear(n_nodes, max_n_neighbors)
        self.softmax = nn.Softmax(dim=1)
        self.selu = nn.SELU()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.convs1(x, edge_index)
        x = self.selu(x)
        x = self.convs2(x, edge_index)
        x = self.selu(x)
        x = x.view(-1, self.n_nodes * self.hidden_dim)
        x = self.fc1(x)
        x = self.selu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x



Transition = namedtuple('Transition',
                        ('data', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


In [160]:
BATCH_SIZE = 128
GAMMA = 0.98
EPS_START = 0.9
EPS_END = 0.00
EPS_DECAY = 10000
TAU = 0.005
LR = 0.01


# Get number of actions from gym action space
n_nodes = 57
env = RoutingGym(sumoCmd, 1100, n_nodes)
# Get the number of state observations

policy_net = GDQN().to(device)
target_net = GDQN().to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(1000)


steps_done = 0


def select_action(data, action_mask):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(data).max(1).indices.view(-1)
    else:
        valid_size = len(torch.where(action_mask == 1)[0])
        return torch.randint(0, valid_size, (1,), device=device)

In [161]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = Batch.from_data_list([s for s in batch.next_state
                                                if s is not None])
    data_batch = Batch.from_data_list(batch.data)
    action_batch = torch.stack(batch.action)
    reward_batch = torch.concat(batch.reward)
  

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(data_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()


In [162]:
# Initialize the environment and get its state
n_epoch = 10
for e in range(n_epoch):
    env.reset()
    done = False
    episode_num = 0
    while not done:
        done = env.sim_done()
        state = env.step()
        episode_num += 1
        edge_index = env.get_edge_index()
        routing_done = False
        accumulated_reward = 0
        while not routing_done:
            action_mask = env.get_action_mask()
            data = Data(x=state, edge_index=edge_index)
            action = select_action(data, action_mask)
            node_features, reward, routing_done = env.act(action.item())
            reward = torch.tensor([reward], device=device)
            accumulated_reward += reward.item()

            routing_done = routing_done

            if routing_done:
                next_state = None
                memory.push(data, action, None, reward)
            else:
                next_state = node_features
                memory.push(data, action, Data(x=next_state, edge_index=edge_index), reward)

            # Move to the next state
            state = next_state

            # Perform one step of the optimization (on the policy network)
            optimize_model()

            # Soft update of the target network's weights
            # θ′ ← τ θ + (1 −τ )θ′
            target_net_state_dict = target_net.state_dict()
            policy_net_state_dict = policy_net.state_dict()
            for key in policy_net_state_dict:
                target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
            target_net.load_state_dict(target_net_state_dict)
        print(f"Episode: {episode_num}, Accumulated reward: {accumulated_reward}")

print('Complete')
plt.ioff()
plt.show()


Simulation ended at time: 546.00
Reason: TraCI requested termination.
Performance: 
 Duration: 154.10s
 TraCI-Duration: 71.31s
 Real time factor: 3.54318
 UPS: 95.931836
Vehicles: 
 Inserted: 60 (Loaded: 70)
 Running: 46
 Waiting: 0
Statistics (avg of 14):
 RouteLength: 2551.36
 Speed: 9.78
 Duration: 269.00
 WaitingTime: 9.86
 TimeLoss: 37.25
 DepartDelay: 0.40

 Retrying in 1 seconds
***Starting server on port 54061 ***
Loading net-file from './config/osm.net.xml.gz' ... done (107ms).
Loading done.
Simulation version 1.20.0 started with time: 0.00.
Routing done, number of hops:  5  minimum number of hops:  tensor(1.)
Episode: 1, Accumulated reward: 2.100000001490116


  return self.node_features.to(device), torch.tensor(reward).to(device), self.routing_done


Routing done, number of hops:  3  minimum number of hops:  tensor(1.)
Episode: 2, Accumulated reward: 2.7666667476296425
Routing done, number of hops:  11  minimum number of hops:  tensor(3.)
Episode: 3, Accumulated reward: 2.963636502623558
Routing done, number of hops:  6  minimum number of hops:  tensor(2.)
Episode: 4, Accumulated reward: 2.8666667491197586
Routing done, number of hops:  4  minimum number of hops:  tensor(2.)
Episode: 5, Accumulated reward: 3.7000000029802322
Failed
Episode: 6, Accumulated reward: 2.100000061094761
Routing done, number of hops:  14  minimum number of hops:  tensor(1.)
Episode: 7, Accumulated reward: 1.7571429312229156
Routing done, number of hops:  13  minimum number of hops:  tensor(3.)
Episode: 8, Accumulated reward: 2.6538462713360786
Routing done, number of hops:  8  minimum number of hops:  tensor(1.)
Episode: 9, Accumulated reward: 1.7250000014901161
Routing done, number of hops:  1  minimum number of hops:  tensor(1.)
Episode: 10, Accumulated

KeyboardInterrupt: 