<a href="https://colab.research.google.com/github/Maya7991/gsc_classification/blob/main/snn_conv2d_mfcc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install snntorch --quiet
!pip install torchaudio --quiet

In [None]:
import os
import csv
import numpy as np
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS
from torch.utils.data import DataLoader, Dataset
import torchaudio.transforms as T

from snntorch import spikegen, surrogate, functional as SF
import snntorch as snn

In [None]:


# --- Configuration ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_steps = 20
encoding_type = "rate"  # <<< "rate" or "latency"
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)
csv_log = "training_log.csv"

# --- Data ---
transform = torchaudio.transforms.MFCC(
    sample_rate=16000,
    n_mfcc=40,
    melkwargs={'n_fft': 400, 'hop_length': 160, 'n_mels': 40}
)

train_dataset = SPEECHCOMMANDS("./", download=True, subset="training")
val_dataset = SPEECHCOMMANDS("./", download=True, subset="validation")
test_dataset = SPEECHCOMMANDS("./", download=True, subset="testing")

# Build label encoder
all_labels = sorted(set(datapoint[2] for datapoint in train_dataset + val_dataset + test_dataset))
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Collate function
def collate_fn(batch):
    X, y = [], []
    max_len = 0
    mfccs = []

    for waveform, sample_rate, label, *_ in batch:
        mfcc = transform(waveform).squeeze(0)  # [n_mfcc, time]
        mfccs.append(mfcc)
        y.append(label_encoder.transform([label])[0])
        max_len = max(max_len, mfcc.shape[1])

    for mfcc in mfccs:
        pad_len = max_len - mfcc.shape[1]
        padded = F.pad(mfcc, (0, pad_len))  # Pad on time dimension
        X.append(padded.unsqueeze(0))  # [1, n_mfcc, time]

    return torch.stack(X), torch.tensor(y)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# --- Spike Encoding ---
def encode_input(batch, num_steps, encoding="rate"):
    min_val = batch.amin(dim=[2, 3], keepdim=True)
    max_val = batch.amax(dim=[2, 3], keepdim=True)
    norm_batch = (batch - min_val) / (max_val - min_val + 1e-7)

    if encoding == "rate":
        spk = spikegen.rate(norm_batch, num_steps=num_steps)
    elif encoding == "latency":
        spk = spikegen.latency(norm_batch, num_steps=num_steps, normalize=True, linear=True)
    else:
        raise ValueError("Encoding must be 'rate' or 'latency'")
    return spk  # [T, B, C, H, W]

# --- Model ---
class SNN_Conv2D(nn.Module):
    def __init__(self, num_classes, num_steps):
        super().__init__()
        self.num_steps = num_steps

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.lif3 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128 * 10 * 13, num_classes)  # Check your final feature map size!

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk_out = 0

        for step in range(self.num_steps):
            input_t = x[step]

            cur1 = self.conv1(input_t)
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.conv2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = self.conv3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            flat = self.flatten(spk3)
            out = self.fc(flat)

            spk_out += out

        return spk_out / self.num_steps

# --- Training & Evaluation ---
def train_epoch(model, loader, optimizer, loss_fn):
    model.train()
    total_loss, correct = 0, 0

    for x, y in loader:
        spk_x = encode_input(x.to(device), model.num_steps, encoding=encoding_type)
        y = y.to(device)

        optimizer.zero_grad()
        out = model(spk_x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (out.argmax(dim=1) == y).sum().item()

    acc = correct / len(loader.dataset)
    return total_loss / len(loader), acc

def evaluate(model, loader, loss_fn):
    model.eval()
    total_loss, correct = 0, 0

    with torch.no_grad():
        for x, y in loader:
            spk_x = encode_input(x.to(device), model.num_steps, encoding=encoding_type)
            y = y.to(device)

            out = model(spk_x)
            loss = loss_fn(out, y)

            total_loss += loss.item()
            correct += (out.argmax(dim=1) == y).sum().item()

    acc = correct / len(loader.dataset)
    return total_loss / len(loader), acc

# --- Instantiate Model ---
num_classes = len(label_encoder.classes_)
model = SNN_Conv2D(num_classes=num_classes, num_steps=num_steps).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss_fn = nn.CrossEntropyLoss()

# --- Training Loop ---
with open(csv_log, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_loss", "train_acc", "val_loss", "val_acc"])

best_val_acc = 0.0
num_epochs = 30

print("Start training!")

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, loss_fn)
    val_loss, val_acc = evaluate(model, val_loader, loss_fn)

    # Log
    with open(csv_log, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, train_loss, train_acc, val_loss, val_acc])

    print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Acc={train_acc*100:.2f}% | Val Loss={val_loss:.4f}, Acc={val_acc*100:.2f}%")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
        print(f"Saved model at epoch {epoch}!")

# --- Final Test Evaluation ---
test_loss, test_acc = evaluate(model, test_loader, loss_fn)
print(f"Test Loss={test_loss:.4f}, Acc={test_acc*100:.2f}%")
