In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from generator import RoadNetwork
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import networkx as nx
import numpy as np
from tqdm import tqdm

from models import Node2VecModel

In [9]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1,2,3

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1,2,3


In [2]:
network = RoadNetwork()
network.load("../osm_data/porto")

In [8]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device_count())
torch.cuda.set_device(1)
print(torch.cuda.current_device())
print(torch.cuda.device_count())

True
0
4
1
4


In [15]:
from torch_geometric.data import Data
import torch_geometric.transforms as T

data = network.generate_road_segment_pyg_dataset()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.ToDevice(device),
])
data = transform(data)
model = Node2VecModel(data, device=device, q=4, p=1)
model.train(epochs=200)

Epoch: 20, avg_loss: 1.1357365068090095
Epoch: 40, avg_loss: 0.9282183041398445
Epoch: 60, avg_loss: 0.8587074920032799
Epoch: 80, avg_loss: 0.8239938451751563
Epoch: 100, avg_loss: 0.8031852890266462
Epoch: 120, avg_loss: 0.7893166749106812
Epoch: 140, avg_loss: 0.779409574070291
Epoch: 160, avg_loss: 0.7719806158433806
Epoch: 180, avg_loss: 0.7662018026565941


In [17]:
model.save_model(path="../model_states/node2vec/")
model.save_emb(path="../model_states/node2vec/")

In [18]:
from sklearn import model_selection
from sklearn import linear_model
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

X = model.load_emb() # embedding for each node
# train simple classifier on 80% of data with cross validation
y = np.array([network.gdf_edges.loc[n]["highway_enc"] for n in network.line_graph.nodes])

#mask = ((y==11) | (y==10) | (y==9) | (y==4) | (y==1) | (y==2) | (y==12) | (y==7)) # remove uncommon tags
#X = X[~mask, :]
#y = y[~mask]
print(np.unique(y, return_counts=True))

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size= 0.2, random_state = 1)

print('X_train dimension= ', X_train.shape)
print('X_test dimension= ', X_test.shape)
print('y_train dimension= ', y_train.shape)
print('y_test dimension= ', y_test.shape)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]), array([1680,   90,  144,  577,   76, 5832, 1451,   98, 1213,   20,   13,
         11,  126]))
X_train dimension=  (9064, 128)
X_test dimension=  (2267, 128)
y_train dimension=  (9064,)
y_test dimension=  (2267,)


In [19]:
lm = linear_model.LogisticRegression(multi_class="multinomial", max_iter=1000)
lm.fit(X_train, y_train)
print(metrics.classification_report(y_test, lm.predict(X_test)))

              precision    recall  f1-score   support

           0       0.57      0.44      0.49       309
           1       0.40      0.33      0.36        18
           2       0.47      0.52      0.49        31
           3       0.62      0.35      0.44       133
           4       1.00      0.15      0.27        13
           5       0.64      0.86      0.73      1175
           6       0.45      0.30      0.36       301
           7       0.00      0.00      0.00        20
           8       0.42      0.20      0.27       221
           9       0.00      0.00      0.00         5
          10       0.67      0.50      0.57         4
          11       0.00      0.00      0.00         3
          12       0.44      0.12      0.19        34

    accuracy                           0.60      2267
   macro avg       0.44      0.29      0.32      2267
weighted avg       0.57      0.60      0.56      2267



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
