<a href="https://colab.research.google.com/github/Maya7991/gsc_classification/blob/main/fmnist_snn_multi_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pure SNN  conv architecture to train FMNIST with ttfs/rate encoding

*   Option to choose rate or latency (TTFS) encoding of input
*   The loss function is applied based on the encoding scheme used
*   model with least val loss saved
*   loss and acc during training saved to csv

* the architecture or model is not stable. This notebook is just a framework for using different encoding schemes for training. Model can be improved by some hyperparameter training




In [38]:
!pip install snntorch --quiet

In [39]:
import torch
from torchvision import transforms,datasets
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader

import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikegen

import os
import pandas as pd
import numpy as np
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [40]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])


train_dataset = datasets.FashionMNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST('data', train=False, download=True, transform=transform)

# Split training into train/val
train_len = int(0.9 * len(train_dataset))
val_len = len(train_dataset) - train_len
train_data, val_data = random_split(train_dataset, [train_len, val_len])

batch_size=50
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

print(f"Training data size : {len(train_data)}, Validation data size : {len(val_data)}, Test data size : {len(test_dataset)}")

Training data size : 54000, Validation data size : 6000, Test data size : 10000


In [41]:
# # Define spiking CNN
# class SNNConvNet(nn.Module):
#     def __init__(self, num_steps, encoding):
#         super().__init__()
#         beta=0.9
#         spike_grad = surrogate.fast_sigmoid()
#         self.num_steps = num_steps
#         self.encoding = encoding

#         self.conv1 = nn.Conv2d(1, 8, 5)
#         self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
#         self.pool1 = nn.MaxPool2d(2)

#         self.conv2 = nn.Conv2d(8, 16, 3)
#         self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
#         self.pool2 = nn.MaxPool2d(2)

#         self.fc1 = nn.Linear(16 * 25, 128)
#         self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
#         self.fc2 = nn.Linear(128, 10)
#         self.lif4 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)

#     def forward(self, x):
#         spk4_rec = []

#         if encoding == "rate":
#             x = spikegen.rate(x, num_steps=num_steps)
#         elif encoding == "ttfs":
#             # x = spikegen.ttfs(x, num_steps=num_steps)
#             spikegen.latency(x, num_steps=num_steps, normalize=True, linear=True)
#         else:
#             raise ValueError("Encoding must be 'rate' or 'ttfs'")

#         self.lif1.init_leaky()
#         self.lif2.init_leaky()
#         self.lif3.init_leaky()
#         self.lif4.init_leaky()

#         for step in range(num_steps):
#             cur1 = self.conv1(x[step])
#             spk1 = self.lif1(cur1)
#             smp1 = self.pool1(spk1)

#             cur2 = self.conv2(smp1)
#             spk2= self.lif2(cur2)
#             smp2 = self.pool2(spk2)

#             flat = smp2.view(smp2.size(0), -1)
#             cur3 = self.fc1(flat)
#             spk3= self.lif3(cur3)

#             cur4 = self.fc2(spk3)
#             spk4= self.lif4(cur4)

#             spk4_rec.append(spk4)

#         return torch.stack(spk4_rec, dim=0)

In [42]:
# Define Network
class Net(nn.Module):
    def __init__(self, num_steps, encoding):
        super().__init__()
        beta=0.9
        spike_grad = surrogate.fast_sigmoid(slope=25)
        self.num_steps = num_steps
        self.encoding = encoding

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 32, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(512, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):
        spk_rec = []
        if encoding == "rate":
            x = spikegen.rate(x, num_steps=num_steps)
        elif encoding == "ttfs":
            x = spikegen.latency(x, num_steps=num_steps, normalize=True, linear=True)
        else:
            raise ValueError("Encoding must be 'rate' or 'ttfs'")

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        for step in range(self.num_steps):
          cur1 = F.max_pool2d(self.conv1(x[step]), 2)
          spk1, mem1 = self.lif1(cur1, mem1)

          cur2 = F.max_pool2d(self.conv2(spk1), 2)
          spk2, mem2 = self.lif2(cur2, mem2)

          cur3 = self.fc1(spk2.view(spk2.size(0), -1))
          # cur3 = self.fc1(spk2.view(batch_size, -1))
          spk3, mem3 = self.lif3(cur3, mem3)
          spk_rec.append(spk3)

        return torch.stack(spk_rec, dim=0)

In [43]:
base_path = '/content/drive/My Drive/thesis_apr'
num_epochs = 15
learning_rate = 1e-4
num_steps = 20
encoding= "ttfs"
save_model_name = f"fmnist_snn_conv_{encoding}.pth"

if encoding == "rate":
    loss_fn = SF.ce_rate_loss()
elif encoding == "ttfs":
    loss_fn = SF.ce_temporal_loss()
else:
    raise ValueError("Encoding must be 'rate' or 'ttfs'")

net = Net(num_steps=num_steps, encoding=encoding).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

def trainNet(model, optimizer, encoding, epochs):
    print(loss_fn)
    loss_keeper={'train_loss':[],'valid_loss':[],'train_acc':[],'valid_acc':[]}
    val_loss_min = np.inf

    for epoch in range(epochs):
      train_loss=0.0
      train_acc = 0.0
      val_loss=0.0
      val_acc = 0.0

      """
      TRAINING PHASE
      """
      model.train()
      for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()
        spk_out = net(data)

        loss = loss_fn(spk_out, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += SF.accuracy_temporal(spk_out, targets)

    # return total_loss / len(train_loader), total_acc / len(train_loader)
      """
      VALIDATION PHASE
      """
      model.eval()
      with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            spk_out = net(data)
            loss = loss_fn(spk_out, targets)

            val_loss += loss.item()
            val_acc += SF.accuracy_temporal(spk_out, targets)

      train_loss = train_loss / len(train_loader)
      train_acc = 100*train_acc / len(train_loader)
      val_loss = val_loss / len(val_loader)
      val_acc = 100* val_acc / len(val_loader)
      loss_keeper['train_loss'].append(train_loss)
      loss_keeper['valid_loss'].append(val_loss)
      loss_keeper['train_acc'].append(train_acc)
      loss_keeper['valid_acc'].append(val_acc)

      print(f"\nEpoch : {epoch+1}\tTraining Loss : {train_loss:.4f}\t Validation Loss : {val_loss:.4f}\t Training Acc : {train_acc:.4f}% \tValidation Acc: {val_acc:.4f}%")

      if val_loss<=val_loss_min:
            print(f"Validation loss decreased from : {val_loss_min} ----> {val_loss} ----> Saving Model.......")
            torch.save(model.state_dict(), base_path + '/model/fmnist/' + save_model_name) # saving entire model causes PicklingError
            val_loss_min=val_loss

    return loss_keeper

def test(model):
    model.eval()
    total_acc = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)

            spk_rec = model(data)
            acc = SF.accuracy_temporal(spk_rec, targets)
            total_acc += acc.item()

    print(f"Test Accuracy: {100*total_acc/len(test_loader):.4f}")

if __name__ == '__main__':
    loss_keeper = trainNet(net, optimizer, encoding, num_epochs )
    test(net)

    # Save loss and accuracy to CSV
    results_df = pd.DataFrame(loss_keeper)
    csv_save_path = os.path.join(base_path, 'logs', f'loss_acc_{encoding}.csv')
    os.makedirs(os.path.dirname(csv_save_path), exist_ok=True)
    results_df.to_csv(csv_save_path, index=False)
    print(f"Training log saved to: {csv_save_path}")


Epoch : 1	Training Loss : 1.7867	 Validation Loss : 1.4944	 Training Acc : 44.9167% 	Validation Acc: 53.7333%
Validation loss decreased from : inf ----> 1.4943573450048764 ----> Saving Model.......

Epoch : 2	Training Loss : 1.5101	 Validation Loss : 1.4246	 Training Acc : 48.1500% 	Validation Acc: 48.3167%
Validation loss decreased from : 1.4943573450048764 ----> 1.424613740046819 ----> Saving Model.......

Epoch : 3	Training Loss : 1.4249	 Validation Loss : 1.3365	 Training Acc : 50.1574% 	Validation Acc: 50.4667%
Validation loss decreased from : 1.424613740046819 ----> 1.3364959860841432 ----> Saving Model.......

Epoch : 4	Training Loss : 1.2328	 Validation Loss : 1.2273	 Training Acc : 56.8500% 	Validation Acc: 58.1000%
Validation loss decreased from : 1.3364959860841432 ----> 1.2272872666517893 ----> Saving Model.......

Epoch : 5	Training Loss : 1.0522	 Validation Loss : 1.0769	 Training Acc : 65.8259% 	Validation Acc: 66.6500%
Validation loss decreased from : 1.227287266651789