S4 implementation using PyTorch to train on the MINST dataset.

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [23]:
class S4Layer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(S4Layer, self).__init__()
        self.hidden_size = hidden_size

        self.A = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1)  # Transition matrix
        self.B = nn.Parameter(torch.randn(hidden_size, input_size) * 0.1)   # Input matrix
        self.C = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1) # Output matrix
        self.D = nn.Parameter(torch.randn(input_size, hidden_size) * 0.1)  # Direct term

    def forward(self, x):
        batch_size, seq_len, input_size = x.size()
        h = torch.zeros(batch_size, self.hidden_size, device=x.device)
        
        outputs = []
        for t in range(seq_len):
            u_t = x[:, t, :]  # Input at time t: (batch_size, input_size)
            h = torch.tanh(h @ self.A.T + u_t @ self.B.T)
            y = h @ self.C.T + u_t @ self.D
            outputs.append(y.unsqueeze(1))

        return torch.cat(outputs, dim=1)


In [26]:
class S4Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(S4Classifier, self).__init__()
        self.s4_layer = S4Layer(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.s4_layer(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x



In [34]:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = S4Classifier(input_size=28, hidden_size=128, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.0005)

for epoch in range(10):  # Train for 5 epochs
    model.train()
    for images, labels in train_loader:
        # Reshape images to sequences of length 28
        images = images.squeeze(1).transpose(1, 2).to(device)  # (batch_size, seq_len, input_size)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
    
    print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')


Epoch [1/5], Loss: 0.3586
Epoch [2/5], Loss: 0.3275
Epoch [3/5], Loss: 0.1513
Epoch [4/5], Loss: 0.1352
Epoch [5/5], Loss: 0.1489
Epoch [6/5], Loss: 0.1166
Epoch [7/5], Loss: 0.2430
Epoch [8/5], Loss: 0.1188
Epoch [9/5], Loss: 0.0581
Epoch [10/5], Loss: 0.0534


In [36]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.squeeze(1).transpose(1, 2).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')


Test Accuracy: 96.68%


In [21]:
for images, labels in train_loader:
    images = images.squeeze(1).transpose(1, 2).to(device)
    labels = labels.to(device)