# Graph Construction

In [1]:
from config import CONFIG
from metro_model import MetroModel
from metro_dataset import MetroDataset
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import os
from pathlib import Path
import networkx as nx
import numpy as np

torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm

import os
os.environ['USE_PYGEOS'] = '0'
import geopandas

In a future release, GeoPandas will switch to using Shapely by default. If you are using PyGEOS directly (calling PyGEOS functions on geometries from GeoPandas), this will then stop working and you are encouraged to migrate from PyGEOS to Shapely 2.0 (https://shapely.readthedocs.io/en/latest/migration_pygeos.html).
  import geopandas as gpd


<torch._C.Generator at 0x1d11f657bd0>

In [2]:
dataset = MetroDataset(CONFIG.lines, init_nb=20)
num_nodes = dataset.cg.num_nodes()
dataset.cg.altair_graph(dataset.cg.graph, 10)



## Model

In [3]:
model = MetroModel(embedding_size=8, num_nodes=num_nodes, neighbor_nb=1, input_size=1, gsl_mode="matrix")

## Training

In [22]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=8)
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/", save_top_k=1000 ,monitor="loss") # , every_n_train_steps=100, 
trainer = Trainer(max_epochs=50, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                   | Type      | Params
-----------------------------------------------------
0 | node_embeddings_start  | Embedding | 552   
1 | node_embeddings_target | Embedding | 552   
2 | graph_layer            | GraphConv | 3     
3 | _linear1               | Linear    | 72    
4 | _linear2               | Linear    | 72    
5 | linear                 | Linear    | 4.8 K 
6 | softmax                | Softmax   | 0     
-----------------------------------------------------
10.8 K    Trainable params
0         Non-trainable params
10.8 K    Total params
0.043     Total estimated model params size (MB)


Epoch 49: 100%|██████████| 248/248 [00:01<00:00, 171.18it/s, loss=0.00379, v_num=19]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 248/248 [00:01<00:00, 169.47it/s, loss=0.00379, v_num=19]


## Visualization of results

In [23]:
def load_model(path):
    checkpoint = torch.load(path)
    model = MetroModel(embedding_size=8, num_nodes=num_nodes, neighbor_nb=1, input_size=1, gsl_mode="matrix")
    model.load_state_dict(checkpoint['state_dict'])
    return model

def viz(checkpoint):
    model = load_model(checkpoint)

    A = model.graph_matrix_learning().detach().numpy()
    A = np.array(A)

    learned_graph = nx.from_numpy_array(A, nx.DiGraph)
    return dataset.cg.altair_graph(learned_graph, 10)

for path in os.listdir('checkpoints'):
    checkpoint = Path('checkpoints') / path
    print(checkpoint)

checkpoints\epoch=0-step=248.ckpt
checkpoints\epoch=1-step=496.ckpt
checkpoints\epoch=10-step=2728.ckpt
checkpoints\epoch=11-step=2976.ckpt
checkpoints\epoch=12-step=3224.ckpt
checkpoints\epoch=13-step=3472.ckpt
checkpoints\epoch=14-step=3720.ckpt
checkpoints\epoch=15-step=3968.ckpt
checkpoints\epoch=16-step=4216.ckpt
checkpoints\epoch=17-step=4464.ckpt
checkpoints\epoch=18-step=4712.ckpt
checkpoints\epoch=19-step=4960.ckpt
checkpoints\epoch=2-step=744.ckpt
checkpoints\epoch=20-step=5208.ckpt
checkpoints\epoch=21-step=5456.ckpt
checkpoints\epoch=22-step=5704.ckpt
checkpoints\epoch=23-step=5952.ckpt
checkpoints\epoch=24-step=6200.ckpt
checkpoints\epoch=25-step=6448.ckpt
checkpoints\epoch=26-step=6696.ckpt
checkpoints\epoch=27-step=6944.ckpt
checkpoints\epoch=28-step=7192.ckpt
checkpoints\epoch=29-step=7440.ckpt
checkpoints\epoch=3-step=992.ckpt
checkpoints\epoch=30-step=7688.ckpt
checkpoints\epoch=31-step=7936.ckpt
checkpoints\epoch=32-step=8184.ckpt
checkpoints\epoch=33-step=8432.ckpt


In [24]:
viz("checkpoints\epoch=30-step=7688.ckpt")

Epoch 2:  46%|████▌     | 113/248 [01:58<02:21,  1.05s/it, loss=0.261, v_num=18] 
