In [1]:
import os
import torch

import matplotlib.pyplot as plt
from tqdm import trange
from models.ESN import ESNModel
from dataset.data_loaders import load_dataset, generate_datasets
from models.early_stopping import EarlyStopping


os.environ["CUDA_VISIBLE_DEVICES"]='0'

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [2]:
data, L, F = load_dataset('spain')

Xtr, Ytr, Xval, Yval, Xte, Yte, diffXte, diffYte = generate_datasets(data, L, F, device, test_percent = 0.25, val_percent = 0.25)

In [3]:
batch = 1
steps = Xtr[:,0].shape[0]

X = torch.stack([Xtr[i:i+batch,0] for i in range(steps-batch+1)])[...,None,None]
Xv = torch.stack([Xval[i:i+batch,0] for i in range(Xval.shape[0]-batch+1)])[...,None,None]
Xt = torch.stack([Xte[i:i+batch,0] for i in range(Xte.shape[0]-batch+1)])[...,None,None]

Y = Ytr[batch-F:,].squeeze()
Yv = Yval[batch-F:,].squeeze()
Yt = Yte[batch-F:,].squeeze()

In [4]:
esn = ESNModel(1, 500, 1, 0, 1, F).to(device)

In [5]:
torch_optimizer = torch.optim.Adam(esn.parameters(), lr=0.01, weight_decay=1e-5)
epochs = 100
torch_loss = torch.nn.MSELoss()

# initialize the early_stopping object
checkpoint_path = "./checkpoints/esn_spain/"
early_stopping = EarlyStopping(patience=20, verbose=False, path=checkpoint_path)

with trange(epochs) as t:
    for epoch in t:
        esn.train()
        torch_optimizer.zero_grad()
        loss = torch_loss(esn(X).squeeze(), Y.squeeze())
        loss.backward()
        torch_optimizer.step()

        # display progress bar
        t.set_description(f"Epoch {epoch+1}")
        t.set_postfix({"loss":float(loss / Y.shape[0])})

        # Early stopping
        esn.eval()
        valid_loss = torch_loss(esn(Xv).squeeze(), Yv.squeeze()).item()

        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, esn)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break

# load the last checkpoint with the best model
esn.load_state_dict(torch.load(checkpoint_path + "checkpoint.pt"))

Epoch 71:  70%|███████   | 70/100 [00:02<00:01, 28.25it/s, loss=0.000519]

Early stopping





<All keys matched successfully>

Reservoir states

In [9]:
h = esn.reservoir(Xt, return_last_state=False).squeeze().cpu()
h.shape

torch.Size([455, 500])

Delay coordinates

In [21]:
stack = 100
h_delay = torch.zeros((stack,h.shape[0]-stack))
for i in range(stack):
    h_delay[i,:] = h[i:-stack+i,0]  # NOTE: only the first state component! What about the others?

In [22]:
h_delay.shape

torch.Size([100, 355])

SVD

In [24]:
U, S, Vh = torch.linalg.svd(h_delay)
V = Vh.T

Finding the optimal for the threshold r

https://github.com/erichson/optht