In [57]:
import networkx as nx
import osmnx as ox
import os.path as osp
import glob
import numpy as np
import random
from collections import deque
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as gnn
from torch_geometric.utils.convert import from_networkx
from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv

In [64]:
%load_ext autoreload
%autoreload 2
from trans_infra.trans_infra.simulator import TransInfraNetworkModel

import logging
logging.basicConfig(level=logging.ERROR)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
def load_network(graph_file) -> nx.Graph:
    """loads in modified OSM graph"""
    G_trans = ox.load_graphml(
                graph_file,
                node_dtypes={'index':int, 'x':float, 'y':float, 'general0':float, 'general1':float, 
                            'general2':float, 'general3':float, 'general4':float},    
                edge_dtypes={'u':int, 'v':int, 'speed':float, 'capacity':float, 'general0':float,
                            'general1':float, 'general2':float, 'general3':float})
    G_trans = G_trans.to_undirected()                                   # make undirected
    return G_trans

In [84]:

# * set num GNN layers to be slightly larger than necessary receptive field for problem
# * add MLP at end of GNN layers for post-processing

In [65]:
class EdgeConditionedConvolution(nn.Module):
    """GNN block: updates node and edge embeddings"""
    def __init__(self, node_dim, edge_dim, hidden_dim, out_dim, dropout_p=0.2):
        super(EdgeConditionedConvolution, self).__init__()
        self.node_lin = nn.Linear(node_dim, hidden_dim)
        self.edge_lin = nn.Linear(edge_dim, hidden_dim)
        self.conv = gnn.NNConv(hidden_dim, out_dim, 
                               nn=nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                nn.BatchNorm1d(hidden_dim),
                                                nn.Dropout(p=dropout_p),
                                                nn.ReLU(),
                                                nn.Linear(hidden_dim, hidden_dim*out_dim),
                                                nn.BatchNorm1d(hidden_dim*out_dim),
                                                nn.Dropout(p=dropout_p),
                                                nn.ReLU(),
                                                ))

    def forward(self, x, edge_index, edge_attr, batch) -> torch.Tensor:
        x = self.node_lin(x)
        edge_attr = self.edge_lin(edge_attr)
        x = self.conv(x, edge_index, edge_attr)
        return x

class EdgeRegressionModel(nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim, out_dim):
        super(EdgeRegressionModel, self).__init__()
        self.conv1 = EdgeConditionedConvolution(node_dim, edge_dim, hidden_dim, hidden_dim)
        self.conv2 = EdgeConditionedConvolution(hidden_dim, edge_dim, hidden_dim, hidden_dim)
        self.conv3 = EdgeConditionedConvolution(hidden_dim, edge_dim, hidden_dim, hidden_dim)
        self.regressor = nn.Linear(2*hidden_dim, out_dim)
        self.pool = gnn.global_mean_pool

    def forward(self, x, edge_index, edge_attr, batch) -> tuple[torch.Tensor, torch.Tensor]:
        x = self.conv1(x, edge_index, edge_attr, batch)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr, batch)
        x = F.relu(x)
        x = self.conv3(x, edge_index, edge_attr, batch)
        # edge predictions
        x_i = x[edge_index[0]]
        x_j = x[edge_index[1]]
        x_edge = torch.cat([x_i, x_j], dim=1)
        edge_scores = self.regressor(x_edge)
        # graph predictions
        graph_embed = self.pool(x, batch)
        
        return edge_scores, graph_embed

In [85]:

class DQN(nn.Module):
    """f(s) -> q : takes concat graph embeddings of states and predicts their q-vals"""
    def __init__(self, graph_dim, hidden_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(graph_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DQNAgent:
    # ? Add episode_length and budget as extra hyperparameters?
    # ? precalculate betweenness?
    """planning agent"""
    def __init__(self, node_dim, edge_dim, erm_hidden_dim,
                 graph_dim, hidden_dim, num_actions, learning_rate, gamma, 
                 epsilon, epsilon_decay, epsilon_min, batch_size, memory_size,
                 pop_size=300, episode_len=24):
        # init GNN, Q NN, T NN, opt, loss, and memory buffer
        self.graph_dim = graph_dim
        self.num_actions = num_actions
        self.edge_model = EdgeRegressionModel(node_dim, edge_dim, 
                                              erm_hidden_dim, self.num_actions)
        self.q_net = DQN(graph_dim, hidden_dim)
        self.t_net = DQN(graph_dim, hidden_dim)
        self.optimizer = optim.Adam(list(self.edge_model.parameters()) + list(self.q_net.parameters()), 
                                    lr=learning_rate)
        self.criterion = nn.MSELoss()
        self.memory = deque(maxlen=memory_size)
        # learning hyperparams
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size
        # init sim
        self.idx2gen_edges = {0: 'car', 1: 'bus', 2: 'pedbike', 3: 'other'}         # TODO: make arg
        self.pop_size = pop_size
        self.episode_len = episode_len
        self.path = "./osm_dataset/raw/copenhagen.osm"
        self.load_sim()
    
    def load_sim(self):
        self.sim = TransInfraNetworkModel(self.pop_size, self.episode_len, self.path)
        self.data2nx = dict(zip(range(self.sim.G_trans.number_of_nodes()),
                                self.sim.G_trans.nodes()))
         
    def sample_env(self, data, isNew=False):
        """get (s, a, r, s_+1) observation and save to replay"""
        print(f"isNew: {isNew}")
        # refresh sim if new
        if isNew:
            print("new env")
            self.path = data.path[0]
            self.load_sim()
        # run sim and get reward in curr state
        print("running sim")
        for _ in range(self.episode_len):
            self.sim.step()
        curr_score = self.sim.datacollector.get_model_vars_dataframe().iloc[-1][0]
        
        # get state from graph_embedding
        edge_scores, curr_state = self.edge_model(data.x, data.edge_index, 
                                                  data.edge_attr, data.batch)
        
        # get actions from edge_scores
        edge_idxs, action_idxs = self.actions_from_edge_scores(edge_scores)
        
        # modify edge attrs by actions
        graph_embeds = None
        new_data_list = []
        for i in range(len(action_idxs)):
            new_data = self.take_actions(data, edge_idxs[i], action_idxs[i])
            new_data_list.append(new_data)
            # get graph embeddings of modified graphs
            _, graph_embed = self.edge_model(new_data.x, new_data.edge_index, 
                                             new_data.edge_attr, new_data.batch)
            if graph_embeds is None:    
                graph_embeds = graph_embed
            else:                     
                graph_embeds = torch.cat((graph_embeds, graph_embed), dim=0)
        
        # pass concat graph embeds to Q NN and get Q vals
        q_vals = self.q_net(graph_embeds)
        # choose actions corresponding with max Q val
        max_q_idx = torch.max(q_vals, dim=0).indices
        max_edge_idx = edge_idxs[max_q_idx]
        max_q_action_vec = edge_scores[max_edge_idx]
        max_q_action_idx = torch.argmax(max_q_action_vec).item()
        next_state = new_data_list[max_q_idx]
        
        # modify networkx graph
        u, v = data.edge_index[0][max_edge_idx].item(), data.edge_index[1][max_edge_idx].item()
        u_p, v_p = self.data2nx[u], self.data2nx[v]
        
        self.sim.G_trans.edges[u_p, v_p, 0]["general"] = self.idx2gen_edges[max_q_action_idx]
        self.sim.reset()
        print("reset and running sim again")
        # run sim and get reward in curr state
        for _ in range(self.episode_len):
            self.sim.step()
        next_score = self.sim.datacollector.get_model_vars_dataframe().iloc[-1][0]
        # print(f"next_score: {next_score}")
        reward = next_score - curr_score
        # print(f"reward: {reward}")
        # ! need to pass action in a smarter way, incorporate the index of the edge?
        # ? concat edge embedding and action
        self.store_transition(data, max_q_action_idx, reward, next_state, False)     # ? action saved is idx

    def actions_from_edge_scores(self, edge_scores) -> tuple[torch.Tensor, torch.Tensor]:
        """return best actions given edge_scores"""
        # if explore randomly choose k edges
        if np.random.rand() <= self.epsilon:
            edge_idxs = np.random.choice(edge_scores.shape[0], 
                                       self.num_actions, replace=False)
            action_vecs = edge_scores[edge_idxs]
            actions = torch.max(action_vecs, 1)
            return edge_idxs, actions.indices
        # if exploit choose k edges with highest max element
        max_elements = torch.max(edge_scores, axis=1)
        edge_idxs = torch.argsort(max_elements.values)[-self.num_actions:]
        action_vecs = edge_scores[edge_idxs]
        actions = torch.max(action_vecs, 1)
        
        return edge_idxs, actions.indices

    def take_actions(self, data, edge_idx, action_idx):
        """modify data object to reflect action"""
        # TODO: make indexing not hardcoded
        u, v = data.edge_index[0][edge_idx].item(), data.edge_index[1][edge_idx].item()
        i = -1
        for i in range(data.edge_index.shape[1]):
            u2, v2 = data.edge_index[0][i].item(), data.edge_index[1][i].item()
            if u == v2 and v == u2:
                break

        sym_edge_index = i

        new_data = copy.deepcopy(data)
        new_data.edge_attr[edge_idx, 3:7] = 0
        new_data.edge_attr[edge_idx][action_idx] = 1
        new_data.edge_attr[sym_edge_index, 3:7] = 0
        new_data.edge_attr[sym_edge_index][action_idx] = 1
        
        return new_data
    
    def update_target_model(self):
        self.t_net.load_state_dict(self.q_net.state_dict())

    def store_transition(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def get_state_embedding(self, graph) -> torch.Tensor:
        x, edge_index, edge_attr, batch = graph.x, graph.edge_index, graph.edge_attr, graph.batch
        _, graph_emb = self.edge_model(x, edge_index, edge_attr, batch)
        return graph_emb

    def train(self, data):
        if len(self.memory) < self.batch_size:
            isNew = False if data.path[0] == self.path else True
            self.sample_env(data, isNew)
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        state_data = Batch.from_data_list(states)
        next_state_data = Batch.from_data_list(next_states)
        # state_embeds = torch.squeeze(torch.stack(states, 1))
        # next_state_embeds = torch.squeeze(torch.stack(next_states, 1))
        actions = torch.LongTensor(actions).unsqueeze(1)
        rewards = torch.Tensor(rewards).unsqueeze(1)
        dones = torch.Tensor(dones).unsqueeze(1)
        
        state_edge_scores, state_graph_embeds = self.edge_model(state_data.x, state_data.edge_index, 
                                                                state_data.edge_attr, state_data.batch)
        next_state_edge_scores, next_state_graph_embeds = self.edge_model(next_state_data.x, next_state_data.edge_index, 
                                                                          next_state_data.edge_attr, next_state_data.batch)

        q_values = self.q_net(state_graph_embeds)
        with torch.no_grad():
            next_q_values = self.t_net(next_state_graph_embeds)
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

        loss = self.criterion(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)

In [8]:
data_list = []
for osm_path in glob.glob('./osm_dataset/raw/*'):
    print(osm_path)
    osm_graph = load_network(osm_path)
    data = from_networkx(osm_graph, 
                        group_node_attrs=["index", "general0", "general1", "general2",
                                        "general3",  "general4", "x", "y"], 
                        group_edge_attrs=["u", "v", "osmid", "general0", "general1", 
                                        "general2", "general3", "length", "speed"])
    data.path = osm_path
    data_list.append(data)

./osm_dataset/raw/east.osm
./osm_dataset/raw/copenhagen.osm
./osm_dataset/raw/nairobi.osm
./osm_dataset/raw/melbourne.osm
./osm_dataset/raw/durham.osm
./osm_dataset/raw/calgary.osm
./osm_dataset/raw/la.osm
./osm_dataset/raw/tehran.osm
./osm_dataset/raw/hanoi.osm
./osm_dataset/raw/seattle.osm
./osm_dataset/raw/west.osm
./osm_dataset/raw/kobe.osm
./osm_dataset/raw/bradenton.osm
./osm_dataset/raw/delft.osm
./osm_dataset/raw/taipei.osm
./osm_dataset/raw/vienna.osm
./osm_dataset/raw/tunis.osm


In [9]:
data = data_list[0]
num_nodes, num_node_features = data.x.shape
num_edges, num_edge_features = data.edge_attr.shape
hidden_dim = 32
out_dim = 4

print(f"|N| = {num_nodes}, |E| = {num_edges}")
print(f"|F_n| = {num_node_features}, |F_e| = {num_edge_features}")

|N| = 1148, |E| = 2890
|F_n| = 8, |F_e| = 9


In [10]:
# Test planner for all instances of osm
loader = DataLoader(data_list, batch_size=1)

In [86]:
planner = DQNAgent(num_node_features, num_edge_features, 
                    erm_hidden_dim=32,
                    graph_dim=32, hidden_dim=32, num_actions=3, 
                    learning_rate=1e-4, gamma=0.99, 
                    epsilon=0.9, epsilon_decay=1000, epsilon_min=0.05, 
                    batch_size=5, memory_size=100,
                    pop_size=100, episode_len=12)

  scale_dist_cost = (np.array(self.dist_costs) / max(self.dist_costs)) * 8


In [None]:
avoid = {"./osm_dataset/raw/copenhagen.osm",
         "./osm_dataset/raw/seattle.osm",
         "./osm_dataset/raw/hanoi.osm",
         "./osm_dataset/raw/tehran.osm",
         "./osm_dataset/raw/durham.osm",
         "./osm_dataset/raw/la.osm"             # ?
         }        

In [87]:

# train loop:
# TODO: save replay buffer to disc
for epoch in range(2):
    for batch in loader:
        print(batch.path)
        for i in range(7):
            planner.train(batch)

['./osm_dataset/raw/east.osm']
isNew: True
new env
running sim
reset and running sim again
isNew: False
running sim
reset and running sim again
isNew: False
running sim
reset and running sim again
isNew: False
running sim
reset and running sim again
isNew: False
running sim
reset and running sim again
['./osm_dataset/raw/copenhagen.osm']
['./osm_dataset/raw/nairobi.osm']


KeyboardInterrupt: 

In [None]:
num_episodes = 1000
update_frequency = 10
target_update_frequency = 100

for episode in range(num_episodes):
    # Reset the environment and get the initial state
    state = env.reset()
    done = False
    
    while not done:
        # Generate training data
        for _ in range(update_frequency):
            # Select an action using the current Q-network and an exploration strategy (e.g., epsilon-greedy)
            action = select_action(q_net, state)
            
            # Execute the action in the environment
            next_state, reward, done = env.step(action)
            
            # Store the transition in the replay memory
            replay_memory.append((state, action, reward, next_state, done))
            
            # Update the current state
            state = next_state
            
            if done:
                break
        
        # Train the Q-network
        if len(replay_memory) >= batch_size:
            # Sample a batch of transitions from the replay memory
            transitions = random.sample(replay_memory, batch_size)
            
            # Extract the states, actions, rewards, next states, and dones from the transitions
            states, actions, rewards, next_states, dones = zip(*transitions)
            
            # Convert the data to PyTorch tensors
            states = torch.tensor(states, dtype=torch.float32)
            actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
            rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
            next_states = torch.tensor(next_states, dtype=torch.float32)
            dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
            
            # Compute the predicted Q-values and the target Q-values
            predicted_q_values = q_net(states).gather(1, actions)
            with torch.no_grad():
                next_q_values = target_q_net(next_states).max(1)[0].unsqueeze(1)
                target_q_values = rewards + (1 - dones) * gamma * next_q_values
            
            # Compute the loss and perform optimization
            loss = criterion(predicted_q_values, target_q_values)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Update the target Q-network
        if episode % target_update_frequency == 0:
            target_q_net.load_state_dict(q_net.state_dict())

In [306]:
def train(model, dataloader, criterion, optimizer, device, epochs):
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            x, edge_index, edge_attr = batch.x, batch.edge_index, batch.edge_attr
            edge_scores, graph_emb = model(x, edge_index, edge_attr, batch.batch)
            print(f"graph emb shape: {graph_emb.shape}")
            
            # Compute loss
            #! delete
            dummy_y = torch.from_numpy(np.random.rand(edge_scores.shape[0], edge_scores.shape[1]).astype('float32')).to(device)
            loss = criterion(edge_scores, dummy_y)
            epoch_loss += loss.item()
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
        
        epoch_loss /= len(dataloader)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")

In [237]:
# Set up the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Instantiate the EdgeRegressionModel
edge_model = EdgeRegressionModel(num_node_features, num_edge_features,
                                 hidden_dim, out_dim).to(device)

# Define the loss function and optimizer
criterion = nn.MSELoss()  # Example loss function (mean squared error)
optimizer = optim.Adam(edge_model.parameters(), lr=0.001)  # Example optimizer (Adam)

In [307]:
# Train the model
epochs = 100  # Number of training epochs
train(edge_model, loader, criterion, optimizer, device, epochs)

graph emb shape: torch.Size([4, 32])
graph emb shape: torch.Size([4, 32])
graph emb shape: torch.Size([3, 32])
Epoch [1/100], Loss: 29622866892117.3320
graph emb shape: torch.Size([4, 32])
graph emb shape: torch.Size([4, 32])


KeyboardInterrupt: 

In [None]:
# pass network through GAT A

# threshold attention weights from GAT A 
# to get K most important edges to modify

# iterate over K edges, passing the modified
# graph through GAT B to get a graph embedding

# feed embedding to DQN, predict Q value

# choose action with highest Q value

# simulate action and return score
    # save (s_t, a, r, s_t+1) as training data

# backprop on GAT A, GAT B, DQN

# ? GAT A and B can be same?
# ? feed hyperparam of episode length into networks