<a href="https://colab.research.google.com/github/Maya7991/gsc_classification/blob/main/snn_conv1d_mfcc_val.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SNN conv1D on MFCC

### Imports

In [None]:
!pip install snntorch --quiet
!pip install torchaudio --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m86.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m70.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import csv
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS
from torch.utils.data import DataLoader, Dataset
import torchaudio.transforms as T

from snntorch import spikegen, surrogate, functional as SF
import snntorch as snn

### Load & Preprocess the Speech Command Dataset

In [None]:
transform = torchaudio.transforms.MFCC(
    sample_rate=16000,
    n_mfcc=40,
    melkwargs={'n_fft': 400, 'hop_length': 160, 'n_mels': 40}
)

train_dataset = SPEECHCOMMANDS(
    "./", download=True, subset="training")
val_dataset = SPEECHCOMMANDS(
    "./", download=True, subset="validation")
test_dataset = SPEECHCOMMANDS(
    "./", download=True, subset="testing")

# Limit to a few keywords for now (e.g., "yes", "no", "up", "down")
keywords = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
label_dict = {k: i for i, k in enumerate(keywords)}

def collate_fn(batch):
    X, y = [], []
    max_len = 0
    mfccs = []

    for waveform, sample_rate, label, *_ in batch:
        if label in keywords:
            mfcc = transform(waveform).squeeze(0)  # [n_mfcc, time]
            mfccs.append(mfcc)
            # print("MFCC shape:", mfcc.shape)  # add this
            y.append(label_dict[label])
            max_len = max(max_len, mfcc.shape[1])

    for mfcc in mfccs:
        pad_len = max_len - mfcc.shape[1]
        padded = F.pad(mfcc, (0, pad_len))  # Pad on the time dimension (right side)
        X.append(padded)

    return torch.stack(X), torch.tensor(y)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)


100%|██████████| 2.26G/2.26G [00:23<00:00, 104MB/s]


### Latency encoding

encode mfcc features to spike trains

In [None]:
# def encode_input(mfcc_batch, num_steps=100):
#     # Normalize to [0, 1]
#     data = (mfcc_batch - mfcc_batch.min()) / (mfcc_batch.max() - mfcc_batch.min())
#     # [B x C x L] → [B x L] if needed
#     data = data.mean(dim=1) if data.ndim == 3 else data
#     # Apply latency encoding
#     spk_data = spikegen.latency(data, num_steps=num_steps, normalize=True, linear=True)
#     return spk_data  # shape: [T x B x L]

def encode_input(mfcc_batch, num_steps=15):
    # Normalize to [0, 1] per sample
    min_val = mfcc_batch.amin(dim=(1,2), keepdim=True)
    max_val = mfcc_batch.amax(dim=(1,2), keepdim=True)
    data = (mfcc_batch - min_val) / (max_val - min_val + 1e-7)

    # Shape: [B, C, L] → [T, B, C, L]
    spk_data = spikegen.latency(data, num_steps=num_steps, normalize=True, linear=True)
    return spk_data  # [T, B, C, L]

### Conv1D SNN Architecture

In [None]:
class SNNConv1D(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        beta = 0.9  # LIF decay constant
        self.conv1 = nn.Conv1d(40, 32, kernel_size=5, stride=1, padding=2)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

        # self.fc1 = nn.Linear(64 * 20, num_classes)
        self.fc1 = nn.Linear(64 * 101, num_classes)
        self.num_steps = 100

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk_out = 0

        for step in range(self.num_steps):
            input_t = x[step]  # Shape: [B x input_dim] [B, 20, T]
            # input_t = input_t.unsqueeze(1)  # Add channel dim → [B x 1 x L]
            x1 = self.conv1(input_t)
            spk1, _ = self.lif1(x1)

            x2 = self.conv2(spk1)
            spk2, _ = self.lif2(x2)

            x_flat = spk2.view(spk2.size(0), -1)
            # print("Flattened shape:", x_flat.shape)
            out = self.fc1(x_flat)
            spk_out += out
        return spk_out / self.num_steps  # Soft output across time


### Training Loop

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SNNConv1D(num_classes=len(keywords)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    correct = 0
    for x, y in loader:
        # x = transform(x).to(device)  # Apply MFCC transform
        spk_x = encode_input(x, num_steps=model.num_steps).to(device)  # [T x B x L] -> TTFS encoding
        y = y.to(device)

        optimizer.zero_grad()
        out = model(spk_x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (out.argmax(dim=1) == y).sum().item()
    acc = correct / len(loader.dataset)
    return total_loss / len(loader), acc


### Train & Evaluate

In [None]:
def evaluate(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            spk_x = encode_input(x, num_steps=model.num_steps).to(device)
            y = y.to(device)

            out = model(spk_x)
            loss = loss_fn(out, y)

            total_loss += loss.item()
            correct += (out.argmax(dim=1) == y).sum().item()
    acc = correct / len(loader.dataset)
    return total_loss / len(loader), acc

In [None]:
csv_filename = "snn_mfcc_log.csv"
model_dir = "checkpoints"
os.makedirs(model_dir, exist_ok=True)

# Write header
with open(csv_filename, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_loss", "train_acc", "val_loss", "val_acc"])
#--------------------------------------------------------------------------------

best_val_acc = 0.0
for epoch in range(10):
    train_loss, train_acc = train_epoch(model, train_loader)
    val_loss, val_acc = evaluate(model, val_loader)

    # Log to CSV
    with open(csv_filename, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, train_loss, train_acc, val_loss, val_acc])

    print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Acc={train_acc*100:.2f}% | Val Loss={val_loss:.4f}, Acc={val_acc*100:.2f}%")

    # Save model
    # model_path = os.path.join(model_dir, f"snn_epoch_{epoch}.pth")
    # torch.save(model.state_dict(), model_path)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        save_model_path = os.path.join(model_dir, "snn_conv1d_mfcc_model.pth")
        torch.save(model.state_dict(), save_model_path)
        print(f"Model saved at epoch {epoch} with Val Acc={val_acc*100:.2f}%")

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Accuracy={train_acc*100:.2f}%")
    if train_acc > 0.90:
        print("Target accuracy reached!")
        break
