# 01 - Create an Echo State Network (ESN) using PyTorch

In [1]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as P
import copy
from tqdm import tqdm

In [2]:
def createSparseMatrix(nReservoir, sparsity):
    '''
    Utility function: creates Sparse Matrix
    Returns:
            W (np.array): sparse matrix of size (**nReservoir**, **nReservoir**).
    '''
    rows, cols = nReservoir, nReservoir
    W = np.random.uniform(-1, 1, (rows, cols)) # randomly chosen matrix, entries in range [-1,1]
    num_zeros = np.ceil(sparsity * rows).astype(np.int) # number of zeros to set
    for iCol in range(cols):
        row_indices  = np.random.permutation(rows) # choose random row indicies
        zero_indices = row_indices[:num_zeros]     # get the num_zeros of them
        W[zero_indices, iCol] = 0                  # set (zero_indicies, iCol) to 0
    return W

In [3]:
class ESN(torch.nn.Module):
    def __init__(self,nInput, nTemporal, nOutput, nReservoir, **kwargs):
        super().__init__()
    
        self.nTemporal  = nTemporal
        self.nReservoir = nReservoir
        self.alpha     = kwargs.get('alpha'     , 0.5)
        phi            = kwargs.get('phi'       , 0.9)
        rho            = kwargs.get('rho'       , 0.99)
        gamma          = kwargs.get('gamma'     , 0.1)
        sparsity       = kwargs.get('sparsity'  , 0.95)
        randomSeed     = kwargs.get('randomSeed', 1)

        np.random.seed(randomSeed)
        torch.manual_seed(randomSeed)

        self.dtype  = torch.float

        self.Win    = torch.tensor(gamma * np.random.randn(nInput, nReservoir), dtype = torch.float)
        W           = createSparseMatrix(nReservoir, sparsity)
        self.W      = torch.tensor(rho * W / (np.max(np.absolute(np.linalg.eigvals(W)))), dtype = torch.float)
        
        Wout        = torch.tensor(np.random.uniform(-phi, phi, [nOutput, nReservoir * nTemporal]) / nReservoir, dtype = torch.float)
        self.Wout   = torch.nn.Parameter(Wout, requires_grad = True)
    
    def forward(self, u):
        
        x = torch.zeros([1, self.nReservoir], dtype = torch.float)
        X = []
        for t in range(self.nTemporal):
            x_prev = copy.copy(x)
            x = (1 - self.alpha) * x_prev + self.alpha * torch.tanh(torch.matmul(x_prev, self.W) + torch.matmul(u[:,t], self.Win))
            X.append(x.float())
        X = torch.stack(X)
        Y = torch.matmul(self.Wout, X.flatten())
        return torch.nn.DoY

## Get Dataset

In [6]:
trainData = datasets.FashionMNIST("../Datasets/", train=True, transform= transforms.ToTensor(), download=True) # download is one-time thing

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../Datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting ../Datasets/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../Datasets/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../Datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting ../Datasets/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../Datasets/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../Datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting ../Datasets/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../Datasets/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../Datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Extracting ../Datasets/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../Datasets/FashionMNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
trainLoader = torch.utils.data.DataLoader(trainData, batch_size = 4, shuffle = True, num_workers = 2)

In [24]:
x = next(iter(trainLoader))[0]
print(x.shape)

torch.Size([4, 1, 28, 28])


In [133]:
esn = ESN(28, 28, 10, 500)

In [37]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(esn.parameters())
activation = torch.nn.Softmax(dim=1)

In [134]:
for d, (x,y) in enumerate(trainLoader):
    optimizer.zero_grad()
    y_hat = activation(torch.stack([esn(u) for u in x]))
    loss = criterion(y_hat, y)
    loss.backward()
    optimizer.step()
    if d % 1000 == 0:
        print('[{}/{}] loss:{:.3f}'.format(d, len(trainLoader), loss.item()))

[0/15000] loss:2.302


KeyboardInterrupt: 

In [62]:
y, activation(torch.stack([esn(u) for u in x])).argmax(dim=1) # not too bad

(tensor([8, 2, 8, 8]), tensor([8, 6, 8, 8]))

In [63]:
testData = datasets.FashionMNIST("../Datasets/", train=False, transform= transforms.ToTensor(), download=True) # download is one-time thing
testLoader = torch.utils.data.DataLoader(testData, batch_size = 4, shuffle = True, num_workers = 2)

In [103]:
Predictions = torch.zeros((len(testLoader), 4), dtype=torch.bool)
with torch.no_grad():
    for d, (x,y) in enumerate(testLoader):
        y_hat = activation(torch.stack([esn(u) for u in x]))
        Predictions[d] = (y == y_hat.argmax(dim=1))

In [119]:
print('Prediction accuracy: {}%'.format(100 * Predictions.sum().item()/(len(testLoader)* 4)))

Prediction accuracy: 69.5%


# Multiple Epochs

In [135]:
for nEpoch in range(10):
    Loss = []
    for d, (x,y) in enumerate(tqdm(trainLoader)):
        optimizer.zero_grad()
        y_hat = activation(torch.stack([esn(u) for u in x]))
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        Loss.append(loss.item())
    print('[{}/{}] loss:{:.3f}'.format(nEpoch, 10, sum(Loss)/len(trainLoader)))

 38%|███▊      | 5758/15000 [01:16<02:03, 75.09it/s]


KeyboardInterrupt: 

1.7550476873954137