In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import gc

import logging
logging.basicConfig(level=logging.ERROR)

In [2]:
import networkx as nx
import osmnx as ox

def load_network(graph_file) -> nx.Graph:
    """loads in modified OSM graph"""
    G_trans = ox.load_graphml(
                graph_file,
                node_dtypes={'idx':int, 'x':float, 'y':float, 'general0':float, 'general1':float, 
                            'general2':float, 'general3':float, 'general4':float},    
                edge_dtypes={'u':int, 'v':int, 'speed':float, 'capacity':float, 'length':float,
                             'general0':float, 'general1':float, 'general2':float, 'general3':float})
    G_trans = G_trans.to_undirected()                                   # make undirected
    return G_trans

In [3]:
import glob
from torch_geometric.utils.convert import from_networkx

data_list = []
for osm_path in glob.glob('./osm_dataset/raw/*'):
    print(osm_path)
    osm_graph = load_network(osm_path)
    data = from_networkx(osm_graph, 
                        group_node_attrs=["idx", "general0", "general1", "general2",
                                        "general3",  "general4", "x", "y"], 
                        group_edge_attrs=["u", "v", "osmid", "general0", "general1", 
                                        "general2", "general3", "length", "speed", "capacity"])
    data.path = osm_path
    data_list.append(data)

./osm_dataset/raw/east.osm
./osm_dataset/raw/copenhagen.osm
./osm_dataset/raw/nairobi.osm
./osm_dataset/raw/melbourne.osm
./osm_dataset/raw/durham.osm
./osm_dataset/raw/calgary.osm
./osm_dataset/raw/jakarta.osm
./osm_dataset/raw/manila.osm
./osm_dataset/raw/la.osm
./osm_dataset/raw/tehran.osm
./osm_dataset/raw/hanoi.osm
./osm_dataset/raw/seattle.osm
./osm_dataset/raw/west.osm
./osm_dataset/raw/kobe.osm
./osm_dataset/raw/rio.osm
./osm_dataset/raw/beirut.osm
./osm_dataset/raw/istanbul.osm
./osm_dataset/raw/delft.osm
./osm_dataset/raw/taipei.osm
./osm_dataset/raw/vienna.osm
./osm_dataset/raw/bogota.osm
./osm_dataset/raw/suwon.osm


In [4]:
# GNN config
node_dim = 8
edge_dim = 10
erm_hidden_dim = 32
action_dim = 4
num_sample_actions = 4

# DQN
graph_dim  = 32
hidden_dim = 32

# Planner
gamma = 0.99
epsilon_start = 1.0
epsilon_decay = 1000
epsilon_min = 0.05
batch_size = 48
memory_size = 2000
pop_size = 300
episode_len = 24

# Overall
criterion = nn.MSELoss()
learning_rate = 1e-4
num_epochs = 1000
num_episodes = 6
num_runs = 10

# Saving
every_n_train_steps = 18
val_every_k_epochs = 3
log_every_k_steps = 4

In [5]:
from trans_infra.trans_infra.planner3 import DQNLightning
%load_ext autoreload
%autoreload 2

In [6]:
# Create the DQNLightning module
model = DQNLightning(node_dim, edge_dim, erm_hidden_dim, graph_dim, hidden_dim, action_dim, 
                     num_sample_actions, criterion, gamma, epsilon_start, epsilon_decay, epsilon_min, 
                     batch_size, memory_size, learning_rate, data_list, 
                     pop_size, episode_len, num_episodes, num_runs)

# Configure the logger
name = "planner3"
version = 1
logger = TensorBoardLogger('./logs/', name=name, version=version)

checkpoint_callback = ModelCheckpoint(
    dirpath=f'./checkpoints/{name}_v{version}',
    filename=f'{name}_v{version}'+'-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3, monitor='val_loss',
    mode='min', verbose=True,
    every_n_epochs=1, 
    save_last=True
)

# Train the model
trainer = pl.Trainer(logger=logger, max_epochs=1000, 
                    check_val_every_n_epoch=val_every_k_epochs,
                    callbacks=[checkpoint_callback],
                    log_every_n_steps=log_every_k_steps)

trainer.fit(model, ckpt_path=f"./checkpoints/{name}_v{version}/last.ckpt")


./osm_dataset/raw/copenhagen.osm


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/alexanderkumar/miniconda3/envs/graphs/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/alexanderkumar/Desktop/transpoPlanner/transpoPlanner/checkpoints/planner3_v1 exists and is not empty.
Restoring states from the checkpoint path at ./checkpoints/planner3_v1/last.ckpt

  | Name       | Type                | Params
---------------------------------------------------
0 | edge_model | EdgeRegressionModel | 86.8 K
1 | q_net      | DQN                 | 4.8 K 
2 | t_net      | DQN                 | 4.8 K 
3 | criterion  | MSELoss             | 0     
---------------------------------------------------
96.3 K    Trainable params
0         Non-trainable params
96.3 K    Total params
0.385     Total estimated model params size (MB)
Restored all states from the checkpoint at ./ch

loaded replay memory


/Users/alexanderkumar/miniconda3/envs/graphs/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/alexanderkumar/miniconda3/envs/graphs/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

./osm_dataset/raw/tehran.osm


/Users/alexanderkumar/miniconda3/envs/graphs/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
Epoch 29, global step 540: 'val_loss' reached 21.05246 (best 14.03883), saving model to '/Users/alexanderkumar/Desktop/transpoPlanner/transpoPlanner/checkpoints/planner3_v1/planner3_v1-epoch=29-val_loss=21.05.ckpt' as top 3


done with val step
saving replay
saving replay
done with train step
done with train step
done with train step
done with train step
getting data





./osm_dataset/raw/kobe.osm


100%|██████████| 6/6 [02:53<00:00, 28.86s/it]


done with train step
done with train step
done with train step
done with train step
done with train step
done with train step
updated target model
 updated target model
done with train step
done with train step
getting data





./osm_dataset/raw/beirut.osm


100%|██████████| 6/6 [02:54<00:00, 29.12s/it]


done with train step
done with train step
done with train step
done with train step
done with train step
done with train step
done with train step
done with train step
getting data





./osm_dataset/raw/seattle.osm


100%|██████████| 6/6 [03:31<00:00, 35.28s/it]


In [None]:
%reload_ext tensorboard
%tensorboard --logdir='./logs/'