## Construct line graph from toy example
- Check python kernel and import graph construction libraries

In [1]:
from src.graph import SNAPSHOTS
import networkx as nx
import math
import csv

In [2]:
WEATHER_FEATURES = ['wind', 'wind_max', 'temperature', 'rain', 'rain_duration', 'fog', 'snow', 'thunder', 'ice']
NODE_TYPES = ['TRAIN', 'WEATHER']
EDGE_TYPES = ['TRAIN', 'WEATHER', 'DISRUPTION']

In [3]:
G = SNAPSHOTS[0]

In [4]:
def build_minimal_graph(graph):
    H = nx.Graph()
    H.graph = graph.graph.copy()

    for node, attrs in graph.nodes(data=True):
        if attrs.get("type") != "TRAIN":
            continue

        ws_code = attrs.get("weather_station")
        ws_attrs = graph.nodes[ws_code]
        
        node_features = {}
        for feat in WEATHER_FEATURES:
            node_features[feat] = float(ws_attrs.get(feat, 0.0))
        H.add_node(node, **node_features)

    for u, v, eattrs in graph.edges(data=True):
        if eattrs.get("type") != "WEATHER":
            duration = eattrs.get("duration", 0)
            H.add_edge(u, v, duration=duration)
    
    return H

In [5]:
def make_linegraph(G):
    LG = nx.line_graph(G)
    LG.graph = G.graph.copy()
    for (u,v) in LG.nodes:
        LG.nodes[u,v]['duration'] = G.edges[u,v]['duration']
    for node_u, node_v in LG.edges:
        s = (set(node_u) & set(node_v)).pop()
        for feat in WEATHER_FEATURES:
            LG.edges[node_u, node_v][feat] = G.nodes[s][feat]
    return LG

In [6]:
LG = make_linegraph(build_minimal_graph(G))
print(LG)

Graph with 431 nodes and 581 edges


## GCN + RNN


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import random, numpy as np
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


- Data preparation for GCN + RNN
- Mean temperature as feature vector

In [8]:
node2idx = {n:i for i,n in enumerate(LG.nodes())}
idx2node = {i:n for i,n in enumerate(LG.nodes())}
print(node2idx)
print(idx2node)

{('NSCH', 'WS'): 0, ('DZ', 'DZW'): 1, ('HDR', 'HDRZ'): 2, ('EEM', 'RD'): 3, ('EDN', 'MTR'): 4, ('EMN', 'EMNZ'): 5, ('BKF', 'EKZ'): 6, ('EGHM', 'LG'): 7, ('ESE', 'GBR'): 8, ('HLG', 'HLGH'): 9, ('KPN', 'ZLSH'): 10, ('CVM', 'KRD'): 11, ('HGLO', 'ODZ'): 12, ('RHN', 'VNDC'): 13, ('KMW', 'STV'): 14, ('UTM', 'UTO'): 15, ('VDM', 'ZB'): 16, ('VS', 'VSS'): 17, ('OVN', 'ZVT'): 18, ('ATN', 'VSV'): 19, ('TBG', 'VSV'): 20, ('BGN', 'RB'): 21, ('KBD', 'RB'): 22, ('DL', 'ZL'): 23, ('WH', 'ZL'): 24, ('GN', 'GNN'): 25, ('GNN', 'SWD'): 26, ('AC', 'ASHD'): 27, ('ASB', 'ASHD'): 28, ('HGZ', 'MTH'): 29, ('HGZ', 'ZB'): 30, ('HVSP', 'HOR'): 31, ('HOR', 'UTO'): 32, ('BTL', 'OT'): 33, ('BTL', 'VG'): 34, ('BLL', 'HLM'): 35, ('HLM', 'HLMS'): 36, ('DT', 'DTCP'): 37, ('DTCP', 'SDM'): 38, ('PMR', 'PMW'): 39, ('PMW', 'ZDK'): 40, ('HNP', 'WK'): 41, ('IJT', 'WK'): 42, ('BHV', 'UTO'): 43, ('UT', 'UTO'): 44, ('BMR', 'CK'): 45, ('CK', 'MMLH'): 46, ('HT', 'HTO'): 47, ('HT', 'TB'): 48, ('BSMZ', 'HVSM'): 49, ('HVS', 'HVSM'): 5

In [9]:
incident_edges = {n: [] for n in LG.nodes()}
for u, v in LG.edges():
    incident_edges[u].append((u, v))
    incident_edges[v].append((u, v))

In [10]:
def linegraph_to_pyg(LG, default_edge_weight = 1.0):
    X = []
    Y = []
    for node in node2idx:
        feature_vector = []
        for feat in WEATHER_FEATURES:
            mean_value = 0
            for (u, v) in incident_edges[node]:
                mean_value = mean_value + LG.edges[(u,v)][feat]
            mean_value = mean_value / len(incident_edges[node])
            feature_vector.append(mean_value)
        X.append(feature_vector)
        Y.append(LG.nodes[node]["duration"])
    X = torch.tensor(X, dtype=torch.float32)
    Y = torch.tensor(Y, dtype=torch.float32)

    return X, Y

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
X_seq = []
Y_seq = []
edge_index = []
edge_weights = []

for t in SNAPSHOTS:
    graph = SNAPSHOTS[t]
    G = build_minimal_graph(graph)
    LG = make_linegraph(G)
    X, Y = linegraph_to_pyg(LG)
    X_seq.append(X)
    Y_seq.append(Y)
X_seq = torch.stack(X_seq)
Y_seq = torch.stack(Y_seq)
print(X_seq.shape)
print(Y_seq.shape)

torch.Size([744, 431, 9])
torch.Size([744, 431])


In [13]:
default_edge_weight = 1
idx_pairs = []
edge_weights = []
for (u, v) in LG.edges():
    idx_pairs.append((node2idx[u], node2idx[v]))
    edge_weights.append(default_edge_weight)
    idx_pairs.append((node2idx[v], node2idx[u]))
    edge_weights.append(default_edge_weight)
edge_index = torch.tensor(idx_pairs, dtype=torch.long).t().contiguous()
edge_weights = torch.full((edge_index.size(1),), 1.0, dtype=torch.float32)
edge_index = edge_index.to(dtype=torch.long, device=device)
edge_weights = edge_weights.to(dtype=torch.float32, device=device)
print(edge_index.shape)
print(edge_weights.shape)

torch.Size([2, 1162])
torch.Size([1162])


In [14]:
def split_dataset(X_seq, Y_seq, train_ratio=0.6, val_ratio=0.2):
    total = len(X_seq)
    n_train = math.floor(total * train_ratio)
    n_val   = math.floor(total * val_ratio)

    X_train = X_seq[:n_train]
    Y_train = Y_seq[:n_train]

    X_val   = X_seq[n_train:n_train+n_val]
    Y_val   = Y_seq[n_train:n_train+n_val]

    X_test  = X_seq[n_train+n_val:]
    Y_test  = Y_seq[n_train+n_val:]


    return X_train, Y_train, X_val, Y_val, X_test, Y_test

In [15]:
def slide_window(X_seq, Y_seq, window_sizes):
    T, N, F = X_seq.shape

    X_list = []
    Y_list = []
    for W in window_sizes:
        for t in range(W, T):
            X_window = X_seq[t-W:t]
            Y_next = Y_seq[t]
            X_list.append(X_window)
            Y_list.append(Y_next)

    return X_list, Y_list

In [16]:
X_train, Y_train, X_val, Y_val, X_test, Y_test = split_dataset(X_seq, Y_seq)
print(X_train.shape)
print(Y_train.shape)
print(X_val.shape)
print(Y_val.shape)
print(X_test.shape)
print(Y_test.shape)

torch.Size([446, 431, 9])
torch.Size([446, 431])
torch.Size([148, 431, 9])
torch.Size([148, 431])
torch.Size([150, 431, 9])
torch.Size([150, 431])


In [17]:
window_sizes = [48, 24, 8, 4, 2]
X_train_window, Y_train_window = slide_window(X_train, Y_train, window_sizes)
X_val_window, Y_val_window = slide_window(X_val, Y_val, window_sizes)
X_test_window, Y_test_window = slide_window(X_test, Y_test, window_sizes)
print(len(X_train_window))
print(len(Y_train_window))
print(len(X_val_window))
print(len(Y_val_window))
print(len(X_test_window))
print(len(Y_test_window))

2144
2144
654
654
664
664


- Very simple GCN encoder

In [18]:
class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_weight):
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight)
        return x

- GCN + RNN Class

In [19]:
class GCN_RNN(nn.Module):
    def __init__(self, in_channels, gcn_hidden, rnn_hidden=64):
        super().__init__()
        self.gcn = GCNEncoder(in_channels, gcn_hidden)
        self.gru = nn.GRU(input_size=gcn_hidden, hidden_size=rnn_hidden, num_layers=1)
        self.out = nn.Linear(rnn_hidden, 1)

    def forward(self, X_seq, edge_index, edge_weight):
        # for all graphs_0 ... graph_t
        H_t = []
        for x in X_seq:
            # get the embeddings at t
            h = self.gcn(x, edge_index, edge_weight)
            # stack all the embeddings 
            H_t.append(h)
        H = torch.stack(H_t, dim=0)
        
        # RNN on embedding sequence
        H_gru, _ = self.gru(H)
        # summarized state i.e the last state
        last = H_gru[-1]
        # from embeddings to prediction
        y = self.out(last)
        return y

In [20]:
def train_one_epoch(model, optim, X_windows, Y_windows):
    model.train()

    for X_seq, Y_true in zip(X_windows, Y_windows):

        X_seq = X_seq.to(torch.float32).to(device)
        Y_true = Y_true.to(torch.float32).to(device)

        y_pred = model(X_seq, edge_index, edge_weights)

        loss = F.mse_loss(y_pred, Y_true.unsqueeze(-1))

        optim.zero_grad()
        loss.backward()
        optim.step()

In [21]:
@torch.no_grad()
def evaluate(model, X_windows, Y_windows):
    model.eval()
    eval_loss = 0.0

    for X_seq, Y_true in zip(X_windows, Y_windows):
        X_seq = X_seq.to(torch.float32).to(device)
        Y_true = Y_true.to(torch.float32).to(device)

        y_pred = model(X_seq, edge_index, edge_weights)  # [N,1]
        loss = F.mse_loss(y_pred, Y_true.unsqueeze(-1))
        eval_loss += loss.item()

    avg_loss = eval_loss / len(X_windows)
    return avg_loss

In [22]:
def run_training_for_config(in_channels, gcn_hidden, rnn_hidden, lr, num_epochs,):
    model = GCN_RNN(in_channels, gcn_hidden, rnn_hidden).to(device)

    optim = torch.optim.Adam(model.parameters(), lr=lr)

    best_val = float("inf")
    best_state = None
    best_epochs = 0

    for epoch in range(num_epochs):
        train_one_epoch(model, optim, X_train_window, Y_train_window,)
        val_loss = evaluate(model, X_val_window, Y_val_window)

        # keep best
        if val_loss < best_val:
            best_val = val_loss
            best_epochs = epoch

    test_loss = evaluate(model, X_test_window, Y_test_window)

    return {
        "best_epoch": best_epochs,
        "best_val_loss": best_val,
        "test_loss": test_loss,
        "gcn_hidden": gcn_hidden,
        "rnn_hidden": rnn_hidden,
        "lr": lr,
    }

In [None]:
in_channels = len(WEATHER_FEATURES)

configs = [
    # (gcn_hidden, rnn_hidden, lr)
    (32, 32, 1e-3),
    (64, 64, 1e-3),
    (128, 128, 1e-3),
    (32, 32, 5e-3),
    (64, 64, 5e-3),
    (128, 128, 5e-3),
    (32, 32, 1e-2),
    (64, 64, 1e-2),
    (128, 128, 1e-2),
]

results = []
for gcn_h, rnn_h, lr in configs:
    out = run_training_for_config(len(WEATHER_FEATURES), gcn_hidden=gcn_h, rnn_hidden=rnn_h, lr=lr, num_epochs=50)
    results.append(out)
    print(gcn_h, rnn_h, lr, "Finished")
    print(out)

results = sorted(results, key=lambda r: r["best_val_loss"])

fields = ["gcn_hidden", "rnn_hidden", "lr", "best_val_loss", "test_loss", "best_epoch"]

csv_path = "experiment_results.csv"

with open(csv_path, mode="w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=fields)
    writer.writeheader()
    for r in results:
        row = {k: r[k] for k in fields}
        writer.writerow(row)

print(f"Wrote {len(results)} rows to {csv_path}")

best = results[0]
print("Best config:", {k: best[k] for k in fields})

KeyboardInterrupt: 