# Import libraries and define globals

In [None]:
%pip install pyrosm tqdm folium
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

In [None]:
import torch
!pip install -q torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install -q torch-geometric-temporal==0.54.0

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
import math
import os
import pickle
import random

import boto3
import folium
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objs as go
import pyrosm
from sklearn.metrics import mean_squared_error
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric_temporal.nn.recurrent import A3TGCN
from torch_geometric_temporal.signal import StaticGraphTemporalSignal, temporal_signal_split

In [None]:
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

In [None]:
CITY_ID = 1_000_000
MAP_FILE = f"{CITY_ID}-latest.osm.pbf"
LABEL = "speed_kmh"
S3 = boto3.client('s3')
S3_BUCKET = "some_bucket"
S3_SUBDIR = f"subdir_path"
S3_DATA = "data_path"
S3_PREDS = f"{S3_SUBDIR}/model_preds"
S3_FILENAME = "edge_time_aggregated_4_lags.parquet"
N_WEEKS = 4
N_WEEKS_TRAINING = 2
N_WEEKS_VALIDATION = 1
TRAIN_RATIO = N_WEEKS_TRAINING / N_WEEKS
EPOCHS = 100
HEADS = 8
DROPOUT=0
LEARNING_RATE = 1e-3
HIDDEN_CHANNELS = 8
OUT_CHANNELS = 1
LOG_FREQ = 1
EARLY_STOP_THRESHOLD = 5
DATA_SPLITS = ["train", "valid", "test"]

In [None]:
S3.download_file(S3_BUCKET, f"{S3_SUBDIR}/unique_edges.pickle", "unique_edges.pickle")
with open("unique_edges.pickle", "rb") as f:
    UNIQUE_EDGES = pickle.load(f)
len(UNIQUE_EDGES)

In [None]:
# MODEL_NAME = f"gnn_2_gats_{EPOCHS}_hidden_channels_{HIDDEN_CHANNELS}_epochs_{len(UNIQUE_EDGES)}_edges_{N_WEEKS}_weeks"
GNN_DATASET_NAME = f"gnn_dataset_{len(UNIQUE_EDGES)}_edges_{N_WEEKS}_weeks_normalised"

In [None]:
def compute_adjacency_matrix():
    adjacency_matrix = np.zeros((len(UNIQUE_EDGES), len(UNIQUE_EDGES)))

    for i, edge_i in enumerate(UNIQUE_EDGES):
        for j, edge_j in enumerate(UNIQUE_EDGES):
            if set(edge_i).intersection(set(edge_j)):
                adjacency_matrix[i, j] = 1
                adjacency_matrix[j, i] = 1

    adjacency_matrix = adjacency_matrix.astype(np.float32)
    edge_index = (np.array(adjacency_matrix) > 0).nonzero()
    return adjacency_matrix, edge_index

# Data imputation methods

In [None]:
def fallback_to_past(edge, minute_bucket, fallback_horizon, unit='m'):
    return DATASET_DICT.get((edge, minute_bucket - pd.Timedelta(fallback_horizon, unit=unit)))


def neighbour_average(edge, minute_bucket):
    neighbour_indicies = np.nonzero(ADJACENCY_MATRIX[EDGE_IDX_MAP[edge]])[0]
    neighbour_speeds = []
    for idx in neighbour_indicies:
        speed = DATASET_DICT.get((edge, minute_bucket))
        if speed is None or math.isnan(speed):
            continue
        neighbour_speeds.append(speed)
    return np.mean(neighbour_speeds)


def expand_edge_time_series(edge_df):
    edge_df = (edge_df.reset_index().set_index("minute_bucket")
        .join(DATASET_RANGE_DF, how="right", lsuffix='l')
        .drop(["index", "indexl"], axis=1))
    edge_df["edge"] = edge_df.edge.ffill().bfill()
    edge_df = edge_df.reset_index()
    return edge_df
    

def compute_rolling_mean(speeds_df, window):
    rolling_window_speed_avg_df = (
        pd.concat([expand_edge_time_series(g) for _, g in subgraph_speeds_df[["edge", "minute_bucket", "speed_kmh"]].groupby("edge")])
        .set_index("minute_bucket").groupby("edge").rolling(window).mean())
    rolling_window_speed_avg_df.dropna(inplace=True)
    return rolling_window_speed_avg_df.to_dict()["speed_kmh"]


def impute_nan(edge, minute_bucket):
    """Data imputation method with the following steps:
        1. Speed on the same edge at the same time 1 week ago
        2. Speed on the same edge at the same time 2 weeks ago 
        3. Average neighbour speed at the current timestamp a week ago
        4. Average neighbour speed at the current timestamp 2 weeks ago
        5. Average accross all edges 15 minutes ago
        6. Average over all past values before current timestamp for the current edge
        7. Global mean speed
    """
    for horizon, unit in [(1, 'W'), (2, 'W')]:
        speed = fallback_to_past(edge, minute_bucket, horizon, unit)
        if speed is not None:
            return speed
        
    speed = neighbour_average(edge, minute_bucket-pd.Timedelta(1, unit='W'))
    if math.isnan(speed):
        speed = neighbour_average(edge, minute_bucket-pd.Timedelta(2, unit='W'))
    else:
        return speed
    
    if math.isnan(speed):
        speed = fallback_to_past(edge, minute_bucket, 15, 'm')
    else:
        return speed
    
    if speed is None or math.isnan(speed):
        speed = neighbour_average(edge, minute_bucket-pd.Timedelta(15, unit='m'))
    else:
        return speed
    
    if math.isnan(speed):
        speed = ROLLING_1H_WINDOW_EDGE_TIME_AVG_DICT.get((edge, minute_bucket))
    else:
        return speed
    
    if speed is None:
        speed = ROLLING_2H_WINDOW_EDGE_TIME_AVG_DICT.get((edge, minute_bucket))
    else:
        return speed
    
    if speed is None:
        speed = ROLLING_3H_WINDOW_EDGE_TIME_AVG_DICT.get((edge, minute_bucket))
    else:
        return speed
    
    if speed is None:
        speed = ROLLING_4H_WINDOW_EDGE_TIME_AVG_DICT.get((edge, minute_bucket))
    else:
        return speed

    if speed is None:
        speed = ROLLING_EDGE_TIME_AVG_DICT.get((edge, minute_bucket))
    else:
        return speed
    
    if speed is None:
        speed = EDGE_15_MIN_BUCKET_DICT.get((edge, minute_bucket))
    else:
        return speed
    
    if speed is None:
        speed = EDGE_AVG_DICT.get((edge, minute_bucket))
    else:
        return speed
    
    if speed is None:
        speed = MEAN_SPEED
    else:
        return speed
    return speed


def impute_dataset(speeds_df, imputation_method):
    """Iterate over a speeds data frame in 15-minute interval groups, fill missing values, collect into a list of snapshots."""
    xs = []
    ys = []
    feature_imputation_count = 0
    target_imputation_count = 0
    target_mask = np.ones((len(DATASET_DATE_RANGE), len(UNIQUE_EDGES)), dtype=int)
    for i, (minute_bucket, minute_bucket_group) in enumerate(tqdm(speeds_df.groupby("minute_bucket"))):
        edge_dict = minute_bucket_group[["edge", "speed_kmh"] + SPEED_FEATURES].set_index("edge").to_dict()
        measurements = []
        targets = []
        past_hour = [(minute, minute_bucket - pd.to_timedelta(minute, unit='m')) for minute in [15, 30, 45, 60]]
        next_15 = minute_bucket + pd.to_timedelta(15, unit='m')
        for j, edge in enumerate(UNIQUE_EDGES):
            row = []
            for minute, quarter in past_hour:
                speed = edge_dict[f"speed_kmh_lag_{minute}_m"].get(edge)
                if speed is None or math.isnan(speed):
                    speed = imputation_method(edge, quarter)
                    feature_imputation_count += 1
                row.append(speed)
            measurements.append(row)
            speed = edge_dict["speed_kmh"].get(edge)
            if speed is None or math.isnan(speed):
                # TODO: not the most efficient way of skipping unpopular segments
                # These are the segments that linear regression couldn't be trained on due to insufficient amount of data
                speed = imputation_method(edge, next_15)
                target_imputation_count += 1
                target_mask[i, j] = 0
            targets.append(speed)
        xs.append(measurements)
        ys.append(targets)
    xs = np.array(xs, dtype=np.float32)
    ys = np.array(ys, dtype=np.float32)

    print(f"Feature imputation count: {feature_imputation_count}")
    print(f"Target imputation count: {target_imputation_count}")
    print(f"Total number of values: {len(UNIQUE_EDGES) * len(DATASET_DATE_RANGE) * 5}")
    print()

    return xs, ys, target_mask

# Baselines

In [None]:
def evaluate_global_mean_baseline(dataset):
    mse = 0
    mae = 0
    for snapshot in dataset:
        mse += (((MEAN_SPEED - snapshot.y)*snapshot.mask)**2).sum() / snapshot.mask.sum()
        mae += (np.abs((MEAN_SPEED - snapshot.y)*snapshot.mask)).sum() / snapshot.mask.sum()
    mse /= dataset.snapshot_count
    mae /= dataset.snapshot_count
    return mse, mae


def evaluate_edge_average_baseline(dataset):
    mse = 0
    mae = 0
    for snapshot in dataset:
        snapshot_mse = 0
        snapshot_mae = 0
        for i, edge in enumerate(UNIQUE_EDGES):
            snapshot_mse += snapshot.mask[i] * (EDGE_AVG_DICT.get(edge, MEAN_SPEED) - snapshot.y[i])**2
            snapshot_mae += snapshot.mask[i] * np.abs(EDGE_AVG_DICT.get(edge, MEAN_SPEED) - snapshot.y[i])
        snapshot_mse /= snapshot.mask.sum()
        snapshot_mae /= snapshot.mask.sum()
        mse += snapshot_mse
        mae += snapshot_mae
    mse /= dataset.snapshot_count
    mae /= dataset.snapshot_count
    return mse, mae


def edge_time_naive(edge, timestamp):
    weekday = timestamp.weekday()
    hour = timestamp.hour
    minute = timestamp.minute
    return EDGE_15_MIN_BUCKET_DICT.get((edge, weekday, hour, minute), EDGE_AVG_DICT.get(edge, MEAN_SPEED))


def rolling_edge_time_avg_naive(edge, minute_bucket):
    return ROLLING_EDGE_TIME_AVG_DICT.get((edge, minute_bucket), MINUTE_BUCKET_AVG_DICT.get((minute_bucket - pd.Timedelta(15, unit='m')), MEAN_SPEED))


def evaluate_edge_time_average_baseline(dataset, date_range, naive):
    mse = 0
    mae = 0
    for timestamp, snapshot in zip(date_range, dataset):
        snapshot_mse = 0
        snapshot_mae = 0
        for i, edge in enumerate(UNIQUE_EDGES):
            snapshot_mse += snapshot.mask[i] * (naive(edge, timestamp) - snapshot.y[i])**2
            snapshot_mae += snapshot.mask[i] * np.abs(naive(edge, timestamp) - snapshot.y[i])
        snapshot_mse /= snapshot.mask.sum()
        snapshot_mae /= snapshot.mask.sum()
        mse += snapshot_mse
        mae += snapshot_mae
    mse /= dataset.snapshot_count
    mae /= dataset.snapshot_count
    return mse, mae

# GNN training and evaluation code

In [None]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, dim_in, hidden_channels, periods):
        super().__init__()
        self.tgnn = A3TGCN(in_channels=dim_in, out_channels=hidden_channels, periods=periods)
        self.linear = torch.nn.Linear(hidden_channels, periods)

    def forward(self, x, edge_index):
        h = self.tgnn(x, edge_index).relu()
        h = self.linear(h)
        return h


def plot_curves(losses):
    fig, ax = plt.subplots(1, 2, figsize=(20, 5))
    ax[0].plot(range(len(losses)), losses, label=["Train", "Validation"])
    ax[0].legend()
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Mean Squared Error")
    ax[0].set_title("Learning curves for a simple GNN")

    # Plot the second half of the learning curve to avoid the huge spikes at the beginning of training
    ax[1].plot(range(len(losses)//2, len(losses), 1), losses[len(losses)//2:], label=["Train", "Validation"])
    ax[1].legend()
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("Mean Squared Error")
    ax[1].set_title(f"Learning curves for a simple GNN starting from epoch {len(losses)//2}")
    plt.show()


def train(train_dataset, valid_dataset, epochs=10):
    model = TemporalGNN(len(SPEED_FEATURES), HIDDEN_CHANNELS, OUT_CHANNELS)

    optimiser = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-3)
    # scheduler = lr_scheduler.LinearLR(optimiser, start_factor=1.0, end_factor=0.5, total_iters=30)
    # scheduler = lr_scheduler.ExponentialLR(optimiser, gamma=0.99)
    # mse = torch.nn.MSELoss(reduction="sum")
    # mae = torch.nn.L1Loss(reduction="sum")
    mse = torch.nn.MSELoss()
    mae = torch.nn.L1Loss()
    model.train()

    best_val_mse = 1_000_000
    best_epoch = -1
    mse_losses = []
    for epoch in range(1, epochs+1):
        model.train()
        mse_loss = 0
        mae_loss = 0
        # sample_cnt = 0
        for snapshot in train_dataset:
            mask = snapshot.mask == 1
            # sample_cnt += mask.sum()
            y_pred = model(snapshot.x.unsqueeze(2), snapshot.edge_index)
            # mse_loss += torch.sum((y_pred.flatten()[mask] - snapshot.y[mask])**2)
            # mae_loss += torch.sum(torch.abs(y_pred.flatten()[mask] - snapshot.y[mask]))
            mse_loss += mse(y_pred.flatten()[mask], snapshot.y[mask])
            mae_loss += mae(y_pred.flatten()[mask], snapshot.y[mask])
            # mse_loss += mse(y_pred.flatten(), snapshot.y)
            # mae_loss += mae(y_pred.flatten(), snapshot.y)
        mse_loss /= train_dataset.snapshot_count
        mae_loss /= train_dataset.snapshot_count
        # mse_loss /= sample_cnt
        # mae_loss /= sample_cnt
        mse_loss.backward()
        optimiser.step()
        optimiser.zero_grad()

        model.eval()
        val_mse_loss = 0
        val_mae_loss = 0
        # sample_cnt = 0
        for snapshot in valid_dataset:
            mask = snapshot.mask == 1
            # sample_cnt += mask.sum()
            y_pred = model(snapshot.x.unsqueeze(2), snapshot.edge_index)
            val_mse_loss += mse(y_pred.flatten()[mask], snapshot.y[mask])
            val_mae_loss += mae(y_pred.flatten()[mask], snapshot.y[mask])
            # val_mse_loss += mse(y_pred.flatten(), snapshot.y)
            # val_mae_loss += mae(y_pred.flatten(), snapshot.y)
            # val_mse_loss += torch.sum((y_pred.flatten()[mask] - snapshot.y[mask])**2)
            # val_mae_loss += torch.sum(torch.abs(y_pred.flatten()[mask] - snapshot.y[mask]))
        val_mse_loss /= valid_dataset.snapshot_count
        val_mae_loss /= valid_dataset.snapshot_count
        # val_mse_loss /= sample_cnt
        # val_mae_loss /= sample_cnt

        # scheduler.step()

        if epoch % LOG_FREQ == 0:
            print(f"Epoch {epoch:>2} | Train MSE: {mse_loss:.4f} | Train MAE: {mae_loss:.4f} | Valid MSE: {val_mse_loss:.4f} | Valid MAE: {val_mae_loss:.4f}")

        mse_losses.append((mse_loss.detach().numpy(), val_mse_loss.detach().numpy()))

        if val_mse_loss < best_val_mse:
            best_epoch = epoch
            save_model(model, MODEL_NAME)
        elif epoch - best_epoch > EARLY_STOP_THRESHOLD:
            print(f"Early stopped training at epoch {epoch}")
            break

    plot_curves(mse_losses)
    return model

    
def save_model(model, model_name):
    torch.save(model.state_dict(), f"{model_name}.pt")
    S3.upload_file(f"{model_name}.pt", S3_BUCKET, f"{S3_SUBDIR}/models/gnn/{model_name}.pt")


def inference(model, test_dataset, aggregate_by_snapshot=True):
    if aggregate_by_snapshot:
        mse = torch.nn.MSELoss()
        mae = torch.nn.L1Loss()
    else:
        mse = torch.nn.MSELoss(reduction="sum")
        mae = torch.nn.L1Loss(reduction="sum")

    mse_loss = 0
    mae_loss = 0
    sample_cnt = 0
    model.eval()
    for snapshot in test_dataset:
        mask = snapshot.mask == 1
        sample_cnt += mask.sum()
        y_pred = denormalise(model(snapshot.x.unsqueeze(2), snapshot.edge_index).flatten()[mask])
        y_true = denormalise(snapshot.y[mask])
        mse_loss += mse(y_pred, y_true)
        mae_loss += mae(y_pred, y_true)

    if aggregate_by_snapshot:
        mse_loss /= train_dataset.snapshot_count
        mae_loss /= train_dataset.snapshot_count
    else:
        mse_loss /= sample_cnt
        mae_loss /= sample_cnt

    return mse_loss.item(), mae_loss.item()


def model_predict(model, dataset, edge):
    edge_predictions = []
    model.eval()
    for snapshot in dataset:
        y_pred = model(snapshot.x.unsqueeze(2), snapshot.edge_index)
        edge_predictions.append(y_pred.detach().numpy()[EDGE_IDX_MAP[edge]])
    return np.array(edge_predictions).reshape(-1)

# Visualisation code

In [None]:
def plot_random_edge_and_neighbours_time_series(speeds_df, dataset, model, nodes):
    edge = random.choice(UNIQUE_EDGES)
    neighbours = [IDX_EDGE_MAP[idx] for idx in np.nonzero(ADJACENCY_MATRIX[EDGE_IDX_MAP[edge]])[0]]
    neighbours.remove(edge)
    for e in [edge] + neighbours:
        plot_edge_time_series(e, speeds_df, dataset, model)
    return plot_edges(nodes, [edge] + neighbours)
    

def plot_edge_and_neighbours_time_series(edge, speeds_df, dataset, model, nodes):
    neighbours = [IDX_EDGE_MAP[idx] for idx in np.nonzero(ADJACENCY_MATRIX[EDGE_IDX_MAP[edge]])[0]]
    neighbours.remove(edge)
    for e in [edge] + neighbours:
        plot_edge_time_series(e, speeds_df, dataset, model)
    return plot_edges(nodes, [edge] + neighbours)
    

def plot_edge_time_series(edge, speeds_df, dataset, model):
    one_edge_df = speeds_df[speeds_df.edge == edge][["minute_bucket", "speed_kmh"]].sort_values("minute_bucket")
    ys = dataset.targets

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
        x=one_edge_df.minute_bucket,
        y=one_edge_df.speed_kmh,
        mode='markers',
        name='Ground Truth'
    ))
    fig.add_trace(
        go.Scatter(
        x=DATASET_DATE_RANGE,
        y=[y[EDGE_IDX_MAP[edge]] for y in ys],
        mode='markers',
        name='Imputed'
    ))
    fig.add_trace(go.Scatter(
        x=DATASET_DATE_RANGE,
        y=model_predict(model, dataset, edge),
        mode='markers',
        name='GNN predictions'
    ))
    fig.add_trace(go.Scatter(
        x=DATASET_DATE_RANGE,
        y=[edge_time_naive(edge, ts) for ts in DATASET_DATE_RANGE],
        mode='markers',
        name='Naive predictions'
    ))
    
    fig.update_layout(
        title=f"Time series for edge {edge}",
        title_x=0.5,
        xaxis=dict(
            title="Time [15-minute bucket]"
        ),
        yaxis=dict(
            title="Speed [km/h]"
        ),
    )

    # Update layout with legend
    fig.update_layout(
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    fig.show()



def plot_edges(nodes, edges):
    m = folium.Map(location=[44.435608, 26.102297], zoom_start=15)

    node_ids = [n for edge in edges for n in edge]

    # Add edges to the map
    for u, v in edges:
        x0, y0 = nodes[nodes["id"] == u][["lat", "lon"]].iloc[0]
        x1, y1 = nodes[nodes["id"] == v][["lat", "lon"]].iloc[0]
        folium.PolyLine(locations=[(x0, y0), (x1, y1)], color='blue', weight=5, tooltip=f"{u, v}").add_to(m)

    # Add nodes to the map
    for node in node_ids:
        x, y = nodes[nodes["id"] == node][["lat", "lon"]].iloc[0]
        folium.CircleMarker(location=(x, y), radius=5, color='red', fill=True, fill_color='red').add_to(m)

    return m

# Data preprocessing

In [None]:
def normalise(x):
    return (x - MEAN) / STD


def denormalise(x):
    return x * STD + MEAN
    

def extract_city_graph():
    S3.download_file(S3_BUCKET, f"{S3_SUBDIR}/{CITY_ID}-latest.osm.pbf", "bucharest.pbf")

    osm = pyrosm.OSM("bucharest.pbf")
    nodes, edges = osm.get_network(nodes=True, network_type="driving+service")
    edges["edge"] = list(zip(edges.u, edges.v))
    print(f"Unique OSM nodes: {nodes.id.nunique()}, unique OSM edges: {edges.id.nunique()}")

    if not os.path.isfile(S3_FILENAME):
        S3.download_file(S3_BUCKET, f"{S3_SUBDIR}/{S3_DATA}/{S3_FILENAME}", S3_FILENAME)
    
    speeds_df = pd.read_parquet(S3_FILENAME)

    print(f"Dataset time boundaries: {speeds_df.minute_bucket.min(), speeds_df.minute_bucket.max()}")
    print(f"Initial dataset shape: {speeds_df.shape}")

    speeds_df["edge"] = list(zip(speeds_df.start_node, speeds_df.end_node))

    speeds_df = speeds_df[speeds_df.edge.isin(UNIQUE_EDGES)]

    print(f"Dataset shape after filtering edges of interest: {speeds_df.shape}")

    speeds_df["day"] = speeds_df.minute_bucket.dt.weekday
    speeds_df["hour"] = speeds_df.minute_bucket.dt.hour
    speeds_df["minute"] = speeds_df.minute_bucket.dt.minute
    speeds_df.sort_values(["edge", "minute_bucket"], inplace=True)

    return speeds_df, nodes, edges

# Experimentation setup

In [None]:
def prepare_dataset(speeds_df, imputation_method):
    if os.path.isfile(f"{GNN_DATASET_NAME}.pickle"):
        with open(f"{GNN_DATASET_NAME}.pickle", "rb") as f:
            dataset = pickle.load(f)
        print("Loadeded imputed data")
    else:
        print("Running data imputation ...")
        xs, ys, target_mask = impute_dataset(speeds_df, imputation_method)
        dataset = StaticGraphTemporalSignal(EDGE_INDEX, ADJACENCY_MATRIX[ADJACENCY_MATRIX>0], xs, ys, mask=target_mask)
        with open(f"{GNN_DATASET_NAME}.pickle", "wb") as f:
            pickle.dump(dataset, f)
        S3.upload_file(f"{GNN_DATASET_NAME}.pickle", S3_BUCKET, f"{S3_SUBDIR}/{S3_DATA}/gnn/{GNN_DATASET_NAME}.pickle")
    return dataset


def evaluate_baselines(train_dataset, valid_dataset, test_dataset):
    for naive_name, naive_method in zip(["Global mean", "Edge mean"], [evaluate_global_mean_baseline, evaluate_edge_average_baseline]):
        for split, ds in zip(["train", "valid", "test"], [train_dataset, valid_dataset, test_dataset]):
            mse, mae = naive_method(ds)
            print(f"\t {naive_name} {split} MSE {mse:.{2}f}")
            print(f"\t {naive_name} {split} RMSE {np.sqrt(mse):.{2}f}")
            print(f"\t {naive_name} {split} MAE {mae:.{2}f}")

    for naive_name, naive_method in zip(["Edge time naive", "Edge time rolling"], [edge_time_naive, rolling_edge_time_avg_naive]):
        for split, date_range, ds in zip(["train", "valid", "test"], [TRAIN_DATE_RANGE, VALID_DATE_RANGE, VALID_DATE_RANGE], [train_dataset, valid_dataset, test_dataset]):
            mse, mae = evaluate_edge_time_average_baseline(ds, date_range, naive_method)
            print(f"\t {naive_name} {split} MSE {mse:.{2}f}")
            print(f"\t {naive_name} {split} RMSE {np.sqrt(mse):.{2}f}")
            print(f"\t {naive_name} {split} MAE {mae:.{2}f}")


def split_dataset(dataset):
    train_dataset, valid_dataset = temporal_signal_split(dataset, train_ratio=TRAIN_RATIO)
    valid_dataset, test_dataset = temporal_signal_split(valid_dataset, train_ratio=1/2) # Assume valid and test dataset are of equal length
    return train_dataset, valid_dataset, test_dataset


def calc_model_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params


def run_experiment(dataset, run_baselines=True):
    print("Example StaticGraphTemporalSignal snapshot:")
    print(dataset[0], '\n')
    
    train_dataset, valid_dataset, test_dataset = split_dataset(dataset)
    print("Train, valid and test set snapshot counts respectively: ", train_dataset.snapshot_count, valid_dataset.snapshot_count, test_dataset.snapshot_count, '\n')

    if run_baselines:
        print("Evaluating baselines:")
        evaluate_baselines(train_dataset, valid_dataset, test_dataset)
        print()

    print("Training a GNN ...")
    model = train(train_dataset, valid_dataset, epochs=EPOCHS)
    print()

    params = calc_model_params(model)
    print("Number of parameters in the model: ", params, '\n')

    mse, mae = inference(model, test_dataset)
    print("Test MSE: ", mse)
    print("Test RMSE: ", np.sqrt(mse))
    print("Test MAE: ", mae)

    return model


def generate_predictions(dataset, model, date_range):
    edge_predictions = []
    for minute_bucket, snapshot in zip(date_range, dataset):
        mask = snapshot.mask == 1
        used_edges = np.array(list(UNIQUE_EDGES))[mask]
        edge_predictions.append(pd.DataFrame({
            "start_node": used_edges[:, 0],
            "end_node": used_edges[:, 1],
            "minute_bucket": np.repeat(minute_bucket, mask.sum()),
            f"{MODEL_NAME}_speed": denormalise(model(snapshot.x, snapshot.edge_index).flatten()[mask].detach().numpy()),
            "speed_kmh": denormalise(snapshot.y[mask].detach().numpy())
        }))
    return pd.concat(edge_predictions)

# Experiments

In [None]:
subgraph_speeds_df, nodes, edges = extract_city_graph()

In [None]:
DATASET_START_DATE = subgraph_speeds_df.minute_bucket.min()
DATASET_END_DATE = subgraph_speeds_df.minute_bucket.max()
TRAIN_DATE_RANGE = pd.date_range(DATASET_START_DATE, DATASET_START_DATE + pd.Timedelta(N_WEEKS_TRAINING, 'W'), freq="15min", inclusive="left")
VALID_DATE_RANGE = pd.date_range(TRAIN_DATE_RANGE[-1], TRAIN_DATE_RANGE[-1] + pd.Timedelta(N_WEEKS_VALIDATION, 'W'), freq="15min", inclusive="right")
TEST_DATE_RANGE = pd.date_range(VALID_DATE_RANGE[-1], VALID_DATE_RANGE[-1] + pd.Timedelta(N_WEEKS_VALIDATION, 'W'), freq="15min", inclusive="right")
DATASET_DATE_RANGE = pd.concat([TRAIN_DATE_RANGE.to_series(), VALID_DATE_RANGE.to_series(), TEST_DATE_RANGE.to_series()])
DATASET_RANGE_DF = pd.DataFrame(DATASET_DATE_RANGE, columns=["minute_bucket"]).reset_index().set_index("minute_bucket")

SPEED_FEATURES = [col_name for col_name in subgraph_speeds_df.columns if "lag" in col_name]

In [None]:
subgraph_speeds_df = subgraph_speeds_df[subgraph_speeds_df.minute_bucket <= TEST_DATE_RANGE[-1]]

In [None]:
subgraph_speeds_df.shape

In [None]:
train_subgraph_speeds_df = subgraph_speeds_df[subgraph_speeds_df.minute_bucket < DATASET_START_DATE + pd.Timedelta(N_WEEKS_TRAINING, 'W')]
MEAN = train_subgraph_speeds_df.speed_kmh.mean()
STD = train_subgraph_speeds_df.speed_kmh.std()
MEAN, STD

In [None]:
subgraph_speeds_df["speed_kmh"] = subgraph_speeds_df.speed_kmh.apply(lambda x: normalise(x))
for feat in SPEED_FEATURES:
    subgraph_speeds_df[feat] = subgraph_speeds_df[feat].apply(lambda x: normalise(x))

In [None]:
train_subgraph_speeds_df = subgraph_speeds_df[subgraph_speeds_df.minute_bucket < DATASET_START_DATE + pd.Timedelta(N_WEEKS_TRAINING, 'W')]

In [None]:
EDGE_IDX_MAP = {edge: i for i, edge in enumerate(UNIQUE_EDGES)}
IDX_EDGE_MAP = {i: edge for i, edge in enumerate(UNIQUE_EDGES)}

MEAN_SPEED = train_subgraph_speeds_df.speed_kmh.mean()
EDGE_AVG_DICT = train_subgraph_speeds_df[["speed_kmh", "edge"]].groupby("edge").mean().astype(int).to_dict()["speed_kmh"]
EDGE_15_MIN_BUCKET_DICT = train_subgraph_speeds_df.groupby(["edge", "day", "hour", "minute"])["speed_kmh"].mean().to_dict()

with open("edge_15min_dict.pickle", "wb") as f:
    pickle.dump(EDGE_15_MIN_BUCKET_DICT, f)

S3.upload_file("edge_15min_dict.pickle", S3_BUCKET, f"{S3_SUBDIR}/models/edge_15min_dict.pickle")

ADJACENCY_MATRIX, EDGE_INDEX = compute_adjacency_matrix()

rolling_speed_avg_df = (pd.concat([expand_edge_time_series(g)
    for _, g in subgraph_speeds_df[["edge", "minute_bucket", "speed_kmh"]]
    .groupby("edge")]).set_index("minute_bucket").groupby("edge").expanding().mean())
rolling_speed_avg_df.dropna(inplace=True)
ROLLING_EDGE_TIME_AVG_DICT = rolling_speed_avg_df.to_dict()["speed_kmh"]
# TODO: Move these to data imputation methods
ROLLING_1H_WINDOW_EDGE_TIME_AVG_DICT = compute_rolling_mean(subgraph_speeds_df, "1h")
ROLLING_2H_WINDOW_EDGE_TIME_AVG_DICT = compute_rolling_mean(subgraph_speeds_df, "2h")
ROLLING_3H_WINDOW_EDGE_TIME_AVG_DICT = compute_rolling_mean(subgraph_speeds_df, "3h")
ROLLING_4H_WINDOW_EDGE_TIME_AVG_DICT = compute_rolling_mean(subgraph_speeds_df, "4h")

DATASET_DICT = subgraph_speeds_df[["edge", "minute_bucket", "speed_kmh"]].set_index(["edge", "minute_bucket"]).to_dict()["speed_kmh"]
MINUTE_BUCKET_AVG_DICT = subgraph_speeds_df[["minute_bucket", "speed_kmh"]].groupby("minute_bucket").mean().to_dict()["speed_kmh"]

In [None]:
GNN_DATASET_NAME

In [None]:
S3.download_file(S3_BUCKET, f"{S3_SUBDIR}/{S3_DATA}/gnn/{GNN_DATASET_NAME}.pickle", f"{GNN_DATASET_NAME}.pickle")

In [None]:
rm "{GNN_DATASET_NAME}.pickle"

In [None]:
subgraph_speeds_df.speed_kmh.hist()

In [None]:
dataset = prepare_dataset(subgraph_speeds_df, impute_nan)

In [None]:
train_dataset, valid_dataset, test_dataset = split_dataset(dataset)

In [None]:
dataset[0]

## Loss average over snapshots, 100 epochs, 32 hidden channels, 2 week of training data

In [None]:
HIDDEN_CHANNELS = 32
LEARNING_RATE = 0.001
EPOCHS = 100
DROPOUT = 0.2
MODEL_NAME = f"at3gcn_{EPOCHS}_epochs_{HIDDEN_CHANNELS}_hidden_channels_{len(UNIQUE_EDGES)}_edges_{N_WEEKS}_weeks"

In [None]:
model = run_experiment(dataset, run_baselines=False)

In [None]:
# These metrics contain masking
print("Train: ", inference(model, train_dataset))
print("Valid: ", inference(model, valid_dataset))
print("Test: ", inference(model, test_dataset))

In [None]:
# These metrics contain masking
print("Train: ", inference(model, train_dataset, aggregate_by_snapshot=False))
print("Valid: ", inference(model, valid_dataset, aggregate_by_snapshot=False))
print("Test: ", inference(model, test_dataset, aggregate_by_snapshot=False))

## Loss average over snapshots, 100 epochs, 32 hidden channels, 2 weeks of training data, with target masking

In [None]:
HIDDEN_CHANNELS = 32
LEARNING_RATE = 0.001
EPOCHS = 100
DROPOUT = 0.2
MODEL_NAME = f"at3gcn_{EPOCHS}_epochs_{HIDDEN_CHANNELS}_hidden_channels_{len(UNIQUE_EDGES)}_edges_{N_WEEKS}_weeks_with_target_masking"

In [None]:
model = run_experiment(dataset, run_baselines=False)

In [None]:
# These metrics contain masking
print("Train: ", inference(model, train_dataset))
print("Valid: ", inference(model, valid_dataset))
print("Test: ", inference(model, test_dataset))

In [None]:
# These metrics contain masking
print("Train: ", inference(model, train_dataset, aggregate_by_snapshot=False))
print("Valid: ", inference(model, valid_dataset, aggregate_by_snapshot=False))
print("Test: ", inference(model, test_dataset, aggregate_by_snapshot=False))

## Generate predictions

In [None]:
# S3.download_file(S3_BUCKET, f"{S3_SUBDIR}/{S3_DATA}/gnn/{GNN_DATASET_NAME}.pickle", f"{GNN_DATASET_NAME}.pickle")
# S3.download_file(S3_BUCKET, f"{S3_SUBDIR}/models/gnn/{MODEL_NAME}.pt", f"{MODEL_NAME}.pt")

# with open(f"{GNN_DATASET_NAME}.pickle", "rb") as file:
#     dataset = pickle.load(file)

In [None]:
dataset[0]

In [None]:
model = GAT(len(SPEED_FEATURES), 8, 1)
model.load_state_dict(torch.load(f"{MODEL_NAME}.pt"))
train_dataset, valid_dataset, test_dataset = split_dataset(dataset)

In [None]:
print(inference(model, train_dataset))
# print(inference(model, valid_dataset))
# print(inference(model, test_dataset))

In [None]:
# train_dataset, valid_dataset, test_dataset = split_dataset(dataset)
os.makedirs("gnn", exist_ok=True)
for ds, split, date_range in zip([train_dataset, valid_dataset, test_dataset], DATA_SPLITS, [TRAIN_DATE_RANGE, VALID_DATE_RANGE, TEST_DATE_RANGE]):
    preds_df = generate_predictions(ds, model, date_range)
    print(mean_squared_error(preds_df.speed_kmh, preds_df[f"{MODEL_NAME}_speed"]))
    preds_df.to_parquet(f"gnn/{split}.parquet")
    S3.upload_file(f"gnn/{split}.parquet", S3_BUCKET, f"{S3_SUBDIR}/model_predictions/{MODEL_NAME}/{split}.parquet")
    print(f"Saved {MODEL_NAME}")

# Time series visualisations

In [None]:
plot_random_edge_and_neighbours_time_series(subgraph_speeds_df, dataset, model, nodes)

In [None]:
plot_random_edge_and_neighbours_time_series(subgraph_speeds_df, dataset, model, nodes)

In [None]:
plot_random_edge_and_neighbours_time_series(subgraph_speeds_df, dataset, model, nodes)

In [None]:
plot_edge_and_neighbours_time_series((248729659, 6258431109), subgraph_speeds_df, dataset, model, nodes)