In [1]:
%%capture
%cd ..

In [2]:
from src.model.esn import ESN
from src.model.lstm import LSTM
from src.trainer.ridge_regression_trainer import RidgeRegressionTrainer
from src.trainer.bptt_trainer import BPTTTrainer
from src.dataset import create_sifim_datasets
from src.trainer.model_selection import retraining
import torch
import matplotlib.pyplot as plt

In [3]:
batch_size=16
shuffle=True

tr_dataset, vl_dataset, ts_dataset = create_sifim_datasets(vl_perc=0.2, ts_perc=0.2, noise=0.005)

In [4]:
def format_result(result):
    return dict(mse=result[0], time=result[1], emissions=result[2], **result[3])

In [None]:
lstm_results = []


def store_lstm_accuracy(trainer):
    lstm_results.append(format_result(trainer.test()))

retraining(
    model_constructor=LSTM,
    trainer_constructor=lambda *args, **kwargs: BPTTTrainer(*args, callback=store_lstm_accuracy, **kwargs),
    tr_dataset=tr_dataset,
    ts_dataset=ts_dataset,
    batch_size=batch_size,
    shuffle=shuffle,
    hyperparams_path=f'hyperparams/LSTM_hyperparams.json',
    model_path=None,  #f'models/LSTM.torch',
    history_path=None, #'history/history.json',
)

In [None]:
esn_results = format_result(RidgeRegressionTrainer(
    torch.load('models/ESN.torch'),
    torch.utils.data.DataLoader(tr_dataset, batch_size=batch_size, shuffle=shuffle),
    torch.utils.data.DataLoader(ts_dataset, batch_size=batch_size, shuffle=shuffle),
).test())

In [None]:
def plot(key):
    plt.figure(figsize=(12, 7))
    plt.title(key)
    plt.plot([r[key] for r in lstm_results], label='LSTM')
    plt.plot([esn_results[key] for _ in lstm_results], linestyle='--', label='ESN')
    plt.xlabel('epochs')
    plt.ylabel(key)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'plots/development_{key}')
    plt.show()

In [None]:
for k in esn_results.keys():
    plot(k)