# Assignment 3_2: Echo State Networks

In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim

from esn import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [7]:
narma_df = pd.read_csv('../NARMA10.csv', header=None)
narma_df.iloc[:, :20] # visualize the first 20 columns

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.083964,0.48934,0.35635,0.25024,0.23554,0.029809,0.34099,0.021216,0.035723,0.26082,0.048365,0.40907,0.40877,0.36122,0.074933,0.3298,0.2593,0.48649,0.3245,0.40017
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.13285,0.17536,0.37127,0.36481,0.33707,0.20447,0.33003,0.20726,0.18825,0.28343


In [105]:
class Reservoir(nn.Module):
    def __init__(self, input_size:int, hidden_size:int, omhega_in:float, omhega_b:float, rho:float, density:float = 1):
        '''
        Initialize Echo State Network (ESN) with the given parameters

        Parameters:
        ----------
        input_size: int
            Size of the input
        hidden_size: int
            Size of the hidden layer
        omhega_in: float
            Input scaling
        omhega_b: float
            Bias scaling
        rho: float
            Desired Spectral radius of the hidden recurrent layer weight matrix

        Returns:
        -------
        return: -
        '''
        super(Reservoir, self).__init__()

        self.input_scaling = nn.Parameter(torch.tensor(omhega_in), requires_grad=False)
        self.rho = nn.Parameter(torch.tensor(rho), requires_grad=False)
        self.hidden_size = nn.Parameter(torch.tensor(hidden_size), requires_grad=False)

        self.W_in = nn.Parameter(nn.init.uniform_(torch.empty(hidden_size, input_size), -omhega_in, omhega_in), requires_grad=False)
        self.bias = nn.Parameter(nn.init.uniform_(torch.empty(hidden_size), -omhega_b, omhega_b), requires_grad=False)

        W_h = nn.init.uniform_(torch.empty(hidden_size, hidden_size), -1, 1)
        W_h = W_h.div_(torch.linalg.eigvals(W_h).abs().max()).mul_(rho).float() # use in-place operations (div_, mul_) to save memory

        self.W_h = nn.Parameter(W_h, requires_grad=False)

    @torch.no_grad()
    def forward(self, input:torch.Tensor, h_init:torch.Tensor, washout:int = 0) -> torch.Tensor:
        '''
        Forward pass through the ESN

        Parameters:
        ----------
        input: torch.Tensor
            Input tensor. Input of Shape (L, input size) or (L, N, input size) if input is batched 
            (L is the length of the sequence, N is the batch size)
        h_init: torch.Tensor
            Initial hidden state (set to zeros if None)
        washout: int
            Number of time steps to ignore

        Returns:
        -------
        return: torch.Tensor
            Output tensor
        '''

        timesteps, batch_size, _ = input.shape
        h = torch.zeros(batch_size, self.hidden_size) if h_init is None else h_init.copy()
        states = []

        for t in range(timesteps):
            h = F.linear(input[t], self.W_in, self.bias) + F.linear(h, self.W_h)
            h = F.tanh(h)
            states.append(h)

        return torch.stack(states[washout:], dim=0) 

In [115]:
cc = torch.rand(10, 1, 1)
r = Reservoir(1, 11, 0.4, 0.5, 0.6)
r(cc, None)

tensor([[[ 0.2926,  0.2173,  0.3035, -0.1396,  0.0698,  0.0783, -0.2219,
           0.0921, -0.3867, -0.1676, -0.2341]],

        [[-0.0449, -0.0657,  0.4548,  0.1277,  0.1643,  0.0532, -0.2058,
           0.0210, -0.4133, -0.2410, -0.4792]],

        [[-0.0091, -0.1591,  0.4855,  0.2849,  0.4492, -0.0099, -0.1439,
           0.0308, -0.3426, -0.2826, -0.4234]],

        [[ 0.0810,  0.0130,  0.4519, -0.0856,  0.2899, -0.2181, -0.4297,
          -0.0913, -0.4189, -0.3270, -0.3094]],

        [[ 0.0988,  0.0239,  0.4912,  0.2029,  0.2416, -0.0098, -0.3685,
          -0.1483, -0.4116, -0.2166, -0.4392]],

        [[ 0.0437, -0.0909,  0.4634,  0.1583,  0.3612, -0.0756, -0.3403,
          -0.1830, -0.4476, -0.2549, -0.3907]],

        [[ 0.0949, -0.0193,  0.4543,  0.0918,  0.3277, -0.1202, -0.3882,
          -0.1632, -0.4760, -0.2520, -0.3508]],

        [[ 0.0882,  0.0043,  0.4729,  0.0808,  0.2542, -0.1085, -0.4287,
          -0.1956, -0.4996, -0.2342, -0.3743]],

        [[ 0.0676, -0.05