# Stateful Neuron Exploration

This notebook will test the performance of the stateful neural network PyTorch port, to verify that the method is 
implemented correctly and behaves as expected.

## Imports and Constants

In [None]:
# General imports
import os
import numpy as np
from tqdm.notebook import tqdm
import copy
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer

# Torch imports
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

# Imports for the tokenizer, the dataset, and the model
from transformers import GPT2Tokenizer
from utils.datasets import TextDataLoader
from models.transformer_model import TransformerModel

# Set random seed for reproducibility
torch.manual_seed(0)

# MNIST dataset hyperparameters
MNIST_INPUT_SIZE = 784
MNIST_LAYER_SIZES = [MNIST_INPUT_SIZE, 128, 64, 10]
MNIST_BATCH_SIZE = 1000

# Sunspot dataset hyperparameters
SUNSPOT_INPUT_SIZE = 12
SUNSPOT_LAYER_SIZES = [SUNSPOT_INPUT_SIZE, 128, 64, 1]
SUNSPOT_BATCH_SIZE = 100

NUM_EPOCHS = 5

# Device configuration
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")

# Set notebook to reload external python modules
%load_ext autoreload
%autoreload 2

# Test 1: Stateful Neuron vs. FCN

## MNIST Dataset Ingestion

In [None]:
# Create training set and loader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST(root='data/test/', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=MNIST_BATCH_SIZE, shuffle=True)

# Create test set and loader
testset = datasets.MNIST(root='data/test/', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=MNIST_BATCH_SIZE, shuffle=True)

## Helper Function: Train Network on MNIST

In [None]:
def train(model, device, train_data, test_data, criterion, opt, epochs):
    for epoch in range(epochs):

        # Training phase
        model.train()  # Set the model to training mode
        train_loss = 0.0
        for inputs, labels in tqdm(train_data, desc=f'Training: Epoch {epoch+1}/{epochs}', unit='batch'):
            # Move the data to the device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Flatten the images
            inputs = inputs.view(inputs.shape[0], -1)  # Flatten the images

            # Zero the gradients
            opt.zero_grad()

            # Forward pass and loss calculation
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and weight update
            loss.backward()
            opt.step()

            # Logging the loss
            train_loss += loss.item()

        # Testing phase
        model.eval()  # Set the model to evaluation mode
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_data:
                # Move the data to the device
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Flatten the images
                inputs = inputs.view(inputs.shape[0], -1)

                # Forward pass and loss calculation
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # Logging the loss and updating variables at batch level
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        # Logging the losses
        train_loss /= len(trainloader)
        test_loss /= len(testloader)
        test_accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

## `StatefulNeuronNetwork`

Define the model:

In [None]:
class Neurons(nn.Module):
    def __init__(self, n_neurons):
        super(Neurons, self).__init__()

        # Initialize matrix neuron parameters and number of neurons to create
        self.n_neurons = n_neurons
        self.params = nn.Parameter(torch.rand(n_neurons, 3, 3) * 2 - 1)

        # Initialize hidden state for batch processing
        self.hidden = nn.Parameter(torch.zeros(1, n_neurons, 1), requires_grad=False)
    
    def neuron_fn(self, inputs):
        batch_size = inputs.shape[0]

        # Expand hidden to match batch size
        hidden_batch = self.hidden.expand(batch_size, -1, -1)

        # Ensure inputs is 2D: (batch_size, n_neurons)
        inputs = inputs.view(batch_size, -1, 1)
        ones = torch.ones_like(inputs)

        # Concatenate along the second dimension
        stacked = torch.cat((inputs, hidden_batch, ones), dim=1)

        # Reshape stacked for matrix multiplication: [batch_size, n_neurons, 3]
        stacked = stacked.view(batch_size, self.n_neurons, 3)

        # Perform matrix multiplication
        dot = torch.tanh(torch.matmul(self.params, stacked.unsqueeze(3)).squeeze(3))

        # Update hidden state
        self.hidden = nn.Parameter(dot[:, :, -1].unsqueeze(2).detach(), requires_grad=False)

        return dot[:, :, 0], dot

class NeuralDiverseNet(nn.Module):
    def __init__(self, sizes):
        super(NeuralDiverseNet, self).__init__()
        self.neurons = nn.ModuleList([Neurons(size) for size in sizes])
        self.weights = nn.ModuleList([nn.Linear(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)])

    def forward(self, x):
        batch_size = x.shape[0]
        for i, neuron in enumerate(self.neurons[:-1]):  # Process through all but last layer
            send, _ = neuron.neuron_fn(x if i == 0 else pre)
            pre = self.weights[i](send)

        # Process the last layer
        final_output, _ = self.neurons[-1].neuron_fn(pre)

        # Reshape the output to ensure it has the shape [batch_size, n_classes]
        final_output = final_output.view(batch_size, -1)

        return final_output

Train the model:

In [None]:
stateful_neuron_model = NeuralDiverseNet(MNIST_LAYER_SIZES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(stateful_neuron_model.parameters(), lr=0.001)

train(model=stateful_neuron_model, device=DEVICE, train_data=trainloader, test_data=testloader, 
      criterion=criterion, opt=optimizer, epochs=NUM_EPOCHS)

## `FeedforwardNetwork`

Define the model:

In [None]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, MNIST_LAYER_SIZES):
        super(FeedForwardNetwork, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(len(MNIST_LAYER_SIZES) - 1):
            self.layers.append(nn.Linear(MNIST_LAYER_SIZES[i], MNIST_LAYER_SIZES[i + 1]))

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        x = self.layers[-1](x)  # No activation after the last layer
        return x

Train the model:

In [None]:
# Create the model
ff_model = FeedForwardNetwork(MNIST_LAYER_SIZES).to(DEVICE)
ff_optimizer = torch.optim.Adam(ff_model.parameters(), lr=0.001)
ff_criterion = nn.CrossEntropyLoss()

# Train the model
train(model=ff_model, device=DEVICE, train_data=trainloader, test_data=testloader, 
      criterion=ff_criterion, opt=ff_optimizer, epochs=NUM_EPOCHS)

## `SpikingNeuralNetwork`

Define the model:

In [None]:
def surrogate_gradient(x):
    alpha = 10  # The steepness of the surrogate gradient
    return torch.sigmoid(alpha * x)

class SpikingNeuronLayer(nn.Module):
    def __init__(self, size_in, size_out, device):
        super(SpikingNeuronLayer, self).__init__()
        self.device = device
        self.synaptic_weights = nn.Parameter(torch.randn(size_in, size_out, device=device) * 0.01)

    def forward(self, x):
        x = x.to(self.device)
        pre_synaptic = torch.matmul(x, self.synaptic_weights)
        post_synaptic = surrogate_gradient(pre_synaptic - 1)
        return post_synaptic

class SpikingNeuralNetwork(nn.Module):
    def __init__(self, MNIST_LAYER_SIZES, device):
        super(SpikingNeuralNetwork, self).__init__()
        self.layers = nn.ModuleList()
        self.device = device
        for i in range(len(MNIST_LAYER_SIZES) - 1):
            self.layers.append(SpikingNeuronLayer(MNIST_LAYER_SIZES[i], MNIST_LAYER_SIZES[i + 1], device))

    def forward(self, x):
        x = x.to(self.device)  # Ensure input tensor is on the correct device
        for layer in self.layers:
            x = layer(x)
        return x

Train the model:

In [None]:
# Create the network
snn_model = SpikingNeuralNetwork([784, 128, 64, 10], device=DEVICE).to(DEVICE)
snn_optimizer = torch.optim.Adam(snn_model.parameters(), lr=0.001)
snn_criterion = nn.CrossEntropyLoss()

# Train the model
train(model=snn_model, device=DEVICE, train_data=trainloader, test_data=testloader, 
      criterion=snn_criterion, opt=snn_optimizer, epochs=NUM_EPOCHS)

# Test 2: Stateful Neurons vs. RNN
Do stateful neurons perform similarly to recurrent neural networks on a simple time series task?

## Dataset Ingestion

In [None]:
# Load the dataset
df = pd.read_csv('data/test/Sunspots.csv', usecols=['Monthly Mean Total Sunspot Number'])
data = df.values.astype(float)

# Normalize the data
scaler = MinMaxScaler(feature_range=(0, 1))
data_normalized = scaler.fit_transform(data)

# Convert data to PyTorch tensors
data_normalized = torch.FloatTensor(data_normalized).view(-1)

# Create sequences and corresponding labels
sequence_length = 12  # For example, use 12 months to predict the next month
sequences = []
labels = []

for i in range(len(data_normalized) - sequence_length):
    sequences.append(data_normalized[i:i+sequence_length])
    labels.append(data_normalized[i+sequence_length])

sequences = torch.stack(sequences[:-1])
labels = torch.stack(labels[1:])

# Split the data into training, validation, and testing sets
train_sequences, test_sequences, train_labels, test_labels = train_test_split(
    sequences, labels, test_size=0.25, random_state=42
)

# Trim the dataset to an easily divisible length
train_sequences = train_sequences[:2400]
train_labels = train_labels[:2400]
test_sequences = test_sequences[:800]
test_labels = test_labels[:800]

# Create DataLoaders for each set
sunspot_train_loader = DataLoader(TensorDataset(train_sequences, train_labels), shuffle=True, 
                                  batch_size=SUNSPOT_BATCH_SIZE)
sunspot_test_loader = DataLoader(TensorDataset(test_sequences, test_labels), shuffle=False, 
                                 batch_size=SUNSPOT_BATCH_SIZE)

# Print the length of each set
print(f'Training set length: {len(train_sequences)}')
print(f'Testing set length: {len(test_sequences)}')

## Helper Function: Train Network on Sunspot Dataset

In [None]:
def train_time_series(model, train_loader, val_loader, criterion, opt, epochs, device=DEVICE, model_type=None):
    # Perform training
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for sequences, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}', unit='batch'):

            # Move the data to the device and reshape the targets
            sequences, labels = sequences.to(device), labels.to(device)
            sequences = sequences.view(sequences.shape[0], SUNSPOT_INPUT_SIZE, 1).to(device)
            labels = labels.unsqueeze(1).to(device)  # Reshape targets

            # Zero the gradients
            opt.zero_grad()

            # Forward pass and loss calculation
            outputs = model(sequences)
            loss = criterion(outputs, labels)

            # Backward pass and weight update
            loss.backward()
            opt.step()

            # Logging the loss
            train_loss += loss.item()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for sequences, labels in val_loader:
                # Move the data to the device and reshape the targets
                sequences, labels = sequences.to(device), labels.to(device)
                sequences = sequences.view(sequences.shape[0], SUNSPOT_INPUT_SIZE, 1).to(device)
                labels = labels.unsqueeze(1).to(device)  # Reshape targets

                # Forward pass and loss calculation
                outputs = model(sequences)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        print(f"Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

## `RecurrentNeuralNetwork`

Define the model:

In [None]:
class VanillaRNN(nn.Module):
    def __init__(self, input_size=1, hidden_size=12, output_size=1):
        super(VanillaRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])  # Using the last time step's output
        return out

Train the model:

In [None]:
rnn_model = VanillaRNN(input_size=1, hidden_size=1, output_size=1).to(DEVICE)
rnn_criterion = nn.MSELoss()
rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=0.001)

train_time_series(model=rnn_model, model_type='RNN', device=DEVICE, train_loader=sunspot_train_loader, 
      val_loader=sunspot_test_loader, criterion=rnn_criterion, opt=rnn_optimizer, epochs=NUM_EPOCHS)

## `RecurrentStatefulNeuron`

Define the model:

In [None]:
class Neurons(nn.Module):
    def __init__(self, n_neurons):
        super(Neurons, self).__init__()

        # Initialize matrix neuron parameters and number of neurons to create
        self.n_neurons = n_neurons
        self.params = nn.Parameter(torch.rand(n_neurons, 3, 3) * 2 - 1)

        # Initialize hidden state for batch processing
        self.hidden = nn.Parameter(torch.zeros(1, n_neurons, 1), requires_grad=False)
    
    def neuron_fn(self, inputs):
        batch_size = inputs.shape[0]

        # Expand hidden to match batch size
        hidden_batch = self.hidden.expand(batch_size, -1, -1)

        # Ensure inputs is 2D: (batch_size, n_neurons)
        inputs = inputs.view(batch_size, -1, 1)
        ones = torch.ones_like(inputs)

        # Concatenate along the second dimension
        stacked = torch.cat((inputs, hidden_batch, ones), dim=1)

        # Reshape stacked for matrix multiplication: [batch_size, n_neurons, 3]
        stacked = stacked.view(batch_size, self.n_neurons, 3)

        # Perform matrix multiplication
        dot = torch.tanh(torch.matmul(self.params, stacked.unsqueeze(3)).squeeze(3))

        # Update hidden state
        self.hidden = nn.Parameter(dot[:, :, -1].unsqueeze(2).detach(), requires_grad=False)

        return dot[:, :, 0], dot

class SequentialNeuralDiverseNet(nn.Module):
    def __init__(self, sizes):
        super(SequentialNeuralDiverseNet, self).__init__()
        self.neurons = nn.ModuleList([Neurons(size) for size in sizes])
        self.weights = nn.ModuleList([nn.Linear(sizes[i], sizes[i + 1]) for i in range(len(sizes) - 1)])

    def forward(self, x):
        batch_size = x.shape[0]
        for i, neuron in enumerate(self.neurons[:-1]):
            send, _ = neuron.neuron_fn(x if i == 0 else pre)
            pre = self.weights[i](send)

        final_output, _ = self.neurons[-1].neuron_fn(pre)

        # Since we're predicting a single value, we reshape the output to [batch_size, 1]
        final_output = final_output.view(batch_size, -1)

        return final_output

Train the model:

In [None]:
sequential_stateful_neuron_model = SequentialNeuralDiverseNet(SUNSPOT_LAYER_SIZES).to(DEVICE)
seq_criterion = nn.MSELoss()
seq_optimizer = torch.optim.Adam(sequential_stateful_neuron_model.parameters(), lr=0.001)

train_time_series(model=sequential_stateful_neuron_model, device=DEVICE, train_loader=sunspot_train_loader, 
      val_loader=sunspot_test_loader, criterion=seq_criterion, opt=seq_optimizer, epochs=NUM_EPOCHS)

# Test 3: Train Transformer On Shakespeare

Define helper function to train the transformer on reconstructing the Shakespeare dataset:

In [None]:
# Define function to mask the target tokens
# TODO: Using a mask breaks the training process due to a shape error. Needs to be fixed
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    return mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))

def train_shakespeare_transformer(model, data_loader, optimizer, num_epochs, device=DEVICE, mask=False):
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for inputs, labels in tqdm(data_loader, desc=f'Training: Epoch {epoch+1}/{num_epochs}', unit='batch'):
            # Move the data to the device
            input_seq = inputs.to(device)
            target_seq = labels.to(device)
            
            # Optionally create a mask for the target sequence
            if mask == True:
                target_seq_mask = create_look_ahead_mask(target_seq.size(1)).to(device)
                target_seq_mask = target_seq_mask.unsqueeze(0)
            else:
                target_seq_mask = None

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(input_seq, target_seq, tgt_mask=target_seq_mask)
            outputs = outputs.view(-1, outputs.size(-1))
            target_seq = target_seq.view(-1)

            # Calculate loss and backpropagate
            loss = nn.CrossEntropyLoss()(outputs, target_seq)
            loss.backward()
            optimizer.step()

            # Logging the loss and update progress bar
            epoch_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs} completed. Loss: {epoch_loss/len(data_loader)}")

Create the Shakespeare dataset:

In [None]:
# Define dataset parameters
seq_length = 512
batch_size = 5
file_path = os.path.join(os.getcwd(), 'data/shakespeare', 'tinyshakespeare_100_lines.txt')
bpe_tokenizer = 'gpt2'
vocab_size = 50257

# Create the data loader
data_loader = TextDataLoader(file_path, seq_length, bpe_tokenizer, batch_size, vocab_size)
train_loader, test_loader = data_loader.create_loaders()


Create the transformer model:

In [None]:
# Define the context window size k (defaulting to chunk_length / 2)
context_window = 256

# Define the model
transformer_model = TransformerModel(vocab_size=vocab_size, max_seq_length=context_window).to(DEVICE)

# Define optimizer
transformer_optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001)

Train the model:

In [None]:
# Define the step size to use for the sliding window
step_size = 16

# Train the model
train_shakespeare_transformer(transformer_model, train_loader, optimizer=transformer_optimizer, num_epochs=NUM_EPOCHS)