In [51]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import DatasetFolder
from tqdm import tqdm

In [52]:
transform_train_mnist = transforms.Compose([
    transforms.Resize((32, 32)),           # Resize to 32x32
    transforms.RandomCrop(32, padding=4),  # Same augmentation as CIFAR-10
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel grayscale
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),  # MNIST specific normalization
])

transform_test_mnist = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to 32x32
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel grayscale
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),  # MNIST specific normalization
])

trainset = torchvision.datasets.MNIST(root="./", train=True, download=True, transform=transform_train_mnist)
testset = torchvision.datasets.MNIST(root="./", train=False, download=True, transform=transform_test_mnist)

In [103]:
class Spike(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, threshold):
      # your code starts here
      ctx.save_for_backward(input)
      # ctx.threshold = threshold
      return (input > threshold).float()


  @staticmethod
  def backward(ctx, grad_output):
      # your code starts here
      input, = ctx.saved_tensors

      grad_input = grad_output

      surrogate_grad = (1 / (1 + torch.abs(input)) ** 2)
      return grad_input * surrogate_grad, None

class LIF(nn.Module):
    def __init__(self, thre:float = 1.0, tau:float = 0.5, delta_th: float = 1.0):
        super(LIF, self).__init__()
        self.thre = thre
        self.tau = tau
        self.delta_th = delta_th  # Increment to threshold after each spike
        self.membrane_potential = None

    def forward(self, x):
        device = x.device
        batch_size, time, channels, height, width = x.size()
        # Initialize membrane potential and threshold
        membrane_potential = torch.zeros_like(x[:, 0, :, :, :])
        threshold = torch.full_like(membrane_potential, self.thre)  # Starting threshold for all neurons

        spikes = torch.zeros_like(x)
        for t in range(x.size(1)):
            membrane_potential = membrane_potential * self.tau + x[:, t, :, :, :]
            spikes[:, t, :, :, :] = Spike.apply(membrane_potential, threshold)
            membrane_potential = membrane_potential * (1 - spikes[:, t, :, :, :])

            # Increase the threshold where spikes occur and reset slowly
            threshold += self.delta_th * spikes[:, t, :, :, :]  # Increment where spikes occurred
            threshold = torch.max(threshold * 0.99, torch.full_like(threshold, self.thre))  # Decay toward base threshold

        return spikes


class ConvLIF(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvLIF, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.lif = LIF()

    def forward(self, x):
        batch_size, time, channels, height, width = x.size()
        reshape_x = x.view(batch_size * time, channels, height, width)
        conv_out = self.conv(reshape_x)
        _, _, new_height, new_width = conv_out.size()
        new_out_channels = self.conv.weight.shape[0]
        reshape_conv = conv_out.view(batch_size, time, new_out_channels, new_height, new_width)
        return self.lif(reshape_conv)

In [104]:
class SAvgPool2d(nn.Module):
  def __init__(self, *args, **kwargs):
      super().__init__(*args, **kwargs)
      self.module = nn.AvgPool2d(2)

  def forward(self, x_seq: torch.Tensor):
      y_shape = [x_seq.shape[0], x_seq.shape[1]]
      y_seq = self.module(x_seq.flatten(0, 1).contiguous())
      y_shape.extend(y_seq.shape[1:])
      return y_seq.view(y_shape)

In [105]:
class SVGG(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
     super().__init__(*args, **kwargs)

     pool = SAvgPool2d()
     self.classifier = nn.Linear(2048 * T, 10)

     self.conv1 = ConvLIF(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)
     self.conv2 = ConvLIF(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
     self.conv4 = ConvLIF(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
     self.conv5 = ConvLIF(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
     self.conv7 = ConvLIF(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
     self.conv8 = ConvLIF(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)

     # your code starts here (convolutional layers)
     self.model = nn.Sequential(
         # Layer 1: ConvLIF
         self.conv1,
         # Layer 2: ConvLIF
         self.conv2,
         # Layer 3: Average Pool
         pool,
         # Layer 4: ConvLIF
         self.conv4,
         # Layer 5: ConvLIF
         self.conv5,
         # Layer 6: Average Pool
         pool,
         # Layer 7: ConvLIF
         self.conv7,
         # Layer 8: ConvLIF
         self.conv8,
         # Layer 9: Average Pool
         pool,
         nn.Flatten(),
         self.classifier
     )
  def forward(self, x):
    x = self.model(x)
    return x

In [106]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [107]:
# Configuration
lr = 1e-3
batch_size = 128
epochs = 5

# Time steps
T = 4

# Define the model
model = SVGG().cuda()

# Define the data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size, num_workers=2, pin_memory=True, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size, num_workers=2, pin_memory=True, shuffle=False)
print(f"Number of batches of training: {len(trainloader)} | number of batches of test: {len(testloader)}")

# Define the loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

Number of batches of training: 469 | number of batches of test: 79


In [108]:
for e in range(epochs):
  print("Training")
  model.train()
  train_loss = 0.0
  correct = 0
  total = 0
  for i, (images, target) in enumerate(tqdm(trainloader)):
    images = images.cuda()
    target = target.cuda()

    # Step 1: Adding a new time dimension and repeat the image by T times
    images = images.unsqueeze(1).repeat(1,T,1,1,1)

    # Step 2: Send the reshaped image into the model
    output = model(images)

    # Step 3: Compute the loss
    loss = loss_fn(output, target)

    # Step 4: Backward propagation + Update the model
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = torch.max(output.data, 1)
    total += target.size(0)
    correct += (predicted == target).sum().item()

  train_accuracy = 100 * correct / total
  # print(f"Epoch {e+1}, Loss: {train_loss/len(trainloader)}, Accuracy: {train_accuracy}%")


  print("Validation")
  model.eval()
  val_correct = 0
  val_total = 0
  for i, (images, target) in enumerate(tqdm(testloader)):
    images = images.cuda()
    target = target.cuda()

    # Step 1: Adding a new time dimension and repeat the image by T times
    images = images.unsqueeze(1).repeat(1,T,1,1,1)

    # Step 2: Send the reshaped image into the model
    output = model(images)

    # Step 3: Compute the accuracy based on the accuracy function
    _, predicted = torch.max(output, 1)  # Get the index of the max logit
    val_total += target.size(0)
    val_correct += (predicted == target).sum().item()

  top1 = 100 * val_correct / val_total
  print(f"[{e}]/[{epochs}] | Test accuracy = {top1}%")


Training


100%|██████████| 469/469 [00:18<00:00, 25.88it/s]


Validation


100%|██████████| 79/79 [00:02<00:00, 38.45it/s]


[0]/[5] | Test accuracy = 87.61%
Training


100%|██████████| 469/469 [00:17<00:00, 26.11it/s]


Validation


100%|██████████| 79/79 [00:02<00:00, 39.04it/s]


[1]/[5] | Test accuracy = 90.34%
Training


100%|██████████| 469/469 [00:18<00:00, 25.80it/s]


Validation


100%|██████████| 79/79 [00:02<00:00, 39.12it/s]


[2]/[5] | Test accuracy = 89.9%
Training


100%|██████████| 469/469 [00:17<00:00, 26.16it/s]


Validation


100%|██████████| 79/79 [00:02<00:00, 37.30it/s]


[3]/[5] | Test accuracy = 94.18%
Training


100%|██████████| 469/469 [00:18<00:00, 25.86it/s]


Validation


100%|██████████| 79/79 [00:02<00:00, 37.95it/s]

[4]/[5] | Test accuracy = 95.82%



