In [None]:
%pip install torch torchvision norse

Collecting norse
  Downloading norse-1.1.0.tar.gz (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nir (from norse)
  Downloading nir-1.0.4-py3-none-any.whl.metadata (5.8 kB)
Collecting nirtorch (from norse)
  Downloading nirtorch-1.0-py3-none-any.whl.metadata (3.6 kB)
Downloading nir-1.0.4-py3-none-any.whl (18 kB)
Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Building wheels for collected packages: norse
  Building wheel for norse (pyproject.toml) ... [?25l[?25hdone
  Created wheel for norse: filename=norse-1.1.0-py3-none-any.whl size=1539018 sha256=9fad2ecb8f0fa0670bf30b5fd3826ae3bb56027b94ae516e6e8e040c9b6690a7
  Stored in directory: /root/.cache/pip/wheels/16/fc/0d/4cbb14992b7e5bb35482df57e887a2ab55cad9ea890501cf6

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import norse

class SNNNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = norse.torch.LConv2d(
            in_channels=1,
            out_channels=20,
            kernel_size=5,
            stride=1
        )
        self.conv2 = norse.torch.LConv2d(
            in_channels=20,
            out_channels=50,
            kernel_size=5,
            stride=1
        )
        self.fc1 = norse.torch.LILinearCell(4*4*50, 500)
        self.fc2 = norse.torch.LILinearCell(500, 10)
        self.max_pool = torch.nn.MaxPool2d(2)

    def forward(self, x):
        # Add time dimension
        batch_size = x.shape[0]
        time_steps = 50
        x = x.unsqueeze(0).repeat(time_steps, 1, 1, 1, 1)

        # Encode input as spikes using threshold
        x = (x > 0.5).float()

        # Conv layers
        # Apply convolutions to the entire 5D tensor at once
        spikes = self.conv1(x)
        # Apply max pooling along spatial dimensions for each time step and batch
        spikes = spikes.permute(1, 0, 2, 3, 4) # Permute to (batch, time, channels, height, width)
        spikes = spikes.reshape(batch_size * time_steps, spikes.shape[2], spikes.shape[3], spikes.shape[4]) # Reshape to (batch*time, channels, height, width)
        spikes = self.max_pool(spikes)
        spikes = spikes.reshape(batch_size, time_steps, spikes.shape[1], spikes.shape[2], spikes.shape[3]) # Reshape back to (batch, time, channels, height, width)
        spikes = spikes.permute(1, 0, 2, 3, 4) # Permute back to (time, batch, channels, height, width)

        spikes = self.conv2(spikes)
        # Apply max pooling along spatial dimensions for each time step and batch
        spikes = spikes.permute(1, 0, 2, 3, 4) # Permute to (batch, time, channels, height, width)
        spikes = spikes.reshape(batch_size * time_steps, spikes.shape[2], spikes.shape[3], spikes.shape[4]) # Reshape to (batch*time, channels, height, width)
        spikes = self.max_pool(spikes)
        spikes = spikes.reshape(batch_size, time_steps, spikes.shape[1], spikes.shape[2], spikes.shape[3]) # Reshape back to (batch, time, channels, height, width)
        spikes = spikes.permute(1, 0, 2, 3, 4) # Permute back to (time, batch, channels, height, width)

        # Flatten
        spikes = spikes.view(time_steps, batch_size, 4*4*50)

        # Fully connected layers
        # Iterate over the batch dimension
        outputs = []
        for i in range(batch_size):
            # Process each sample individually
            out, state = self.fc1(spikes[:, i, :])  # Pass one sample at a time
            out, state = self.fc2(out)
            outputs.append(out)  # Collect the outputs

        # Stack the outputs back into a single tensor
        spikes = torch.stack(outputs, dim=1)

        # Average over time steps
        return spikes.mean(0)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    losses = []

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss = torch.nn.functional.cross_entropy(output, target)

        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    return torch.tensor(losses).mean()

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            test_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')

    return test_loss, accuracy

In [None]:
# Training settings
batch_size = 64
epochs = 10
learning_rate = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    '../data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.MNIST(
    '../data',
    train=False,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False
)

# Initialize model and optimizer
model = SNNNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(1, epochs + 1):
    train_loss = train(model, device, train_loader, optimizer, epoch)
    test_loss, accuracy = test(model, device, test_loader)

    print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, "
          f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%")


Test set: Average loss: 0.1072, Accuracy: 9657/10000 (96.57%)

Epoch: 1, Train Loss: 0.2984, Test Loss: 0.1072, Test Accuracy: 96.57%

Test set: Average loss: 0.0694, Accuracy: 9779/10000 (97.79%)

Epoch: 2, Train Loss: 0.0978, Test Loss: 0.0694, Test Accuracy: 97.79%

Test set: Average loss: 0.0535, Accuracy: 9830/10000 (98.30%)

Epoch: 3, Train Loss: 0.0728, Test Loss: 0.0535, Test Accuracy: 98.30%

Test set: Average loss: 0.0509, Accuracy: 9833/10000 (98.33%)

Epoch: 4, Train Loss: 0.0602, Test Loss: 0.0509, Test Accuracy: 98.33%

Test set: Average loss: 0.0478, Accuracy: 9853/10000 (98.53%)

Epoch: 5, Train Loss: 0.0532, Test Loss: 0.0478, Test Accuracy: 98.53%

Test set: Average loss: 0.0524, Accuracy: 9850/10000 (98.50%)

Epoch: 6, Train Loss: 0.0471, Test Loss: 0.0524, Test Accuracy: 98.50%

Test set: Average loss: 0.0531, Accuracy: 9836/10000 (98.36%)

Epoch: 7, Train Loss: 0.0412, Test Loss: 0.0531, Test Accuracy: 98.36%

Test set: Average loss: 0.0579, Accuracy: 9815/10000 (