In [None]:
from datamodule.datamodule_old import DataModule
from distributions.distributions import *
from utils import print_model_size, list_available_models, draw, draw_single

# Dataset
- you have to specify each param of the dataset

In [None]:
name = "electric" # 'exchange'
train_size = 0.7
validation_size = 0.15
test_size = 0.15
custom_datamodule = DataModule(name, (train_size, validation_size, test_size))

In [None]:
train_window = 24
test_window = 24
test_horizon = 4
test_stride = 4 # should be equal to the time horizon 
test_delay = 0
train_dm, test_dataloader = custom_datamodule.get_all(
    window=train_window,
    test_window=test_window,
    test_horizon=test_horizon,
    test_stride=test_stride,
    test_delay=test_delay
)

# Model
you have to specify:
- model parameters such as number of nodes and number of feature
- distribution: Gaussian or StudentT
- loss
- perform scaling
- additional parameters if needed

In [None]:
input_size = custom_datamodule.get_channels()
n_nodes = custom_datamodule.get_number_of_nodes()

distribution = GaussianDistribution()  # or StudentTDistribution()

test_loss = 'rmse'  # or 'mae'

perform_scaling = True


In [None]:
from model.net import BaseModelDeepGar
from model.deepGARv2 import DeepGAR

model = DeepGAR(
    input_size=input_size,
    n_nodes=n_nodes,
    distribution=distribution,
    test_loss=test_loss,
    perform_scaling=perform_scaling
)

# list_available_models()

In [None]:
print(model)
print_model_size(model)

# Trainer

In [None]:
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=20,
    mode='min'
)
checkpoint_callback = ModelCheckpoint(
    dirpath='logs2',
    save_top_k=1,
    monitor='val_loss',
    mode='min',
)
trainer = pl.Trainer(
    max_epochs=400,
    accelerator='cuda' if torch.cuda.is_available() else 'cpu',
    devices = 1,
    callbacks=[checkpoint_callback, early_stopping_callback]
)

# Training

In [None]:
if True:
    trainer.fit(model=model, datamodule=train_dm)

# Testing

In [None]:
best_model_path = 'logs2/epoch=1-step=1184.ckpt'

best_model = DeepGAR.load_from_checkpoint(
    checkpoint_path=best_model_path
)
best_model.eval()
res = trainer.predict(model=best_model, dataloaders=test_dataloader)

# Draw + errors

In [None]:
draw(res)

In [None]:
interval =int(100/4)
draw(res[:interval])