In [None]:
from datamodule.datamodule import *
from distributions.distributions import *
from utils import *
# MODEL:
from model.deepGARv2 import DeepGAR
# PYTORCH LIGHTNING:
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
#

In [None]:

name = 'electric'  # electric
train_size = 0.7
validation_size = 0.15
test_size = 0.15
batch_size = 32
custom_datamodule = DataModule(name, (train_size, validation_size, test_size), batch_size=batch_size)

In [None]:
train_window = 168
test_window = 168
test_horizon = 24
test_stride = 24 # should be equal to the time horizon
test_delay = 0
train_loader, val_loader, test_loader = custom_datamodule.get_all(
    window=train_window,
    test_window=test_window,
    test_horizon=test_horizon,
    test_stride=test_stride,
    test_delay=test_delay
)

In [None]:
input_size = custom_datamodule.get_channels()
n_nodes = custom_datamodule.get_number_of_nodes()

distribution = GaussianDistribution()  # or StudentTDistribution()

perform_scaling = False

In [None]:

def run_model(run, train=False, testing=True, checkpoint_callback_path:str=None, tuning=False):
   
    model = DeepGAR(
        input_size=input_size,
        n_nodes=n_nodes,
        distribution=distribution,
        perform_scaling=perform_scaling
    )
    print(model)
    print_model_size(model)

    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        patience=20,
        mode='min'
    )
    checkpoint_callback = ModelCheckpoint(
        dirpath=f'electricity_logs_v2_{run}',
        save_top_k=1,
        monitor='val_loss',
        mode='min',
    )
    trainer = pl.Trainer(
        max_epochs=1000,
        accelerator='cpu',
        devices = 1,
        callbacks=[checkpoint_callback, early_stopping_callback],
        log_every_n_steps=100,
    )

    if train:
        trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    if testing:
        best_model_path = checkpoint_callback.best_model_path
        if len(best_model_path) < 1 and checkpoint_callback_path is None:
            raise RuntimeError(f'failed to locate best model checkpoint path. Stopping at run {run}')
        elif len(best_model_path) < 1:
            best_model_path= checkpoint_callback_path
        best_model = DeepGAR.load_from_checkpoint(
        checkpoint_path=best_model_path
        )
        print(f'Using checkpoint {best_model_path}')
        best_model.eval()
        res = trainer.predict(model=best_model, dataloaders=test_loader)
        rmse_loss, nd_loss = get_metrics(res, horizon=test_horizon, n_nodes=n_nodes)
        print(f'Run {run} on {name}: RMSE: {rmse_loss}, ND: {nd_loss}')
        add_metrics(f'{name}_v2', rmse_loss, nd_loss)

    
    try:
        del model
        del early_stopping_callback
        del checkpoint_callback
        del trainer
        if testing:
            del best_model
    except:
        pass
    return res
    


# list_available_models()

In [None]:
train_n_times = 1
for i in range(train_n_times):
    run_model(i, train=True, testing=True)