# Graph Construction

In this notebook, we are exploring the standard construction of the Graph using a GNN + GSL model.

In [1]:
from config import CONFIG
from model.metro_model import MetroModel
from data.metro_dataset import MetroDataset

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import os
from pathlib import Path
import networkx as nx
import numpy as np

torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm

import os
os.environ['USE_PYGEOS'] = '0'
import geopandas

In a future release, GeoPandas will switch to using Shapely by default. If you are using PyGEOS directly (calling PyGEOS functions on geometries from GeoPandas), this will then stop working and you are encouraged to migrate from PyGEOS to Shapely 2.0 (https://shapely.readthedocs.io/en/latest/migration_pygeos.html).
  import geopandas as gpd


<torch._C.Generator at 0x295e4bb7bb0>

In [2]:
dataset = MetroDataset(CONFIG.lines, init_nb=20)
num_nodes = dataset.cg.num_nodes()
dataset.cg.altair_graph(dataset.cg.graph, 10)



## Model

In [3]:
model = MetroModel(embedding_size=8, num_nodes=num_nodes, neighbor_nb=1, input_size=1, gsl_mode="matrix")

## Training

In [4]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=8)
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/", save_top_k=1000 ,monitor="loss") # , every_n_train_steps=100, 
trainer = Trainer(max_epochs=50, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name                   | Type      | Params
-----------------------------------------------------
0 | node_embeddings_start  | Embedding | 296   
1 | node_embeddings_target | Embedding | 296   
2 | _linear1               | Linear    | 72    
3 | _linear2               | Linear    | 72    
4 | graph_layer            | GraphConv | 3     
5 | linear                 | Linear    | 1.4 K 
6 | softmax                | Softmax   | 0     
-----------------------------------------------------
3.5 K     Trainable params
0         Non-trainable params
3.5 K     Total params
0.014     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 49: 100%|██████████| 248/248 [00:01<00:00, 212.82it/s, loss=17, v_num=0]     

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 248/248 [00:01<00:00, 211.37it/s, loss=17, v_num=0]


## Visualization of results

In [5]:
def load_model(path):
    checkpoint = torch.load(path)
    model = MetroModel(embedding_size=8, num_nodes=num_nodes, neighbor_nb=1, input_size=1, gsl_mode="matrix")
    model.load_state_dict(checkpoint['state_dict'])
    return model

def viz(checkpoint):
    model = load_model(checkpoint)

    A = model.graph_matrix_learning().detach().numpy()
    A = np.array(A)

    learned_graph = nx.from_numpy_array(A, nx.DiGraph)
    return dataset.cg.altair_graph(learned_graph, 10)

for path in os.listdir('checkpoints'):
    checkpoint = Path('checkpoints') / path
    print(checkpoint)

checkpoints\epoch=0-step=248-v1.ckpt
checkpoints\epoch=0-step=248.ckpt
checkpoints\epoch=1-step=496-v1.ckpt
checkpoints\epoch=1-step=496.ckpt
checkpoints\epoch=10-step=2728-v1.ckpt
checkpoints\epoch=10-step=2728.ckpt
checkpoints\epoch=11-step=2976-v1.ckpt
checkpoints\epoch=11-step=2976.ckpt
checkpoints\epoch=12-step=3224-v1.ckpt
checkpoints\epoch=12-step=3224.ckpt
checkpoints\epoch=13-step=3472-v1.ckpt
checkpoints\epoch=13-step=3472.ckpt
checkpoints\epoch=14-step=3720-v1.ckpt
checkpoints\epoch=14-step=3720.ckpt
checkpoints\epoch=15-step=3968-v1.ckpt
checkpoints\epoch=15-step=3968.ckpt
checkpoints\epoch=16-step=4216-v1.ckpt
checkpoints\epoch=16-step=4216.ckpt
checkpoints\epoch=17-step=4464-v1.ckpt
checkpoints\epoch=17-step=4464.ckpt
checkpoints\epoch=18-step=4712-v1.ckpt
checkpoints\epoch=18-step=4712.ckpt
checkpoints\epoch=19-step=4960-v1.ckpt
checkpoints\epoch=19-step=4960.ckpt
checkpoints\epoch=2-step=744-v1.ckpt
checkpoints\epoch=2-step=744.ckpt
checkpoints\epoch=20-step=5208-v1.ckp

In [7]:
viz("checkpoints\epoch=30-step=7688-v1.ckpt")