# Let's try framed raw audio on MLP

Model is not learning anything! \
I thought this model would be emulating conv1D by dividing the data into frames.

In [5]:
!pip install snntorch --quiet
!pip install torchaudio --quiet

In [11]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchaudio
from torchaudio.datasets import SPEECHCOMMANDS
from torch.utils.data import DataLoader, Dataset
import torchaudio.transforms as T

from snntorch import spikegen, surrogate, functional as SF
import snntorch as snn

### Load and Frame the Audio Data

* Sampling Rate: The original Speech Commands dataset has a sampling rate of 16,000 Hz. In this script, we resample it to 8,000 Hz to reduce computational load.

* Normalization: Normalizing each frame helps in stabilizing the training process, especially when dealing with raw audio inputs.​

* Label Mapping: The script builds a label-to-index mapping to convert string labels into integer indices suitable for training.





* 400 features per timestep
* 20 timesteps
* total features in 1 second = 400 * 20 = 8000

In [7]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
 # Custom dataset class to handle framing
class FramedSpeechCommands(Dataset):
    def __init__(self, subset):
        self.dataset = SPEECHCOMMANDS(root="./", download=True, subset=subset)
        self.resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=8000)
        self.frame_size = 400  # 400 samples per frame
        self.num_frames = 20   # 8000 samples / 400 = 20 frames per 1-second clip
        self.label_to_index = self._build_label_index()

    def _build_label_index(self):
        labels = sorted(set(datapoint[2] for datapoint in self.dataset))
        return {label: idx for idx, label in enumerate(labels)}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        waveform, sample_rate, label, *_ = self.dataset[idx]
        waveform = self.resample(waveform)
        waveform = waveform.squeeze(0)  # Convert from [1, N] to [N]

        # Ensure the waveform is exactly 8000 samples
        if waveform.size(0) < 8000:
            padding = 8000 - waveform.size(0)
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        else:
            waveform = waveform[:8000]

        # Frame the waveform into segments of 400 samples
        frames = waveform.unfold(0, self.frame_size, self.frame_size)  # Shape: [20, 400]

        # Normalize each frame
        frames = (frames - frames.mean(dim=1, keepdim=True)) / (frames.std(dim=1, keepdim=True) + 1e-5)

        label_idx = self.label_to_index[label]
        return frames, label_idx

# Create dataset and dataloader
train_dataset = FramedSpeechCommands(subset='training')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Example: Iterate through the DataLoader
for batch_idx, (frames, labels) in enumerate(train_loader):
    # frames shape: [batch_size, 20, 400]
    # labels shape: [batch_size]
    print(f"Batch {batch_idx}:")
    print(f"Frames shape: {frames.shape}")
    print(f"Labels shape: {labels.shape}")
    break  # Remove this break to iterate through the entire dataset


100%|██████████| 2.26G/2.26G [00:24<00:00, 101MB/s] 


Batch 0:
Frames shape: torch.Size([64, 20, 400])
Labels shape: torch.Size([64])


### MLP SNN Network

In [9]:
# 3. MLP-SNN Model
beta = 0.95
# spike_grad = surrogate.fast_sigmoid()

class MLP_SNN(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()

        self.fc1 = nn.Linear(input_size, 512)
        self.lif1 = snn.Leaky(beta=0.95)

        self.fc2 = nn.Linear(512, 256)
        self.lif2 = snn.Leaky(beta=0.95)

        self.fc3 = nn.Linear(256, num_classes)
        self.lif3 = snn.Leaky(beta=0.95)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        spk_out = []

        for t in range(x.size(0)):
            x_t = x[t]
            cur1 = self.fc1(x_t)
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk_out.append(spk3)

        return torch.stack(spk_out, dim=0)


### Training

In [13]:
loss_fn = SF.ce_rate_loss()
num_epochs = 10

# 4. Training
def train():
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        acc = 0

        for inputs, labels in train_loader:
            # print(inputs.shape) # shape: [B, T, F]
            inputs = inputs.permute(1, 0, 2).float().to(device)  # shape: [T, B, F]
            labels = labels.to(device)

            outputs = model(inputs)
            # print(f"outputs {outputs.shape}")
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            acc += SF.accuracy_rate(outputs, labels)

        avg_loss = total_loss / len(train_loader)
        avg_acc = 100 * acc / len(train_loader)
        print(f"--------Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.2f}%--------")


In [14]:
num_classes = len(train_dataset.label_to_index)
model = MLP_SNN(input_size=400, output_size=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train()

--------Epoch 1/10, Loss: 3.5558, Accuracy: 2.16%--------
--------Epoch 2/10, Loss: 3.5554, Accuracy: 2.17%--------
--------Epoch 3/10, Loss: 3.5546, Accuracy: 2.44%--------
--------Epoch 4/10, Loss: 3.5536, Accuracy: 2.45%--------


KeyboardInterrupt: 

### Testing

In [None]:
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.permute(1, 0, 2).float().to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.mean(0).max(1)  # Mean over time, shape [B, num_classes]
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    print(f"Test Accuracy: {100 * correct / total:.2f}%")
