# Assignment 3_2: Echo State Networks

In [1]:
import pandas as pd
import numpy as np
from sklearn.linear_model import Ridge
import matplotlib.pyplot as plt
import torch
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim

from esn import *
from learning import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [2]:
narma_df = pd.read_csv('../NARMA10.csv', header=None)
narma_df.iloc[:, :20] # visualize the first 20 columns

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.083964,0.48934,0.35635,0.25024,0.23554,0.029809,0.34099,0.021216,0.035723,0.26082,0.048365,0.40907,0.40877,0.36122,0.074933,0.3298,0.2593,0.48649,0.3245,0.40017
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.13285,0.17536,0.37127,0.36481,0.33707,0.20447,0.33003,0.20726,0.18825,0.28343


In [3]:
x_data = torch.tensor(narma_df.iloc[0].values, dtype=torch.float32) # float 32 for better memory efficiency
y_data = torch.tensor(narma_df.iloc[1].values, dtype=torch.float32)

# 4000 tr, 1000 val, 5000 test (WARNING: load entire dataset in memory ONLY because it is small and also the NN is quite small)
dev_x, dev_y = x_data[:5000], y_data[:5000] # only used for retraining (train + val sets)

test_x, test_y = x_data[5000:], y_data[5000:]

train_x, val_x = dev_x[:4000], dev_x[4000:]
train_y, val_y = dev_y[:4000], dev_y[4000:]

In [4]:
train_x = train_x.unsqueeze(1).unsqueeze(1) # needed shape by the model
train_x.shape

torch.Size([4000, 1, 1])

In [5]:
val_x = val_x.unsqueeze(1).unsqueeze(1) # needed shape by the model
val_x.shape

torch.Size([1000, 1, 1])

In [6]:
import itertools
import torch
import statistics as st

from esn import *

def grid_search(hyperparameters:dict, train_x, train_y, val_x, val_y, n_iter:int = 5, verbose:bool = False):
    all_config = [dict(zip(hyperparameters.keys(), config)) for config in itertools.product(*hyperparameters.values())]

    model_selection_history = {}
    mse = torch.nn.MSELoss()

    for i, config in enumerate(all_config):
        input_size = train_x.shape[2]

        train_mse = []
        val_mse = []
        washout = config['washout']
        for _ in range(n_iter):
            esn = RegressorESN(input_size=input_size, hidden_size=config['hidden_size'], ridge_regression=config['ridge_regression'],
                            omhega_in=config['omhega_in'], omhega_b=config['omhega_b'], rho=config['rho'], density=1)

            esn.train()        
            h_last = esn.fit(train_x, train_y, washout)
            train_pred = esn(train_x, None)
            train_mse.append(mse(train_pred, train_y).item())
            
            esn.eval()
            val_pred = esn(val_x, h_init=h_last)
            val_mse.append(mse(val_pred, val_y).item())

        model_selection_history[f'config_{i}'] = {**config, 
                                                  'train_mse_mean': st.mean(train_mse), 'train_mse_var': st.variance(train_mse), 
                                                  'val_mse_mean': st.mean(val_mse), 'val_mse_var': st.variance(val_mse)}
        if verbose:
            print(f'Configuration {i}')

    return model_selection_history

In [7]:
hyperparams = {
    'hidden_size': [256, 512, 1024], 
    'ridge_regression':  [1e-6],
    'omhega_in': [1, 2],
    'omhega_b': [0.5, 1],
    'rho': [0.7, 0.8, 0.9],
    'washout': [0, 100],
}

model_selection_history = grid_search(hyperparams, train_x, train_y, val_x, val_y, n_iter=5) 

df = pd.DataFrame.from_dict(model_selection_history, orient='index')
df.to_csv('esn_grid_search.csv')

In [11]:
df_results = pd.read_csv('esn_grid_search.csv', index_col=0)
df_results.sort_values(by='val_mse_mean', ascending=True).head(10)

Unnamed: 0,hidden_size,ridge_regression,omhega_in,omhega_b,rho,washout,train_mse_mean,train_mse_var,val_mse_mean,val_mse_var
config_59,1024,1e-06,1,1.0,0.9,100,0.044263,0.0004076876,3e-06,1.063745e-13
config_58,1024,1e-06,1,1.0,0.9,0,2e-06,1.473181e-13,4e-06,2.718549e-13
config_50,1024,1e-06,1,0.5,0.8,0,2e-06,3.656955e-14,4e-06,7.393062e-13
config_51,1024,1e-06,1,0.5,0.8,100,0.006309,1.21608e-05,4e-06,2.177277e-13
config_57,1024,1e-06,1,1.0,0.8,100,0.22907,0.0875152,5e-06,3.283548e-12
config_56,1024,1e-06,1,1.0,0.8,0,3e-06,5.526449e-13,5e-06,8.380498e-13
config_49,1024,1e-06,1,0.5,0.7,100,0.039802,0.001711976,5e-06,4.796634e-13
config_52,1024,1e-06,1,0.5,0.9,0,3e-06,7.460215e-14,6e-06,4.262545e-13
config_53,1024,1e-06,1,0.5,0.9,100,0.002312,1.799049e-06,6e-06,6.084754e-13
config_48,1024,1e-06,1,0.5,0.7,0,3e-06,1.06175e-13,6e-06,3.592116e-13
