<a href="https://colab.research.google.com/github/Maya7991/gsc_classification/blob/main/gsc_mel_snn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS
import torchaudio.transforms as T

import torch
import torch.nn as nn

import snntorch as snn
from snntorch import surrogate
import snntorch.functional as SF


In [None]:


class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./", download=True)
        def load_list(filename):
            with open(filename) as f:
                return [os.path.join(self._path, line.strip()) for line in f]
        if subset == "validation":
            self._walker = load_list(self._path + "/validation_list.txt")
        elif subset == "testing":
            self._walker = load_list(self._path + "/testing_list.txt")
        elif subset == "training":
            excludes = load_list(self._path + "/validation_list.txt") + load_list(self._path + "/testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]


In [None]:

transform = T.MelSpectrogram(sample_rate=16000, n_mels=64)

def preprocess(waveform):
    return transform(waveform).squeeze(0)  # Shape: [freq, time]


In [None]:


def encode_input(mel_spec, time_steps=100):
    return SF.poisson(mel_spec, num_steps=time_steps)  # Shape: [time, freq, time]


In [None]:


class SNNNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(64 * 32, 512)
        self.lif1 = snn.Leaky(beta=0.95, spike_grad=surrogate.fast_sigmoid())
        self.fc2 = nn.Linear(512, 35)  # 35 classes in speech commands
        self.lif2 = snn.Leaky(beta=0.95, spike_grad=surrogate.fast_sigmoid())

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk2_rec = []

        for step in range(x.size(0)):  # Time dimension
            x_t = x[step]
            cur1 = self.fc1(x_t.view(x_t.size(0), -1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)

        return torch.stack(spk2_rec)


In [None]:


def train_batch(inputs, labels, model, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()
    spk_out = model(inputs)
    out = spk_out.sum(dim=0)  # Sum over time
    loss = loss_fn(out, labels)
    loss.backward()
    optimizer.step()
    return loss.item()


In [None]:
def evaluate(model, dataloader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data, label in dataloader:
            encoded = encode_input(preprocess(data))
            spk_out = model(encoded)
            pred = spk_out.sum(0).argmax(1)
            correct += (pred == label).sum().item()
            total += label.size(0)
    return correct / total
