In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from generator import RoadNetwork, Trajectory
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import networkx as nx
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch_geometric.transforms as T

from models import CLMModel
from models.utils import generate_trajid_to_nodeid

In [2]:
city = "sf"

In [3]:
network = RoadNetwork()
network.load(f"../../osm_data/{city}")
trajectory = pd.read_pickle(
    f"../../datasets/trajectories/{city}/traj_train_test_split/train_69.pkl"
)
trajectory["seg_seq"] = trajectory["seg_seq"].map(np.array)
data = network.generate_road_segment_pyg_dataset(include_coords=True, dataset=city)

In [21]:
# calculate transition matrix 
traj_map = generate_trajid_to_nodeid(network)
trans_mat = np.zeros((data.x.shape[0], data.x.shape[0]))
for seq in tqdm(trajectory.seg_seq):
    for i, id1 in enumerate(seq):
        for id2 in seq[i:]:
            node_id1, node_id2 = traj_map[id1], traj_map[id2]
            trans_mat[node_id1, node_id2] += 1

trans_mat = trans_mat / (trans_mat.max(axis=1, keepdims=True, initial=0.) + 1e-9)
row, col = np.diag_indices_from(trans_mat)
trans_mat[row, col] = 0

100%|██████████| 1080963/1080963 [16:02<00:00, 1123.30it/s]


In [8]:
np.save("clm_trans_mat_porto.npy", trans_mat)

In [4]:
trans_mat = np.load(f"./clm_trans_mat_{city}.npy")

In [5]:
trans_mat_b = (trans_mat > 0.6)
aug_edges = [(i // trans_mat.shape[0] , i % trans_mat.shape[0]) for i, n in enumerate(trans_mat_b.flatten()) if n]
aug_edge_index = torch.tensor(np.array(aug_edges).transpose()).cuda()

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trajectory.rename({"seg_seq": "path"}, inplace=True, axis=1)
model = CLMModel(data, device, network, trans_adj=aug_edge_index, traj_data=trajectory, batch_size=64, emb_dim=128)

In [12]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model.model)

+---------------------------------------------------------------------------+------------+
|                                  Modules                                  | Parameters |
+---------------------------------------------------------------------------+------------+
|                        module.node_embedding.weight                       |  1450368   |
|                   module.graph_encoder1.layers.0.att_src                  |    128     |
|                   module.graph_encoder1.layers.0.att_dst                  |    128     |
|                    module.graph_encoder1.layers.0.bias                    |    128     |
|               module.graph_encoder1.layers.0.lin_src.weight               |   16384    |
|                   module.graph_encoder1.layers.1.att_src                  |    128     |
|                   module.graph_encoder1.layers.1.att_dst                  |    128     |
|                    module.graph_encoder1.layers.1.bias                    |    128     |

1716608

In [22]:
torch.cuda.empty_cache()
del model
1/0

ZeroDivisionError: division by zero

In [9]:
model.train(epochs=5)

100%|██████████| 250000/250000 [00:47<00:00, 5218.32it/s]


12-04 14:14:33 | (Train) | Epoch=0, batch=200 loss=-0.6331, loss_ss=0.0100,  loss_tt=-0.9375,  loss_st1=-0.6732, loss_st2=-0.6776
12-04 14:17:38 | (Train) | Epoch=0, batch=400 loss=-0.7810, loss_ss=-0.0114,  loss_tt=-1.0704,  loss_st1=-0.8403, loss_st2=-0.8417
12-04 14:20:47 | (Train) | Epoch=0, batch=600 loss=-0.9322, loss_ss=-0.0335,  loss_tt=-1.1812,  loss_st1=-1.0277, loss_st2=-0.9991
12-04 14:24:08 | (Train) | Epoch=0, batch=800 loss=-0.8848, loss_ss=-0.0542,  loss_tt=-1.2448,  loss_st1=-0.9420, loss_st2=-0.9453
12-04 14:27:46 | (Train) | Epoch=0, batch=1000 loss=-0.9238, loss_ss=-0.0723,  loss_tt=-1.2390,  loss_st1=-1.0071, loss_st2=-0.9745
12-04 14:32:06 | (Train) | Epoch=0, batch=1200 loss=-0.9867, loss_ss=-0.0932,  loss_tt=-1.2338,  loss_st1=-1.0817, loss_st2=-1.0533
12-04 14:37:16 | (Train) | Epoch=0, batch=1400 loss=-1.0298, loss_ss=-0.1132,  loss_tt=-1.3352,  loss_st1=-1.1229, loss_st2=-1.0896
12-04 14:41:43 | (Train) | Epoch=0, batch=1600 loss=-1.0406, loss_ss=-0.1328,  lo

100%|██████████| 250000/250000 [00:47<00:00, 5292.56it/s]


12-04 15:30:32 | (Train) | Epoch=1, batch=200 loss=-1.1813, loss_ss=-0.3231,  loss_tt=-1.3382,  loss_st1=-1.2720, loss_st2=-1.2660
12-04 15:33:47 | (Train) | Epoch=1, batch=400 loss=-1.1622, loss_ss=-0.3330,  loss_tt=-1.3550,  loss_st1=-1.2470, loss_st2=-1.2366
12-04 15:36:55 | (Train) | Epoch=1, batch=600 loss=-1.1728, loss_ss=-0.3430,  loss_tt=-1.3487,  loss_st1=-1.2565, loss_st2=-1.2525
12-04 15:40:03 | (Train) | Epoch=1, batch=800 loss=-1.1821, loss_ss=-0.3531,  loss_tt=-1.3605,  loss_st1=-1.2642, loss_st2=-1.2626
12-04 15:43:11 | (Train) | Epoch=1, batch=1000 loss=-1.1752, loss_ss=-0.3612,  loss_tt=-1.3520,  loss_st1=-1.2584, loss_st2=-1.2513
12-04 15:46:20 | (Train) | Epoch=1, batch=1200 loss=-1.1677, loss_ss=-0.3696,  loss_tt=-1.3545,  loss_st1=-1.2443, loss_st2=-1.2438
12-04 15:49:31 | (Train) | Epoch=1, batch=1400 loss=-1.1810, loss_ss=-0.3768,  loss_tt=-1.3618,  loss_st1=-1.2589, loss_st2=-1.2590
12-04 15:52:37 | (Train) | Epoch=1, batch=1600 loss=-1.2031, loss_ss=-0.3838,  l

100%|██████████| 250000/250000 [00:46<00:00, 5430.99it/s]


12-04 16:32:47 | (Train) | Epoch=2, batch=200 loss=-1.2068, loss_ss=-0.4388,  loss_tt=-1.3628,  loss_st1=-1.2812, loss_st2=-1.2854
12-04 16:35:53 | (Train) | Epoch=2, batch=400 loss=-1.2109, loss_ss=-0.4426,  loss_tt=-1.3510,  loss_st1=-1.2904, loss_st2=-1.2885
12-04 16:38:59 | (Train) | Epoch=2, batch=600 loss=-1.2096, loss_ss=-0.4449,  loss_tt=-1.3748,  loss_st1=-1.2818, loss_st2=-1.2872
12-04 16:42:06 | (Train) | Epoch=2, batch=800 loss=-1.2165, loss_ss=-0.4481,  loss_tt=-1.3567,  loss_st1=-1.2963, loss_st2=-1.2938
12-04 16:45:13 | (Train) | Epoch=2, batch=1000 loss=-1.2225, loss_ss=-0.4501,  loss_tt=-1.3707,  loss_st1=-1.2995, loss_st2=-1.3016
12-04 16:48:20 | (Train) | Epoch=2, batch=1200 loss=-1.2126, loss_ss=-0.4528,  loss_tt=-1.3603,  loss_st1=-1.2905, loss_st2=-1.2878
12-04 16:51:26 | (Train) | Epoch=2, batch=1400 loss=-1.2197, loss_ss=-0.4547,  loss_tt=-1.3616,  loss_st1=-1.2981, loss_st2=-1.2971
12-04 16:54:33 | (Train) | Epoch=2, batch=1600 loss=-1.2112, loss_ss=-0.4582,  l

100%|██████████| 250000/250000 [00:46<00:00, 5340.41it/s]


12-04 17:34:40 | (Train) | Epoch=3, batch=200 loss=-1.2198, loss_ss=-0.4820,  loss_tt=-1.3676,  loss_st1=-1.2917, loss_st2=-1.2955
12-04 17:37:45 | (Train) | Epoch=3, batch=400 loss=-1.2245, loss_ss=-0.4830,  loss_tt=-1.3732,  loss_st1=-1.2971, loss_st2=-1.3000
12-04 17:40:52 | (Train) | Epoch=3, batch=600 loss=-1.2340, loss_ss=-0.4849,  loss_tt=-1.3690,  loss_st1=-1.3110, loss_st2=-1.3105
12-04 17:43:57 | (Train) | Epoch=3, batch=800 loss=-1.2341, loss_ss=-0.4860,  loss_tt=-1.3693,  loss_st1=-1.3098, loss_st2=-1.3116
12-04 17:47:04 | (Train) | Epoch=3, batch=1000 loss=-1.2249, loss_ss=-0.4872,  loss_tt=-1.3681,  loss_st1=-1.2967, loss_st2=-1.3018
12-04 17:50:09 | (Train) | Epoch=3, batch=1200 loss=-1.2318, loss_ss=-0.4878,  loss_tt=-1.3768,  loss_st1=-1.3076, loss_st2=-1.3059
12-04 17:53:16 | (Train) | Epoch=3, batch=1400 loss=-1.2281, loss_ss=-0.4899,  loss_tt=-1.3760,  loss_st1=-1.2996, loss_st2=-1.3042
12-04 17:56:24 | (Train) | Epoch=3, batch=1600 loss=-1.2302, loss_ss=-0.4916,  l

100%|██████████| 250000/250000 [00:47<00:00, 5281.13it/s]


12-04 18:36:26 | (Train) | Epoch=4, batch=200 loss=-1.2278, loss_ss=-0.5062,  loss_tt=-1.3751,  loss_st1=-1.2982, loss_st2=-1.3009
12-04 18:39:32 | (Train) | Epoch=4, batch=400 loss=-1.2453, loss_ss=-0.5063,  loss_tt=-1.3825,  loss_st1=-1.3161, loss_st2=-1.3251
12-04 18:42:38 | (Train) | Epoch=4, batch=600 loss=-1.2403, loss_ss=-0.5076,  loss_tt=-1.3728,  loss_st1=-1.3133, loss_st2=-1.3173
12-04 18:45:46 | (Train) | Epoch=4, batch=800 loss=-1.2402, loss_ss=-0.5079,  loss_tt=-1.3651,  loss_st1=-1.3130, loss_st2=-1.3193
12-04 18:48:56 | (Train) | Epoch=4, batch=1000 loss=-1.2298, loss_ss=-0.5082,  loss_tt=-1.3723,  loss_st1=-1.3033, loss_st2=-1.3011
12-04 18:52:06 | (Train) | Epoch=4, batch=1200 loss=-1.2355, loss_ss=-0.5102,  loss_tt=-1.3720,  loss_st1=-1.3078, loss_st2=-1.3103
12-04 18:55:12 | (Train) | Epoch=4, batch=1400 loss=-1.2293, loss_ss=-0.5120,  loss_tt=-1.3660,  loss_st1=-1.2995, loss_st2=-1.3043
12-04 18:58:17 | (Train) | Epoch=4, batch=1600 loss=-1.2322, loss_ss=-0.5125,  l

In [10]:
torch.save(model.model.state_dict(), os.path.join("./clm_sf.pt"))

In [11]:
model.load_model("clm_sf.pt")

In [12]:
z = model.load_emb()

In [14]:
from sklearn import model_selection
from sklearn import linear_model
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.metrics import make_scorer

# n2v = models[-1]
idxs = np.arange(len(network.line_graph.nodes))
train_idx, test_idx = model_selection.train_test_split(idxs, test_size=0.2, random_state=69)
y = np.array([network.gdf_edges.loc[n]["highway_enc"] for n in network.line_graph.nodes])

# for m, e in models:
    # m.train(epochs=e)
    # zn = m.load_emb()
    # zcn = np.concatenate((zn, z2), axis=1)
    # zct = np.concatenate((zn, z3), axis=1)
    # zcnn = np.concatenate((zn, z4), axis=1)
    # zctn = np.concatenate((zn, z5), axis=1)
    # X = z # embedding for each node
eva = [z] # gtc.load_emb(), gae_emb, rand_emb
for X in eva:
    X_train, X_test, y_train, y_test = X[train_idx], X[test_idx], y[train_idx], y[test_idx]

    lm = linear_model.LogisticRegression(multi_class="multinomial", max_iter=1000)
    # lm.fit(X_train, y_train)
    scorer = make_scorer(metrics.f1_score, average="macro")
    print(np.mean(cross_val_score(estimator=lm, X=X, y=y, scoring=scorer, cv=5)))
    #print(metrics.classification_report(y_test, lm.predict(X_test)))

0.341889072799386
