In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
import pandas as pd
import torch.nn as nn
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split, Dataset

from tqdm import tqdm

from utils.london.link_loads import get_graph_attributes, df_to_graph, build_quarter_hour_data, add_missing_nodes
# import networkx as nx

In [3]:
folder_path = "data/london/"
# Get graph attributes, create dfs from csv, process Link column and ordered dfs by time
num_nodes, edge_index, node_mapping, dfs = get_graph_attributes(folder_path)

In [4]:
graph_data = []
# each df should have the same dimension and same nodes at the same columns
for filename, df in dfs.items():
    df = add_missing_nodes(df, node_mapping, num_nodes) # add zeros row for missing nodes
    df_qhrs = build_quarter_hour_data(df, filename, num_nodes) # retourne 24*4 df avec ses paramètres temporel et le flow
    graph_data.extend(df_qhrs)
    
graphs = [df_to_graph(df, edge_index) for df in graph_data]  # Un graphe par quart d'heure

In [5]:
class GraphSequenceDataset(Dataset):
    def __init__(self, graphs, window_size=4):
        self.graphs = graphs
        self.window_size = window_size
        if len(graphs) <= window_size:
            raise ValueError("Error : window_size should be higher than the number of graphs")
        
    def __len__(self):
        return len(self.graphs) - self.window_size

    def __getitem__(self, idx):
        # input_graphs prend en compte les window_size graphe précédent l'instant idx + le graphe à prèdire à l'instant idx
        input_graphs = self.graphs[idx : idx + self.window_size + 1]
        target = self.graphs[idx + self.window_size].y  # On prédit y du dernier graph
        return input_graphs, target 

In [6]:
# Définition des loaders
window_size = 4  # Nombre de pas de temps en entrée
train_size = int(0.8 * len(graphs))

train_dataset = GraphSequenceDataset(graphs[:train_size], window_size)  
test_dataset = GraphSequenceDataset(graphs[train_size:], window_size)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [14]:
class TGCN(nn.Module):
    def __init__(self, node_features, hidden_dim, gru_hidden_dim):
        super(TGCN, self).__init__()
        self.gcn = GCNConv(node_features, hidden_dim)  # GCN
        self.gru = nn.GRU(hidden_dim, gru_hidden_dim, batch_first=True)  # GRU
        self.fc = nn.Linear(gru_hidden_dim, 1)  # Prédiction finale

    def forward(self, graph_seq):
        # window_size = len(graph_seq)  # Nombre de pas de temps
        # batch_size = graph_seq[0].x.shape[0]  # Nombre de nœuds

        spatial_features = []
        for graph in graph_seq:
            x = self.gcn(graph.x, graph.edge_index)  # GCN
            x = F.relu(x)
            spatial_features.append(x)

        spatial_features = torch.stack(spatial_features, dim=1)  # (batch, time, hidden_dim)

        _ , final_state = self.gru(spatial_features)  # prédiction sur le dernier (1, gru_hidden_dim)
        final_state = final_state.squeeze() # (gru_hidden_dim)
        final_out = self.fc(final_state) # Prédiction sur le dernier état
        final_out = F.relu(final_out)
        return final_out

In [15]:
class CustomMAELoss(nn.Module):
    def __init__(self):
        super(CustomMAELoss, self).__init__()

    def forward(self, pred, target):
        return torch.mean(torch.abs(pred - target))

# Instanciation de la loss
MAE = CustomMAELoss()

In [None]:
model = TGCN(node_features=10, hidden_dim=32, gru_hidden_dim=64)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = MAE

# Boucle d'entraînement
for epoch in tqdm(range(50)):
    model.train()
    total_loss = 0
    for graph_seq, target in train_loader:
        optimizer.zero_grad()
        output = model(graph_seq)

        target = target.squeeze()
        output = output.reshape(target.shape)


        print(output.shape)
        print(target.shape)
    
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")


  2%|▏         | 1/50 [01:30<1:13:54, 90.51s/it]

Epoch 1, Loss: 243.2198999875875


  4%|▍         | 2/50 [03:00<1:12:21, 90.45s/it]

Epoch 2, Loss: 228.43867672306067


  6%|▌         | 3/50 [04:25<1:08:55, 87.98s/it]

Epoch 3, Loss: 216.81931867135106


  8%|▊         | 4/50 [05:54<1:07:29, 88.03s/it]

Epoch 4, Loss: 212.18706189455196


 10%|█         | 5/50 [07:24<1:06:35, 88.78s/it]

Epoch 5, Loss: 210.82059863291335


 12%|█▏        | 6/50 [08:49<1:04:13, 87.58s/it]

Epoch 6, Loss: 209.96486645076547


 14%|█▍        | 7/50 [10:13<1:01:55, 86.41s/it]

Epoch 7, Loss: 210.25340612492244


 16%|█▌        | 8/50 [11:37<1:00:01, 85.75s/it]

Epoch 8, Loss: 209.63629927232734


 18%|█▊        | 9/50 [13:03<58:29, 85.61s/it]  

Epoch 9, Loss: 208.92310597453948


 20%|██        | 10/50 [14:27<56:44, 85.10s/it]

Epoch 10, Loss: 208.43291046971544


 22%|██▏       | 11/50 [15:54<55:50, 85.92s/it]

Epoch 11, Loss: 208.26697427719697


 24%|██▍       | 12/50 [17:22<54:40, 86.34s/it]

Epoch 12, Loss: 208.11164522180948


 26%|██▌       | 13/50 [18:53<54:09, 87.84s/it]

Epoch 13, Loss: 208.49031959404437


 28%|██▊       | 14/50 [20:24<53:17, 88.83s/it]

Epoch 14, Loss: 207.6980908533258


 30%|███       | 15/50 [21:54<51:57, 89.08s/it]

Epoch 15, Loss: 207.68055955245177


 32%|███▏      | 16/50 [23:24<50:38, 89.37s/it]

Epoch 16, Loss: 207.32717494064786


 34%|███▍      | 17/50 [24:56<49:41, 90.36s/it]

Epoch 17, Loss: 207.18748548426643


 36%|███▌      | 18/50 [26:22<47:27, 88.99s/it]

Epoch 18, Loss: 206.93548003119946


 38%|███▊      | 19/50 [27:48<45:31, 88.13s/it]

Epoch 19, Loss: 206.60354017842235


 40%|████      | 20/50 [29:17<44:09, 88.32s/it]

Epoch 20, Loss: 206.60852854630159


 42%|████▏     | 21/50 [30:45<42:41, 88.34s/it]

Epoch 21, Loss: 206.85906348294756


 44%|████▍     | 22/50 [32:10<40:38, 87.10s/it]

Epoch 22, Loss: 206.7270019196338


 46%|████▌     | 23/50 [33:34<38:53, 86.42s/it]

Epoch 23, Loss: 206.64893615027339


 48%|████▊     | 24/50 [35:06<38:08, 88.01s/it]

Epoch 24, Loss: 206.66083501331198


 50%|█████     | 25/50 [36:34<36:36, 87.86s/it]

Epoch 25, Loss: 206.60395793384322


 52%|█████▏    | 26/50 [38:02<35:14, 88.10s/it]

Epoch 26, Loss: 206.47932126548517


 54%|█████▍    | 27/50 [39:30<33:43, 88.00s/it]

Epoch 27, Loss: 207.04883310574408


 56%|█████▌    | 28/50 [41:03<32:45, 89.35s/it]

Epoch 28, Loss: 206.89123925752767


 58%|█████▊    | 29/50 [42:36<31:41, 90.54s/it]

Epoch 29, Loss: 206.69313646626594


 60%|██████    | 30/50 [44:04<29:53, 89.66s/it]

Epoch 30, Loss: 206.49768830186798


 62%|██████▏   | 31/50 [44:59<25:06, 79.26s/it]

Epoch 31, Loss: 206.46059983913057


 64%|██████▍   | 32/50 [45:43<20:41, 68.96s/it]

Epoch 32, Loss: 206.38410073137473


 66%|██████▌   | 33/50 [46:28<17:29, 61.72s/it]

Epoch 33, Loss: 206.80647178334033


 68%|██████▊   | 34/50 [47:16<15:21, 57.60s/it]

Epoch 34, Loss: 206.6641405279979


 70%|███████   | 35/50 [48:04<13:41, 54.74s/it]

Epoch 35, Loss: 206.62936461987726


 72%|███████▏  | 36/50 [48:52<12:17, 52.71s/it]

Epoch 36, Loss: 206.81808008940828


 74%|███████▍  | 37/50 [49:42<11:13, 51.80s/it]

Epoch 37, Loss: 206.67207974927078


 76%|███████▌  | 38/50 [50:34<10:22, 51.89s/it]

Epoch 38, Loss: 206.42452382059923


 78%|███████▊  | 39/50 [51:25<09:26, 51.51s/it]

Epoch 39, Loss: 206.31405715062007


 80%|████████  | 40/50 [52:15<08:30, 51.04s/it]

Epoch 40, Loss: 206.36833346556037


 82%|████████▏ | 41/50 [53:04<07:35, 50.66s/it]

Epoch 41, Loss: 206.36231648722926


 82%|████████▏ | 41/50 [53:22<11:42, 78.10s/it]


KeyboardInterrupt: 

In [19]:
model.eval()
test_loss = 0
with torch.no_grad():
    for graph_seq, target in test_loader:

        output = model(graph_seq)
        loss = criterion(output, target)
        test_loss += loss.item()

print(f"Test MAE: {test_loss / len(test_loader)}")


Test MAE: 204.33262727453592


In [22]:
39e6/(96*1206)

336.8573797678275

The average flow per inter_station per 15min is 337. So 204 MAE error is not a valid result.