# MNIST_STP review


#### This jupyter notebook stands for the purpose to learn the code MNIST_STP and divide it into several important sessions.

## 1. Import

In [5]:
import torch
import math
import matplotlib.pyplot as plt
import numpy as np 
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = '4'

## 2. Device configuration

In [4]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

# get index of currently selected device
print(torch.cuda.current_device()) # returns 0 in my case

# get number of GPUs available
print(torch.cuda.device_count()) # returns 1 in my case

# get the name of the device
print(torch.cuda.get_device_name(0)) # good old Tesla K80

Using device: cuda
0
1
NVIDIA GeForce RTX 4070 Laptop GPU


## 3. Dataset and dataloggers

In [3]:
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)


from torch.utils.data import DataLoader
loaders = {
    'train' : torch.utils.data.DataLoader(train_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=0),
    
    'test'  : torch.utils.data.DataLoader(test_data, 
                                          batch_size=100, 
                                          shuffle=True, 
                                          num_workers=0),
}
loaders


1.0%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz





Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\raw






{'train': <torch.utils.data.dataloader.DataLoader at 0x21d94086790>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x21d8c503410>}

### 4. Set structure

In [4]:
from torch import nn
import torch.nn.functional as F      

sequence_length = 196
input_size = 4
hidden_size = 144
num_layers = 1
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01


### 5. Class STP

In [5]:
## This is the illustrated version of code, dont run it -- too many illustrations
class STPCell(nn.Module):
    def __init__(self, input_size, hidden_size, complexity, e_h, alpha):
        super(STPCell, self).__init__()
        ## define the input arguments to be the artibutes in class
        self.input_size = input_size        
        self.hidden_size = hidden_size
        self.complexity = complexity 
        sigmoid = nn.Sigmoid() 
        self.ones = torch.ones(self.hidden_size, self.hidden_size)
        self.batch_size = batch_size 
        self.forprintingX = []
        self.forprintingU = []
        self.forprintingh = []

        if self.complexity == "rich":
            # System variables 
            self.e_h = e_h

            # Short term Plasticity variables 
            self.delta_t = 1
            self.alpha = alpha
            self.e_ux = self.alpha * self.e_h
            self.z_min = 0.001
            self.z_max = 0.1

            # Short term Depression parameters  
            ## Parameters initialised in this way are all trainable parameters
            self.c_x = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))

            # Short term Facilitation parameters
            self.c_u = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            self.c_U = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            
            # System parameters            
            self.c_h = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            self.w = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            self.p = torch.nn.Parameter(torch.rand(self.hidden_size, self.input_size))  
            self.b = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            
            # State initialisations
            ## h_t is a horizontal tensor! not a vertical one as we used in algebra
            ## transpose might be needed
            """"
            h_t starts with 0, with a clean memory state
            X starts with 1, as available resourse is 1 at first
            U follows the same principle, as baseline utilisation factor is 0.9
            Afterall it depends on the biological models
            """
            self.h_t = torch.zeros(1, self.hidden_size, dtype=torch.float32)
            self.X = torch.ones(self.hidden_size, self.hidden_size, dtype=torch.float32)     
            self.U = torch.full((self.hidden_size, self.hidden_size), 0.9, dtype=torch.float32)         
            self.Ucap = 0.9 * sigmoid(self.c_U)
            ## detach() -- don't want to include it in the gradient computation
            self.Ucapclone = self.Ucap.clone().detach()
        if self.complexity == "poor":
            # System variables 
            self.e_h = e_h

            # Short term Plasticity variables 
            self.delta_t = 1
            self.alpha = alpha
            self.e_ux = self.alpha * self.e_h
            self.z_min = 0.001
            self.z_max = 0.1

            # Short term Depression parameters  
            self.c_x = torch.nn.Parameter(torch.rand(self.hidden_size, 1))

            # Short term Facilitation parameters
            self.c_u = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            self.c_U = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            
            # System parameters
            self.c_h = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            self.w = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            self.p = torch.nn.Parameter(torch.rand(self.hidden_size, self.input_size))
            self.b = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            
            # State initialisations
            self.h_t = torch.zeros(1, self.hidden_size, dtype=torch.float32)
            self.X = torch.ones(self.hidden_size, 1, dtype=torch.float32)
            self.U = torch.full((self.hidden_size, 1), 0.9, dtype=torch.float32)   
            self.Ucap = 0.9 * sigmoid(self.c_U)
            self.Ucapclone = self.Ucap.clone().detach()
        for name, param in self.named_parameters():
            #print(name, param.size(), param)
            nn.init.uniform_(param, a=-(1/math.sqrt(hidden_size)), b=(1/math.sqrt(hidden_size))) 

    def forward(self, x):                 
        if self.complexity == "rich":
            if self.h_t.dim() == 3:
                self.h_t = self.h_t[0]
            self.h_t = torch.transpose(self.h_t, 0, 1)
            x = torch.transpose(x, 0, 1)
            sigmoid = nn.Sigmoid()
            
            # graph plotting 
            '''self.forprintingX.append(self.X[20,11,24].item())
            self.forprintingU.append(self.U[20,11,24].item())
            self.forprintingh.append(self.h_t[11, 20].item())
            if len(self.forprintingX) % (196*5) == 0:
                self.forprintingX = []
                self.forprintingU = []
                self.forprintingh = []'''   

            # Short term Depression 
            self.z_x = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_x)
            #print("z_x", self.z_x.size())
            #print("self.X", self.X.size())
            #print("self.ones", self.ones.size())
            #print("h_t", self.h_t.size())
            #print("self.U", self.U.size())
            #a = self.delta_t * self.U * torch.einsum("ijk, ji  -> ijk", self.X, self.h_t)
            #print("a", a)
            #print("a size", a.size())
            self.X = self.z_x + torch.mul((1 - self.z_x), self.X) - self.delta_t * self.U * torch.einsum("ijk, ji  -> ijk", self.X, self.h_t)

            # Short term Facilitation 
            self.z_u = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_u)    
            self.Ucap = 0.9 * sigmoid(self.c_U)
            self.U = self.Ucap * self.z_u + torch.mul((1 - self.z_u), self.U) + self.delta_t * self.Ucap * torch.einsum("ijk, ji  -> ijk", (1 - self.U), self.h_t)
            self.Ucapclone = self.Ucap.clone().detach() 
            self.U = torch.clamp(self.U, min=self.Ucapclone.repeat(self.U.size(0), 1, 1).to(device), max=torch.ones_like(self.Ucapclone.repeat(self.U.size(0), 1, 1).to(device)))

            # System Equations 
            self.z_h = self.e_h * sigmoid(self.c_h) 
            #   a = self.w * self.U * self.X
            #print("size of a", a.size())
            #print("size of h_t", self.h_t.size())
            #print("size of a * h_t", torch.matmul(a, self.h_t).size())
            #print("size of x", x.size())
            x = torch.transpose(x, 0, 1)
            self.h_t = torch.mul((1 - self.z_h), self.h_t) + self.z_h * sigmoid(torch.einsum("ijk, ki  -> ji", (self.w * self.U * self.X), self.h_t) + torch.matmul(self.p, x) + self.b)
            #self.h_t = torch.matmul(self.w, self.h_t) + torch.matmul(self.p, x) + self.b
            self.h_t = torch.transpose(self.h_t, 0, 1)
            return self.h_t   

        if self.complexity == "poor":
            if self.h_t.dim() == 3:
                self.h_t = self.h_t[0]
            self.h_t = torch.transpose(self.h_t, 0, 1)
            x = torch.transpose(x, 0, 1)
            sigmoid = nn.Sigmoid()
            
            # Short term Depression 
            self.z_x = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_x)
            #print("z_x", self.z_x.size())
            #print("self.X", self.X.size())
            #print("self.ones", self.ones.size())
            #print("self.U", self.U.size())
            #print("h_t", self.h_t.size())
            a = self.delta_t * self.U * self.X * self.h_t
            #print("a", a)
            #print("a size", a.size())
        
            self.X = self.z_x + torch.mul((1 - self.z_x), self.X) - self.delta_t * self.U * self.X * self.h_t

            # Short term Facilitation 
            self.z_u = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_u)    
            self.Ucap = 0.9 * sigmoid(self.c_U)
            self.U = self.Ucap * self.z_u + torch.mul((1 - self.z_u), self.U) + self.delta_t * self.Ucap * (1 - self.U) * self.h_t
            self.Ucapclone = self.Ucap.clone().detach()
            self.U = torch.clamp(self.U, min=self.Ucapclone.repeat(1, x.size(0)).to(device), max=torch.ones_like(self.Ucapclone.repeat(1, x.size(0)).to(device)))
            # graph plotting 
            '''self.forprintingX.append(self.X[20,5].item())
            self.forprintingU.append(self.U[20,5].item())
            if len(self.forprintingX) % 140 == 0:
                self.forprintingX = []
                self.forprintingU = []'''

            # System Equations 
            # self.z_h = self.e_h * sigmoid(self.c_h) 
            #a = self.w * self.U * self.X
            #print("size of a", a.size())
            #print("size of h_t", self.h_t.size())
            #print("size of a * h_t", torch.matmul(a, self.h_t).size())
            #print("size of x", x.size())
            x = torch.transpose(x, 0, 1)
            self.h_t = torch.mul((1 - self.c_h), self.h_t) + self.c_h * sigmoid(torch.matmul(self.w, (self.U * self.X * self.h_t)) + torch.matmul(self.p, x) + self.b)
            #self.h_t = torch.matmul(self.w, self.h_t) + torch.matmul(self.p, x) + self.b
            self.h_t = torch.transpose(self.h_t, 0, 1)
            return self.h_t

class STP(nn.Module):
    def __init__(self, input_size, hidden_size, complexity, e_h, alpha): 
        super(STP, self).__init__()
        self.stpcell = STPCell(input_size, hidden_size, complexity, e_h, alpha).to(device)

    def forward(self, x):
        for n in range(x.size(1)):
            x_slice = torch.transpose(x[:,n,:], 0, 1)
            self.stpcell(x_slice)
        return self.stpcell.h_t 

In [5]:

class STPCell(nn.Module):
    def __init__(self, input_size, hidden_size, complexity, e_h, alpha):
        super(STPCell, self).__init__()
        self.input_size = input_size        
        self.hidden_size = hidden_size
        self.complexity = complexity 
        sigmoid = nn.Sigmoid() 
        self.ones = torch.ones(self.hidden_size, self.hidden_size)
        self.batch_size = batch_size 
        self.forprintingX = []
        self.forprintingU = []
        self.forprintingh = []

        if self.complexity == "rich":
            # System variables 
            self.e_h = e_h

            # Short term Plasticity variables 
            self.delta_t = 1
            self.alpha = alpha
            self.e_ux = self.alpha * self.e_h
            self.z_min = 0.001
            self.z_max = 0.1

            # Short term De,////////////////pression parameters  
            self.c_x = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))

            # Short term Facilitation parameters
            self.c_u = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            self.c_U = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            
            # System parameters            
            self.c_h = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            self.w = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            self.p = torch.nn.Parameter(torch.rand(self.hidden_size, self.input_size))  
            self.b = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            
            # State initialisations
            self.h_t = torch.zeros(1, self.hidden_size, dtype=torch.float32)
            self.X = torch.ones(self.hidden_size, self.hidden_size, dtype=torch.float32)     
            self.U = torch.full((self.hidden_size, self.hidden_size), 0.9, dtype=torch.float32)         
            self.Ucap = 0.9 * sigmoid(self.c_U)
            self.Ucapclone = self.Ucap.clone().detach()
        if self.complexity == "poor":
            # System variables 
            self.e_h = e_h

            # Short term Plasticity variables 
            self.delta_t = 1
            self.alpha = alpha
            self.e_ux = self.alpha * self.e_h
            self.z_min = 0.001
            self.z_max = 0.1

            # Short term Depression parameters  
            self.c_x = torch.nn.Parameter(torch.rand(self.hidden_size, 1))

            # Short term Facilitation parameters
            self.c_u = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            self.c_U = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            
            # System parameters
            self.c_h = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            self.w = torch.nn.Parameter(torch.rand(self.hidden_size, self.hidden_size))
            self.p = torch.nn.Parameter(torch.rand(self.hidden_size, self.input_size))
            self.b = torch.nn.Parameter(torch.rand(self.hidden_size, 1))
            
            # State initialisations
            self.h_t = torch.zeros(1, self.hidden_size, dtype=torch.float32)
            self.X = torch.ones(self.hidden_size, 1, dtype=torch.float32)
            self.U = torch.full((self.hidden_size, 1), 0.9, dtype=torch.float32)   
            self.Ucap = 0.9 * sigmoid(self.c_U)
            self.Ucapclone = self.Ucap.clone().detach()
        for name, param in self.named_parameters():
            #print(name, param.size(), param)
            nn.init.uniform_(param, a=-(1/math.sqrt(hidden_size)), b=(1/math.sqrt(hidden_size))) 

    def forward(self, x):                 
        if self.complexity == "rich":
            if self.h_t.dim() == 3:
                self.h_t = self.h_t[0]
            self.h_t = torch.transpose(self.h_t, 0, 1)
            x = torch.transpose(x, 0, 1)
            sigmoid = nn.Sigmoid()
            
            # graph plotting 
            '''self.forprintingX.append(self.X[20,11,24].item())
            self.forprintingU.append(self.U[20,11,24].item())
            self.forprintingh.append(self.h_t[11, 20].item())
            if len(self.forprintingX) % (196*5) == 0:
                self.forprintingX = []
                self.forprintingU = []
                self.forprintingh = []'''   

            # Short term Depression 
            self.z_x = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_x)
            #print("z_x", self.z_x.size())
            #print("self.X", self.X.size())
            #print("self.ones", self.ones.size())
            #print("h_t", self.h_t.size())
            #print("self.U", self.U.size())
            #a = self.delta_t * self.U * torch.einsum("ijk, ji  -> ijk", self.X, self.h_t)
            #print("a", a)
            #print("a size", a.size())
            self.X = self.z_x + torch.mul((1 - self.z_x), self.X) - self.delta_t * self.U * torch.einsum("ijk, ji  -> ijk", self.X, self.h_t)

            # Short term Facilitation 
            self.z_u = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_u)    
            self.Ucap = 0.9 * sigmoid(self.c_U)
            self.U = self.Ucap * self.z_u + torch.mul((1 - self.z_u), self.U) + self.delta_t * self.Ucap * torch.einsum("ijk, ji  -> ijk", (1 - self.U), self.h_t)
            self.Ucapclone = self.Ucap.clone().detach() 
            self.U = torch.clamp(self.U, min=self.Ucapclone.repeat(self.U.size(0), 1, 1).to(device), max=torch.ones_like(self.Ucapclone.repeat(self.U.size(0), 1, 1).to(device)))

            # System Equations 
            self.z_h = self.e_h * sigmoid(self.c_h) 
            #   a = self.w * self.U * self.X
            #print("size of a", a.size())
            #print("size of h_t", self.h_t.size())
            #print("size of a * h_t", torch.matmul(a, self.h_t).size())
            #print("size of x", x.size())
            x = torch.transpose(x, 0, 1)
            self.h_t = torch.mul((1 - self.z_h), self.h_t) + self.z_h * sigmoid(torch.einsum("ijk, ki  -> ji", (self.w * self.U * self.X), self.h_t) + torch.matmul(self.p, x) + self.b)
            #self.h_t = torch.matmul(self.w, self.h_t) + torch.matmul(self.p, x) + self.b
            self.h_t = torch.transpose(self.h_t, 0, 1)
            return self.h_t   

        if self.complexity == "poor":
            if self.h_t.dim() == 3:
                self.h_t = self.h_t[0]
            self.h_t = torch.transpose(self.h_t, 0, 1)
            x = torch.transpose(x, 0, 1)
            sigmoid = nn.Sigmoid()
            
            # Short term Depression 
            self.z_x = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_x)
            #print("z_x", self.z_x.size())
            #print("self.X", self.X.size())
            #print("self.ones", self.ones.size())
            #print("self.U", self.U.size())
            #print("h_t", self.h_t.size())
            a = self.delta_t * self.U * self.X * self.h_t
            #print("a", a)
            #print("a size", a.size())
        
            self.X = self.z_x + torch.mul((1 - self.z_x), self.X) - self.delta_t * self.U * self.X * self.h_t

            # Short term Facilitation 
            self.z_u = self.z_min + (self.z_max - self.z_min) * sigmoid(self.c_u)    
            self.Ucap = 0.9 * sigmoid(self.c_U)
            self.U = self.Ucap * self.z_u + torch.mul((1 - self.z_u), self.U) + self.delta_t * self.Ucap * (1 - self.U) * self.h_t
            self.Ucapclone = self.Ucap.clone().detach()
            self.U = torch.clamp(self.U, min=self.Ucapclone.repeat(1, x.size(0)).to(device), max=torch.ones_like(self.Ucapclone.repeat(1, x.size(0)).to(device)))
            # graph plotting 
            '''self.forprintingX.append(self.X[20,5].item())
            self.forprintingU.append(self.U[20,5].item())
            if len(self.forprintingX) % 140 == 0:
                self.forprintingX = []
                self.forprintingU = []'''

            # System Equations 
            # self.z_h = self.e_h * sigmoid(self.c_h) 
            #a = self.w * self.U * self.X
            #print("size of a", a.size())
            #print("size of h_t", self.h_t.size())
            #print("size of a * h_t", torch.matmul(a, self.h_t).size())
            #print("size of x", x.size())
            x = torch.transpose(x, 0, 1)
            self.h_t = torch.mul((1 - self.c_h), self.h_t) + self.c_h * sigmoid(torch.matmul(self.w, (self.U * self.X * self.h_t)) + torch.matmul(self.p, x) + self.b)
            #self.h_t = torch.matmul(self.w, self.h_t) + torch.matmul(self.p, x) + self.b
            self.h_t = torch.transpose(self.h_t, 0, 1)
            return self.h_t

class STP(nn.Module):
    def __init__(self, input_size, hidden_size, complexity, e_h, alpha): 
        super(STP, self).__init__()
        self.stpcell = STPCell(input_size, hidden_size, complexity, e_h, alpha).to(device)

    def forward(self, x):
        for n in range(x.size(1)):
            x_slice = torch.transpose(x[:,n,:], 0, 1)
            self.stpcell(x_slice)
        return self.stpcell.h_t                                   
            

### Class Hierachy: STPclass, STP and RNN

Class STP carries most of the charateristics of STP class, the only difference being the change in forward() function.
However, RNN, in contrast, wrap the STP cell in classical RNN structure.

In [1]:
### Class Hierachy: STPclass, STP and RNN
class STP(nn.Module):
    def __init__(self, input_size, hidden_size, complexity, e_h, alpha): 
        super(STP, self).__init__()
        self.stpcell = STPCell(input_size, hidden_size, complexity, e_h, alpha).to(device)

    def forward(self, x):
        for n in range(x.size(1)):
            x_slice = torch.transpose(x[:,n,:], 0, 1)
            self.stpcell(x_slice)
        return self.stpcell.h_t                                   
            
class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = STP(input_size, hidden_size, "rich", 0.9, 0.1).to(device)
        ## The purpose of this fc layer is to convert the RNN output into class
        #  probabilities that can be used for the classification loss function during training.
        self.fc = nn.Linear(hidden_size, num_classes).to(device)
        self.update_number = 0
        pass

    def forward(self, x):
        # Set initial hidden and cell states 
        if self.lstm.stpcell.complexity == "rich":
            ## x.size(0) is the batch_size
            ## remember h_t is the hidden state, X is the STD available resource
            ## and U is the baseline available resource.
            self.lstm.stpcell.h_t = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            self.lstm.stpcell.X = torch.ones(x.size(0), self.hidden_size, self.hidden_size, dtype=torch.float32).to(device)
            #self.lstm.stpcell.U = torch.full((x.size(0), self.hidden_size, self.hidden_size), 0.9, dtype=torch.float32).to(device)
            self.lstm.stpcell.U = (self.lstm.stpcell.Ucapclone.repeat(x.size(0), 1, 1)).to(device)
        if self.lstm.stpcell.complexity == "poor":
            self.lstm.stpcell.h_t = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
            self.lstm.stpcell.X = torch.ones(self.hidden_size, x.size(0), dtype=torch.float32).to(device)
            #self.lstm.stpcell.U = torch.full((self.hidden_size, x.size(0)), 0.9, dtype=torch.float32).to(device)
            self.lstm.stpcell.U = (self.lstm.stpcell.Ucapclone.repeat(1, x.size(0))).to(device)
            #torch.full((2, 3), 3.141592)
        '''self.update_number += 1 
        if self.update_number % 50 == 0: 
            plt.plot(self.lstm.stpcell.forprintingX)
            plt.plot(self.lstm.stpcell.forprintingU)
            plt.plot(self.lstm.stpcell.forprintingh)
            plt.legend(["X","U","h_t"])
            plt.show()'''
        #c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        # Passing in the input and hidden state into the model and  obtaining outputs
        out = self.lstm(x)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        out = self.fc(out)
        return out
        
        pass                                    
pass

### Baseindexing
This function converts the 2D image to a sequential input.

In [None]:
def baseindexing(time_gap, input_size, stride, a):
    a = a.flatten()
    ## put from gpu to cpu because np only works in cpu
    a = a.cpu()
    ## convert cpu tensors to numpy array
    a = a.numpy()

    baseinds = np.arange(0, time_gap*input_size, time_gap)

    #zero padding 
    a = np.pad(a, (baseinds[-1],0), 'constant')

    new_sequence = []

    for t in range(784):
        new_sequence.append(a[(t+baseinds).tolist()])        
    '''print("baseinds", baseinds)
    print("2baseinds", 2+baseinds)
    print((2+baseinds).tolist())
    print(a[(2+baseinds).tolist()])'''
    
    #new_sequence = [item for sublist in new_sequence for item in sublist] 
    new_sequence = np.array(new_sequence)
    new_sequence = torch.tensor(new_sequence, dtype=torch.float).to(device)
    #print("size of new sequence tensor", new_sequence.size())
    return new_sequence

### Train function
1. Load data from loaders( the train dataset)
2. reshape to fit into the model
3. feed into the model and get outputs
4. compare the outputs with acutal label to get loss function
5. use gradient optimizer to do backward propagation, and step forward the model.


In [None]:
def train(num_epochs, model, loaders): 
        
    # Train the model
    total_step = len(loaders['train'])
    #torch.autograd.set_detect_anomaly(True)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):
            images = images.reshape(-1, sequence_length, input_size).to(device)
            labels = labels.to(device)
            # Forward pass    
            outputs = model(images)
            loss = loss_func(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                       .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))  
                pass
        
        pass
    pass

### Test the model
Simply evaluate the model by compare real and predicted model, just as what we did for loss function, but...
1. Using *test data* instead
2. Translated into percentage

In [None]:
### Test the model
def evaluate(mymodel):
    mymodel.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in loaders['test']:
            images = images.reshape(-1, sequence_length, input_size).to(device)
            labels = labels.to(device)
            outputs = mymodel(images)
            _, predicted = torch.max(outputs.data, 1)
            total = total + labels.size(0)
            correct = correct + (predicted == labels).sum().item()
    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))                  
    return (100*correct/total)

### Spiralisers and spiralised train and test function
From Claude 2:

The spiraliser function takes a 2D tensor 'a' and returns a new tensor containing the values of 'a' in spiral order.

Here is what it is doing step-by-step:

Convert 'a' to cpu and numpy array for easier manipulation.
Initialize loop variables k, l for starting row/col indices.
Initialize empty list 'spiral' to store values.
Loop through the tensor in spiral order:
Print first row
Print last column
Print last row
Print first column
This is done by looping through the appropriate indices and appending to spiral list.

Convert spiral list back to tensor with shape (28,28).
So in summary, it takes a 28x28 2D tensor, iterates through it in spiral order appending values to a list, then converts back to 28x28 tensor containing values in spiral order.

This is used to transform the MNIST images to introduce a temporal pattern that the RNN can learn. The idea is that scanning the image in spiral order creates a sequence rather than static image.



In [None]:
def spiraliser(m, n, a):
    a = a.cpu()
    a = a.numpy()
    k = 0
    l = 0
    spiral = []
    ''' k - starting row index
        m - ending row index
        l - starting column index
        n - ending column index
        i - iterator '''
  
    while (k < m and l < n):
  
        # Print the first row from
        # the remaining rows
        for i in range(l, n):
            spiral.append(a[k][i])
  
        k += 1
  
        # Print the last column from
        # the remaining columns
        for i in range(k, m):
            spiral.append(a[i][n - 1])
  
        n -= 1
  
        # Print the last row from
        # the remaining rows
        if (k < m):
  
            for i in range(n - 1, (l - 1), -1):
                spiral.append(a[m - 1][i])
  
            m -= 1
  
        # Print the first column from
        # the remaining columns
        if (l < n):
            for i in range(m - 1, k - 1, -1):
                spiral.append(a[i][l])
  
            l += 1
        
    spiraltensor = torch.tensor(spiral, dtype=torch.float)
    spiraltensor = spiraltensor.reshape(28, 28).to(device)
    return spiraltensor


In [None]:
def trainspiral(num_epochs, model, loaders): 
        
    # Train the model
    total_step = len(loaders['train'])
    #torch.autograd.set_detect_anomaly(True)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):
            #p = torch.rand(batch_size, 1, 784, input_size)
            p = images.clone()
            for n in range(images.size(dim=0)):         # Use this loop to spiralise the image 
                for image in images[n,0:1,:,:]: 
                    spiralimage = spiraliser(28, 28, image)        
                    #indexedimage = baseindexing(3, input_size, 1, spiralimage)  
                    p[n,0,:,:] = spiralimage
                    #print(images[2,0:1,0:28,0:28])
                    #print(p.size())
                    #print(image.size())
            images = p.clone() 
            #images = images.reshape(-1, 784, input_size).to(device)
            images = images.reshape(-1, sequence_length, input_size).to(device)
            labels = labels.to(device)
            # Forward pass    
            outputs = model(images)
            loss = loss_func(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                       .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))  
                pass
        
        pass
    pass