In [1]:
%%capture
%cd ..

In [2]:
from sklearn.metrics import root_mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from src.utils import load_raw, train_model, create_loader, Evaluator
from src.models import BaselineGCN


WINDOW_SIZE = 24
TARGET_SIZE = 1
GRANULARITY = 12 # 5 minutes

EPOCHS = 100

METRICS = {
    "RMSE": root_mean_squared_error, 
    "MAE": mean_absolute_error,
    "MAPE": mean_absolute_percentage_error
}

pems8_data, pems8_adj, pems8_dist = load_raw("data/PEMS08/pems08.npz", "data/PEMS08/distance.csv")
pems4 = load_raw("data/PEMS04/pems04.npz", "data/PEMS04/distance.csv")

In [3]:
pems8_adj.shape

(295, 2)

In [4]:
import torch
model_to_train = BaselineGCN(12, 1, 512)
training_config = {
    "raw_data": pems8_data,
    "adj_matrix": pems8_adj,
    "cost_matrix": pems8_dist,
    "window_size": 12,
    "target_size": 1,
    "granularity": 10,
    "epochs": 100,
    "batch_size": 64, 
    "standardize":True, 
    "shuffle":True,
    "criterion": torch.nn.MSELoss(),
    "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
    "learning_rate": 1e-3,
    "patience": 5,
    "min_delta": 0.0005,
    "reset_weights":True,
}

In [5]:
from src.experiments.run import run_experiment

model, metr = run_experiment(create_loader,
                             train_model,
                             model_to_train,
                             Evaluator(METRICS, save_history=True), 
                             **training_config)



Epoch 1/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 2/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 3/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 4/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 5/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 6/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 7/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 8/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 9/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 10/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 11/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 12/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 13/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 14/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 15/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 16/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 17/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 18/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 19/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 20/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 21/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 22/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 23/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 24/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 25/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 26/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 27/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 28/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 29/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 30/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 31/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 32/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 33/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 34/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 35/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 36/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 37/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 38/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 39/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 40/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 41/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 42/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 43/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 44/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 45/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 46/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 47/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 48/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 49/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 50/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 51/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 52/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 53/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 54/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 55/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 56/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 57/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 58/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 59/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 60/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 61/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 62/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 63/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 64/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 65/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 66/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 67/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Epoch 68/100:   0%|          | 0/20 [00:00<?, ?batch/s]

Early stopping
{'RMSE': 0.814967, 'MAE': 0.5912433, 'MAPE': 1.2185074}
