In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from generator import RoadNetwork
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import networkx as nx
import numpy as np
from tqdm import tqdm

from models import Node2VecModel

In [2]:
network = RoadNetwork()
network.load("../osm_data/porto")

In [3]:
LG = network.line_graph
# create edge_index
map_id = {j:i for i,j in enumerate(LG.nodes)}
edge_list = nx.to_pandas_edgelist(LG)
edge_list["sidx"] = edge_list["source"].map(map_id)
edge_list["tidx"] = edge_list["target"].map(map_id)

edge_index = np.array(edge_list[["sidx", "tidx"]].values).T
edge_index = torch.tensor(edge_index, dtype=torch.long).contiguous()

print(edge_index.shape, len(list(LG.edges)))

torch.Size([2, 26617]) 26617


In [None]:
torch.cuda.is_available()

In [4]:
from torch_geometric.data import Data
import torch_geometric.transforms as T

data = Data(edge_index=edge_index)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.ToDevice(device),
])
data = transform(data)
model = Node2VecModel(data, device=device)
model.train(epochs=200)

Epoch: 10, avg_loss: 1.7161345858922168
Epoch: 20, avg_loss: 1.22438044450926
Epoch: 30, avg_loss: 1.0582666016026832
Epoch: 40, avg_loss: 0.9751250269205384
Epoch: 50, avg_loss: 0.9253085849258339
Epoch: 60, avg_loss: 0.8921480135524767
Epoch: 70, avg_loss: 0.868477387805238
Epoch: 80, avg_loss: 0.85072138729725
Epoch: 90, avg_loss: 0.8369138277202064
Epoch: 100, avg_loss: 0.8258651511387879
Epoch: 110, avg_loss: 0.8168230262547882
Epoch: 120, avg_loss: 0.8092834363325259
Epoch: 130, avg_loss: 0.8029041622616448
Epoch: 140, avg_loss: 0.7974384912470565
Epoch: 150, avg_loss: 0.7927000573288637
Epoch: 160, avg_loss: 0.788551892972227
Epoch: 170, avg_loss: 0.7848923380337088
Epoch: 180, avg_loss: 0.781640858950091
Epoch: 190, avg_loss: 0.778729697517359


In [5]:
model.save_model(save_name="node2vec_200e.pt", path="../model_states/node2vec/")
model.save_emb(path="../model_states/node2vec/")

In [9]:
from sklearn import model_selection
from sklearn import linear_model
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

X = model.load_emb() # embedding for each node
# train simple classifier on 80% of data with cross validation
y = np.array([network.gdf_edges.loc[n]["highway_enc"] for n in network.line_graph.nodes])

#mask = ((y==11) | (y==10) | (y==9) | (y==4) | (y==1) | (y==2) | (y==12) | (y==7)) # remove uncommon tags
#X = X[~mask, :]
#y = y[~mask]
print(np.unique(y, return_counts=True))

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size= 0.2, random_state = 1)

print('X_train dimension= ', X_train.shape)
print('X_test dimension= ', X_test.shape)
print('y_train dimension= ', y_train.shape)
print('y_test dimension= ', y_test.shape)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]), array([1680,   90,  144,  577,   76, 5832, 1451,   98, 1213,   20,   13,
         11,  126]))
X_train dimension=  (9064, 128)
X_test dimension=  (2267, 128)
y_train dimension=  (9064,)
y_test dimension=  (2267,)


In [10]:
lm = linear_model.LogisticRegression(multi_class="multinomial", max_iter=1000)
lm.fit(X_train, y_train)
print(metrics.classification_report(y_test, lm.predict(X_test)))

              precision    recall  f1-score   support

           0       0.54      0.40      0.46       309
           1       0.57      0.44      0.50        18
           2       0.52      0.48      0.50        31
           3       0.47      0.32      0.38       133
           4       0.43      0.23      0.30        13
           5       0.62      0.85      0.72      1175
           6       0.44      0.27      0.33       301
           7       0.00      0.00      0.00        20
           8       0.35      0.16      0.22       221
           9       0.00      0.00      0.00         5
          10       0.50      0.50      0.50         4
          11       0.00      0.00      0.00         3
          12       0.29      0.06      0.10        34

    accuracy                           0.58      2267
   macro avg       0.36      0.29      0.31      2267
weighted avg       0.54      0.58      0.54      2267



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
