In [13]:
%load_ext autoreload
%autoreload 2

import os
import sys

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from generator import RoadNetwork
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import networkx as nx
import numpy as np
from tqdm import tqdm

from models import PCAModel

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
network = RoadNetwork()
network.load("../../osm_data/porto")

In [15]:
from torch_geometric.data import Data
import torch_geometric.transforms as T

# create pyg dataset
data = network.generate_road_segment_pyg_dataset()

model = PCAModel(data, emb_dim=4)
model.train()
model.save_emb(path="../../model_states/pca/")

In [16]:
"""
Evaluate on road type classification
"""
from sklearn import model_selection
from sklearn import linear_model
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

X = model.emb # embedding for each node
# train simple classifier on 80% of data with cross validation
y = np.array([network.gdf_edges.loc[n]["highway_enc"] for n in network.line_graph.nodes])

# mask = ((y==11) | (y==10) | (y==9) | (y==4) | (y==1) | (y==2) | (y==12) | (y==7)) # remove uncommon tags
# X = X[~mask, :]
# y = y[~mask]
# print(np.unique(y, return_counts=True))

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size= 0.2, random_state = 1)

print('X_train dimension= ', X_train.shape)
print('X_test dimension= ', X_test.shape)
print('y_train dimension= ', y_train.shape)
print('y_test dimension= ', y_test.shape)

lm = linear_model.LogisticRegression(multi_class="multinomial", max_iter=1000)
lm.fit(X_train, y_train)
print(metrics.classification_report(y_test, lm.predict(X_test)))

X_train dimension=  (9064, 4)
X_test dimension=  (2267, 4)
y_train dimension=  (9064,)
y_test dimension=  (2267,)
              precision    recall  f1-score   support

           0       0.99      0.76      0.86       309
           1       0.00      0.00      0.00        18
           2       0.00      0.00      0.00        31
           3       0.00      0.00      0.00       133
           4       0.00      0.00      0.00        13
           5       0.58      1.00      0.73      1175
           6       0.00      0.00      0.00       301
           7       0.00      0.00      0.00        20
           8       1.00      0.00      0.01       221
           9       0.00      0.00      0.00         5
          10       0.00      0.00      0.00         4
          11       0.00      0.00      0.00         3
          12       0.00      0.00      0.00        34

    accuracy                           0.62      2267
   macro avg       0.20      0.14      0.12      2267
weighted avg       0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
