In [None]:
import torch
import torch.nn as nn
import torchaudio
import snntorch as snn
from torch.utils.data import Dataset, DataLoader
from snntorch import surrogate
import matplotlib.pyplot as plt
import os

# Hyperparameters
batch_size = 64
num_steps = 100  # 1 second at 10 ms frames
num_inputs = 1300  # 13 MFCCs × 100 frames
num_hidden = 100
num_outputs = 10  # 10 command words
learning_rate = 1e-3
num_epochs = 15
patience = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check dataset
dataset_dir = "speech_commands_v0.02"
if not os.path.exists(dataset_dir):
    print("Please download speech_commands_v0.02.tar.gz from http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz and extract to 'speech_commands_v0.02'")
    exit()

# Custom Dataset for Google Speech Commands
class SpeechCommandsDataset(Dataset):
    def __init__(self, root="speech_commands_v0.02", train=True, target_words=["yes", "no", "up", "down", "left", "right", "on", "off", "stop", "go"]):
        self.root = root
        self.train = train
        self.target_words = target_words
        self.transform = torchaudio.transforms.MFCC(
            sample_rate=16000,
            n_mfcc=13,
            melkwargs={
                'n_fft': 400,
                'hop_length': 160,
                'f_min': 20,
                'f_max': 4000,
                'n_mels': 40
            }
        )
        self.data = []
        self.labels = []
        for word in os.listdir(root):
            if word in target_words:
                word_dir = os.path.join(root, word)
                for file in os.listdir(word_dir):
                    if file.endswith(".wav"):
                        file_path = os.path.join(word_dir, file)
                        label = target_words.index(word)
                        self.data.append(file_path)
                        self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_path, label = self.data[idx], self.labels[idx]
        try:
            waveform, sample_rate = torchaudio.load(file_path)
            if waveform.numel() == 0:
                raise ValueError("Empty waveform")
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                waveform = resampler(waveform)
            mfcc = self.transform(waveform)
            if mfcc.dim() == 3 and mfcc.shape[0] == 1:
                mfcc = mfcc.squeeze(0)  # Shape: [13, num_frames]
            num_frames = mfcc.shape[1]
            if num_frames < 100:
                pad_width = 100 - num_frames
                mfcc = torch.nn.functional.pad(mfcc, (0, pad_width))
            elif num_frames > 100:
                mfcc = mfcc[:, :100]
            flattened = mfcc.flatten()  # Shape: [1300]
            return flattened, label
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
            return torch.zeros(1300), -1  # Dummy tensor and label

# Custom collate function
def collate_fn(batch):
    data = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    valid_indices = [i for i, target in enumerate(targets) if target != -1]
    if not valid_indices:
        return torch.zeros(1, 1300), torch.zeros(1, dtype=torch.long)
    data = [data[i] for i in valid_indices]
    targets = [targets[i] for i in valid_indices]
    data = torch.stack(data)
    targets = torch.tensor(targets)
    return data, targets

# Load datasets
train_dataset = SpeechCommandsDataset(train=True)
test_dataset = SpeechCommandsDataset(train=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Surrogate gradient
spike_grad = surrogate.fast_sigmoid(slope=25)

# Baseline SNN
class BaselineSNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=spike_grad)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=spike_grad)
        for layer in [self.fc1, self.fc2]:
            nn.init.xavier_uniform_(layer.weight, gain=0.1)

    def forward(self, x):
        batch_size = x.size(1)
        mem1, mem2 = self.lif1.init_leaky(), self.lif2.init_leaky()
        spk1_rec, spk2_rec = [], []
        for step in range(num_steps):
            cur1 = self.fc1(x[step])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk1_rec.append(spk1)
            spk2_rec.append(spk2)
        return torch.stack(spk1_rec, dim=0), torch.stack(spk2_rec, dim=0)

# Instantiate model and optimizer
net = BaselineSNN().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.2)
criterion = nn.CrossEntropyLoss()

# Training loop
best_val_accuracy = 0.0
epochs_no_improve = 0
accuracies, val_accuracies, losses, total_spikes_list = [], [], [], []
for epoch in range(num_epochs):
    net.train()
    total_spikes = 0
    for data, targets in train_loader:
        valid_indices = targets != -1
        if not valid_indices.any():
            continue
        data = data[valid_indices]
        targets = targets[valid_indices]
        data, targets = data.to(device), targets.to(device)
        batch_size_actual = data.size(0)
        if data.max() > 0:
            data = data / data.max()
        spike_data = (torch.rand(num_steps, batch_size_actual, num_inputs, device=device) < data * 0.2).float()
        spk1_rec, outputs = net(spike_data)
        spk_count = outputs.sum(dim=0)
        loss = criterion(spk_count, targets)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=3.0)
        optimizer.step()
        correct = (spk_count.argmax(dim=1) == targets).sum().item()
        accuracy = correct / batch_size_actual
        total_spikes += (spk1_rec.sum() + outputs.sum()).item()
    scheduler.step()
    losses.append(loss.item())
    accuracies.append(accuracy)

    # Validation
    net.eval()
    total_correct = 0
    total_samples = 0
    val_total_spikes = 0
    with torch.no_grad():
        for data, targets in test_loader:
            valid_indices = targets != -1
            if not valid_indices.any():
                continue
            data = data[valid_indices]
            targets = targets[valid_indices]
            data, targets = data.to(device), targets.to(device)
            batch_size_actual = data.size(0)
            if data.max() > 0:
                data = data / data.max()
            spike_data = (torch.rand(num_steps, batch_size_actual, num_inputs, device=device) < data * 0.2).float()
            spk1_rec, outputs = net(spike_data)
            spk_count = outputs.sum(dim=0)
            total_correct += (spk_count.argmax(dim=1) == targets).sum().item()
            total_samples += batch_size_actual
            val_total_spikes += (spk1_rec.sum() + outputs.sum()).item()
        val_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
        val_accuracies.append(val_accuracy)
        total_spikes_list.append(val_total_spikes)
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}, Total Spikes: {val_total_spikes:.0f}")

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs")
            break

# Plot results
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(accuracies, label="Train Accuracy")
plt.plot(val_accuracies, label="Validation Accuracy")
plt.title("Accuracy")
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(losses, label="Loss")
plt.title("Loss")
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(total_spikes_list, label="Total Spikes")
plt.title("Total Spikes")
plt.legend()
plt.tight_layout()
plt.savefig("baseline_snn_results.png")
plt.show()

# Final summary
final_val_accuracy = max(val_accuracies) if val_accuracies else 0.0
final_total_spikes = total_spikes_list[val_accuracies.index(final_val_accuracy)] if val_accuracies else 0
energy = final_total_spikes * 20 / 1e6  # µJ
print("\nBaseline SNN Final Results:")
print(f"Peak Validation Accuracy: {final_val_accuracy:.2%}")
print(f"Total Spikes at Peak Accuracy: {final_total_spikes:.0f}")
print(f"Energy Consumption: {energy:.2f} µJ")

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
from torch.utils.data import DataLoader
import random

# Set random seeds
torch.manual_seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants
data_path = './speechcommands_data'
target_classes = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
class_to_idx = {label: idx for idx, label in enumerate(target_classes)}
FIXED_MFCC_LENGTH = 100

# Preprocessing
def preprocess(waveform, sample_rate):
    transform = torchaudio.transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=20,
        melkwargs={'n_fft': 400, 'hop_length': 160, 'n_mels': 64}
    )
    mfcc = transform(waveform).squeeze(0)
    time_steps = mfcc.shape[1]

    if time_steps < FIXED_MFCC_LENGTH:
        pad_amt = FIXED_MFCC_LENGTH - time_steps
        mfcc = F.pad(mfcc, (0, pad_amt))
    else:
        mfcc = mfcc[:, :FIXED_MFCC_LENGTH]

    return mfcc

# Dataset
class SubsetSC(torch.utils.data.Dataset):
    def __init__(self, samples):
        self.samples = [(waveform, sample_rate, label) for waveform, sample_rate, label, *_ in samples if label in target_classes]

    def __getitem__(self, index):
        waveform, sample_rate, label = self.samples[index]
        mfcc = preprocess(waveform, sample_rate)
        return mfcc, class_to_idx[label]

    def __len__(self):
        return len(self.samples)

# Load Data
full_train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_path, download=True, subset='training')
full_val_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_path, download=True, subset='validation')

train_dataset = SubsetSC(full_train_dataset)
val_dataset = SubsetSC(full_val_dataset)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)

# 🔥 Improved LIF Neuron with sharper surrogate and optional dropout
class LIFNeuron(nn.Module):
    def __init__(self, tau=2.0, dropout=0.0):
        super().__init__()
        self.tau = tau
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x, mem):
        mem = mem * self.tau + x
        spk = self.surrogate_spike(mem - 1.0)
        mem = mem * (1.0 - spk)
        if self.dropout is not None:
            spk = self.dropout(spk)
        return spk, mem

    def surrogate_spike(self, x):
        return torch.sigmoid(10 * x)  # 🔥 steeper surrogate

    def init_mem(self, batch_size, features, device):
        return torch.zeros(batch_size, features, device=device)

# 🔥 Improved SNN Model with LayerNorm
class SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.norm1 = nn.LayerNorm(hidden_size)  # 🔥 normalization
        self.lif1 = LIFNeuron(dropout=0.2)       # 🔥 slight dropout
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.lif2 = LIFNeuron()

    def forward(self, x):
        batch_size, steps, features = x.shape
        spk2_rec = []
        mem1 = self.lif1.init_mem(batch_size, self.fc1.out_features, x.device)
        mem2 = self.lif2.init_mem(batch_size, self.fc2.out_features, x.device)

        for step in range(steps):
            cur1 = self.fc1(x[:, step, :])
            cur1 = self.norm1(cur1)        # 🔥 normalize activations
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)

        return torch.stack(spk2_rec)

# 🔥 Rate encoding instead of Poisson
def rate_encode(x, num_steps):
    x = x.unsqueeze(0).expand(num_steps, *x.shape)
    return (torch.rand_like(x) < x).float()

# Config
input_size = 20 * FIXED_MFCC_LENGTH
hidden_size = 512     # 🔥 bigger hidden size
output_size = 10
num_epochs = 15       # 🔥 longer training
num_steps = 20

net = SNN(input_size, hidden_size, output_size).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.5)
criterion = nn.CrossEntropyLoss()

# Train
for epoch in range(num_epochs):
    net.train()
    running_loss, correct, total_spikes = 0, 0, 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = inputs / inputs.max()
        inputs = inputs.view(inputs.size(0), inputs.size(2), inputs.size(1))

        spk_in = rate_encode(inputs, num_steps)
        spk_in = spk_in.permute(1, 0, 2, 3)
        spk_in = spk_in.reshape(spk_in.size(0), spk_in.size(1), -1)

        spk2_rec = net(spk_in)
        spk_count = spk2_rec.sum(0) / num_steps

        loss = criterion(spk_count, targets)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()
        correct += (spk_count.argmax(1) == targets).sum().item()
        total_spikes += spk2_rec.sum().item()

    scheduler.step()

    train_acc = correct / len(train_loader.dataset)
    avg_loss = running_loss / len(train_loader)
    avg_spike_rate = total_spikes / (len(train_loader.dataset) * num_steps * hidden_size)

    # Validation
    net.eval()
    correct = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs / inputs.max()
            inputs = inputs.view(inputs.size(0), inputs.size(2), inputs.size(1))

            spk_in = rate_encode(inputs, num_steps)
            spk_in = spk_in.permute(1, 0, 2, 3)
            spk_in = spk_in.reshape(spk_in.size(0), spk_in.size(1), -1)

            spk2_rec = net(spk_in)
            spk_count = spk2_rec.sum(0) / num_steps

            correct += (spk_count.argmax(1) == targets).sum().item()

    val_acc = correct / len(val_loader.dataset)

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {train_acc:.4f}, Validation Accuracy: {val_acc:.4f}, Avg Spike Rate: {avg_spike_rate:.4f}")


Epoch 1, Loss: 2.3626, Accuracy: 0.1042, Validation Accuracy: 0.1059, Avg Spike Rate: 0.0047
Epoch 2, Loss: 2.3046, Accuracy: 0.1070, Validation Accuracy: 0.1118, Avg Spike Rate: 0.0001
Epoch 3, Loss: 2.3026, Accuracy: 0.1069, Validation Accuracy: 0.1110, Avg Spike Rate: 0.0000
Epoch 4, Loss: 2.3026, Accuracy: 0.1054, Validation Accuracy: 0.1080, Avg Spike Rate: 0.0000


KeyboardInterrupt: 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
from torch.utils.data import DataLoader
import random

# Set random seeds
torch.manual_seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants
data_path = './speechcommands_data'
target_classes = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
class_to_idx = {label: idx for idx, label in enumerate(target_classes)}
FIXED_MFCC_LENGTH = 100

# Preprocessing with proper normalization
def preprocess(waveform, sample_rate):
    transform = torchaudio.transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=20,
        melkwargs={'n_fft': 400, 'hop_length': 160, 'n_mels': 64}
    )
    mfcc = transform(waveform).squeeze(0)
    time_steps = mfcc.shape[1]

    if time_steps < FIXED_MFCC_LENGTH:
        pad_amt = FIXED_MFCC_LENGTH - time_steps
        mfcc = F.pad(mfcc, (0, pad_amt))
    else:
        mfcc = mfcc[:, :FIXED_MFCC_LENGTH]

    # Shift and scale MFCCs to [0, 1]
    mfcc = (mfcc - mfcc.min()) / (mfcc.max() - mfcc.min() + 1e-8)
    return mfcc

# Dataset
class SubsetSC(torch.utils.data.Dataset):
    def __init__(self, samples):
        self.samples = [(waveform, sample_rate, label) for waveform, sample_rate, label, *_ in samples if label in target_classes]

    def __getitem__(self, index):
        waveform, sample_rate, label = self.samples[index]
        mfcc = preprocess(waveform, sample_rate)
        return mfcc, class_to_idx[label]

    def __len__(self):
        return len(self.samples)

# Load Data
full_train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_path, download=True, subset='training')
full_val_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_path, download=True, subset='validation')

train_dataset = SubsetSC(full_train_dataset)
val_dataset = SubsetSC(full_val_dataset)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)

# LIF Neuron with sharper surrogate
class LIFNeuron(nn.Module):
    def __init__(self, tau=2.0):
        super().__init__()
        self.tau = tau

    def forward(self, x, mem):
        mem = mem * self.tau + x
        spk = self.surrogate_spike(mem - 1.0)
        mem = mem * (1.0 - spk)
        return spk, mem

    def surrogate_spike(self, x):
        return torch.sigmoid(10 * x)  # Steeper surrogate

    def init_mem(self, batch_size, features, device):
        return torch.zeros(batch_size, features, device=device)

# Simplified SNN Model without LayerNorm or Dropout
class SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = LIFNeuron()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.lif2 = LIFNeuron()

    def forward(self, x):
        batch_size, steps, features = x.shape
        spk2_rec = []
        mem1 = self.lif1.init_mem(batch_size, self.fc1.out_features, x.device)
        mem2 = self.lif2.init_mem(batch_size, self.fc2.out_features, x.device)

        for step in range(steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)

        return torch.stack(spk2_rec)

# Rate encoding with increased time steps
num_steps = 50  # Increased from 20 to 50
def rate_encode(x, num_steps):
    x = x.unsqueeze(0).expand(num_steps, *x.shape)
    return (torch.rand_like(x) < x).float()

# Config
input_size = 20 * FIXED_MFCC_LENGTH
hidden_size = 256  # Adjusted hidden size
output_size = 10
num_epochs = 15
net = SNN(input_size, hidden_size, output_size).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)  # Slower decay
criterion = nn.CrossEntropyLoss()

# Train
for epoch in range(num_epochs):
    net.train()
    running_loss, correct, total_spikes = 0, 0, 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = inputs.view(inputs.size(0), inputs.size(2), inputs.size(1))  # [batch_size, num_frames, num_mfcc]

        spk_in = rate_encode(inputs, num_steps)  # [num_steps, batch_size, num_frames, num_mfcc]
        spk_in = spk_in.permute(1, 0, 2, 3)  # [batch_size, num_steps, num_frames, num_mfcc]
        spk_in = spk_in.reshape(spk_in.size(0), spk_in.size(1), -1)  # [batch_size, num_steps, input_size]

        spk2_rec = net(spk_in)
        spk_count = spk2_rec.sum(0)  # Sum over time steps: [batch_size, num_classes]

        loss = criterion(spk_count, targets)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()
        correct += (spk_count.argmax(1) == targets).sum().item()
        total_spikes += spk2_rec.sum().item()

    scheduler.step()

    train_acc = correct / len(train_loader.dataset)
    avg_loss = running_loss / len(train_loader)
    avg_spike_rate = total_spikes / (len(train_loader.dataset) * num_steps * hidden_size)

    # Validation
    net.eval()
    correct = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.view(inputs.size(0), inputs.size(2), inputs.size(1))

            spk_in = rate_encode(inputs, num_steps)
            spk_in = spk_in.permute(1, 0, 2, 3)
            spk_in = spk_in.reshape(spk_in.size(0), spk_in.size(1), -1)

            spk2_rec = net(spk_in)
            spk_count = spk2_rec.sum(0)  # Sum over time steps

            correct += (spk_count.argmax(1) == targets).sum().item()

    val_acc = correct / len(val_loader.dataset)

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {train_acc:.4f}, Validation Accuracy: {val_acc:.4f}, Avg Spike Rate: {avg_spike_rate:.4f}")

Epoch 1, Loss: 9.0104, Accuracy: 0.0982, Validation Accuracy: 0.1007, Avg Spike Rate: 0.0334
Epoch 2, Loss: 7.1013, Accuracy: 0.0977, Validation Accuracy: 0.1018, Avg Spike Rate: 0.0351
Epoch 3, Loss: 7.1014, Accuracy: 0.1018, Validation Accuracy: 0.1018, Avg Spike Rate: 0.0351
Epoch 4, Loss: 7.0994, Accuracy: 0.0969, Validation Accuracy: 0.0945, Avg Spike Rate: 0.0351
Epoch 5, Loss: 7.1025, Accuracy: 0.0957, Validation Accuracy: 0.0945, Avg Spike Rate: 0.0351


KeyboardInterrupt: 

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
from torch.utils.data import DataLoader
import random

# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants
data_path = './speechcommands_data'
target_classes = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
class_to_idx = {label: idx for idx, label in enumerate(target_classes)}
FIXED_MFCC_LENGTH = 100
num_steps = 100  # Match the number of MFCC frames
num_mfcc = 20  # Number of MFCC coefficients per frame
num_epochs = 10  # Define number of epochs

# Preprocessing with proper normalization
def preprocess(waveform, sample_rate):
    transform = torchaudio.transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=num_mfcc,
        melkwargs={'n_fft': 400, 'hop_length': 160, 'n_mels': 64}
    )
    mfcc = transform(waveform).squeeze(0)  # Shape: [num_mfcc, num_frames]
    time_steps = mfcc.shape[1]

    if time_steps < FIXED_MFCC_LENGTH:
        pad_amt = FIXED_MFCC_LENGTH - time_steps
        mfcc = F.pad(mfcc, (0, pad_amt))
    else:
        mfcc = mfcc[:, :FIXED_MFCC_LENGTH]

    # Normalize MFCCs to [0, 1]
    mfcc_min = mfcc.min()
    mfcc_max = mfcc.max()
    mfcc = (mfcc - mfcc_min) / (mfcc_max - mfcc_min + 1e-8)
    return mfcc  # Shape: [num_mfcc, FIXED_MFCC_LENGTH]

# Dataset
class SubsetSC(torch.utils.data.Dataset):
    def __init__(self, samples):
        self.samples = [(waveform, sample_rate, label) for waveform, sample_rate, label, *_ in samples if label in target_classes]

    def __getitem__(self, index):
        waveform, sample_rate, label = self.samples[index]
        mfcc = preprocess(waveform, sample_rate)  # Shape: [num_mfcc, FIXED_MFCC_LENGTH]
        mfcc = mfcc.transpose(0, 1)  # Shape: [FIXED_MFCC_LENGTH, num_mfcc]
        return mfcc, class_to_idx[label]

    def __len__(self):
        return len(self.samples)

# Load Data
full_train_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_path, download=True, subset='training')
full_val_dataset = torchaudio.datasets.SPEECHCOMMANDS(root=data_path, download=True, subset='validation')

train_dataset = SubsetSC(full_train_dataset)
val_dataset = SubsetSC(full_val_dataset)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=False)

# LIF Neuron with corrected tau and simple surrogate
class LIFNeuron(nn.Module):
    def __init__(self, tau=0.9):
        super().__init__()
        self.tau = tau

    def forward(self, x, mem):
        mem = mem * self.tau + x
        spk = (mem >= 1.0).float()  # Simple step function for spikes
        mem = mem * (1.0 - spk)  # Reset membrane potential
        return spk, mem

    def init_mem(self, batch_size, features, device):
        return torch.zeros(batch_size, features, device=device)

# SNN Model without LayerNorm or Dropout
class SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = LIFNeuron(tau=0.9)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.lif2 = LIFNeuron(tau=0.9)

    def forward(self, x):
        batch_size, steps, features = x.shape
        mem1 = self.lif1.init_mem(batch_size, self.fc1.out_features, x.device)
        mem2 = self.lif2.init_mem(batch_size, self.fc2.out_features, x.device)
        spk2_rec = []

        for step in range(steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)

        return torch.stack(spk2_rec, dim=0)  # [steps, batch_size, output_size]

# Rate encoding for sequential input
def rate_encode(mfcc, num_steps):
    # mfcc: [batch_size, num_frames, num_mfcc]
    batch_size, num_frames, num_mfcc = mfcc.shape
    assert num_frames == num_steps, "Number of frames must match num_steps"
    # Generate spikes for each frame
    spikes = (torch.rand(batch_size, num_steps, num_mfcc, device=device) < mfcc).float()
    return spikes  # [batch_size, num_steps, num_mfcc]

# Config
input_size = num_mfcc  # 20 MFCC coefficients per frame
hidden_size = 256
output_size = 10

net = SNN(input_size, hidden_size, output_size).to(device)
optimizer = optim.Adam(net.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)
criterion = nn.CrossEntropyLoss()

# Train
for epoch in range(num_epochs):
    net.train()
    running_loss, correct, total_spikes = 0, 0, 0

    for mfcc, targets in train_loader:
        mfcc, targets = mfcc.to(device), targets.to(device)  # mfcc: [batch_size, 100, 20]
        batch_size_actual = mfcc.size(0)

        # Rate encoding: generate spikes for each frame
        spk_in = rate_encode(mfcc, num_steps)  # [batch_size, 100, 20]

        # Forward pass
        spk2_rec = net(spk_in)  # [100, batch_size, 10]
        spk_count = spk2_rec.sum(0)  # [batch_size, 10]

        loss = criterion(spk_count, targets)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()
        correct += (spk_count.argmax(1) == targets).sum().item()
        total_spikes += spk2_rec.sum().item()

    scheduler.step()

    train_acc = correct / len(train_loader.dataset)
    avg_loss = running_loss / len(train_loader)
    avg_spike_rate = total_spikes / (len(train_loader.dataset) * num_steps * hidden_size)

    # Validation
    net.eval()
    correct = 0
    with torch.no_grad():
        for mfcc, targets in val_loader:
            mfcc, targets = mfcc.to(device), targets.to(device)
            batch_size_actual = mfcc.size(0)

            spk_in = rate_encode(mfcc, num_steps)
            spk2_rec = net(spk_in)
            spk_count = spk2_rec.sum(0)
            correct += (spk_count.argmax(1) == targets).sum().item()

    val_acc = correct / len(val_loader.dataset)

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {train_acc:.4f}, Validation Accuracy: {val_acc:.4f}, Avg Spike Rate: {avg_spike_rate:.4f}")

MemoryError: Unable to allocate 62.5 KiB for an array with shape (16000, 1) and data type float32

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchaudio.datasets import SPEECHCOMMANDS
import torchaudio
import os
import random
import numpy as np

# ---------------------
# 1. Dataset Setup
# ---------------------
class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./", download=True)
        def load_list(filename):
            with open(os.path.join(self._path, filename)) as f:
                return [os.path.join(self._path, line.strip()) for line in f]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

labels = sorted(list(set(dat[2] for dat in SubsetSC("training"))))
label_to_index = {label: i for i, label in enumerate(labels)}

def label_to_tensor(label):
    return torch.tensor(label_to_index[label])

def pad_sequence(batch):
    tensors, targets = zip(*batch)
    tensors = [t.squeeze(0).t() for t in tensors]
    tensors = nn.utils.rnn.pad_sequence(tensors, batch_first=True).permute(0, 2, 1)
    targets = torch.stack(targets)
    return tensors, targets

def collate_fn(batch):
    batch = [(torchaudio.transforms.MFCC()(waveform), label_to_tensor(label)) for waveform, _, label, *_ in batch]
    return pad_sequence(batch)

train_set = SubsetSC("training")
val_set = SubsetSC("validation")
test_set = SubsetSC("testing")

train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=64, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=64, collate_fn=collate_fn)

# ---------------------
# 2. RadLIF SNN Module
# ---------------------
class RadLIFNeuron(nn.Module):
    def __init__(self, size, beta=0.9):
        super().__init__()
        self.size = size
        self.beta = beta
        self.register_buffer("membrane", torch.zeros(size))

    def forward(self, input):
        self.membrane = self.beta * self.membrane + input
        out = (self.membrane > 1.0).float()
        self.membrane *= (1 - out)
        return out

class RadLIFNet(nn.Module):
    def __init__(self, input_size, layer_sizes, output_size, dropout=0.2):
        super().__init__()
        layers = []
        in_size = input_size
        self.neurons = nn.ModuleList()

        for h_size in layer_sizes:
            layers.append(nn.Linear(in_size, h_size))
            layers.append(nn.LayerNorm(h_size))
            layers.append(nn.Dropout(dropout))
            self.neurons.append(RadLIFNeuron(h_size))
            in_size = h_size

        self.hidden = nn.Sequential(*layers)
        self.output = nn.Linear(in_size, output_size)

    def forward(self, x):
        batch_size, channels, steps = x.shape
        x = x.permute(2, 0, 1)  # T x B x C
        spike_sum = torch.zeros(batch_size, len(labels), device=x.device)

        for t in range(x.size(0)):
            out = x[t]
            for i, layer in enumerate(self.hidden):
                out = layer(out)
                if isinstance(layer, nn.Linear):
                    out = self.neurons[i](out)
            spike_sum += self.output(out)

        return spike_sum / x.size(0)

# ---------------------
# 3. Training Functions
# ---------------------
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        correct += (out.argmax(1) == y).sum().item()
        total += x.size(0)
    return total_loss / total, correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
            correct += (out.argmax(1) == y).sum().item()
            total += x.size(0)
    return total_loss / total, correct / total

# ---------------------
# 4. Run Training
# ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RadLIFNet(input_size=40, layer_sizes=[256, 256, 128], output_size=len(labels)).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(1, 21):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    print(f"Epoch {epoch:02d}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")




IndexError: index 3 is out of range

In [2]:
# Baseline SNN for Google Speech Commands using SpArch

# 1. Install dependencies
# Make sure these are installed in your environment:
# pip install torch torchaudio numpy matplotlib
# Clone SpArch if not done:
# git clone https://github.com/Intelligent-Computing-Lab-YZU/SpArch

import os
import torch
import torchaudio
import matplotlib.pyplot as plt
from sparch.models.radlif import RadLIFNet
from sparch.utils.trainer import Trainer
from sparch.utils.metrics import calculate_accuracy
from sparch.utils.dataset import get_speechcommands_dataloaders

# 2. Configuration
config = {
    'model_type': 'RadLIF',
    'dataset_name': 'sc',
    'data_folder': './speech_commands',
    'batch_size': 64,
    'nb_epochs': 30,
    'lr': 1e-3,
    'beta1': 0.9,
    'beta2': 0.999,
    'weight_decay': 1e-4,
    'dropout': 0.2,
    'layer_sizes': [256, 256, 128],
    'bidirectional': True,
    'normalization': 'layernorm',
    'optimizer': 'adamw',
    'scheduler': 'cosine',
    't_max': 30,
    'spike_tracking': True,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

# 3. Load dataset
train_loader, val_loader, test_loader, num_classes = get_speechcommands_dataloaders(
    config['data_folder'],
    batch_size=config['batch_size']
)

# 4. Build model
model = RadLIFNet(
    input_size=40,  # log-Mel spectrogram features
    hidden_sizes=config['layer_sizes'],
    output_size=num_classes,
    dropout=config['dropout'],
    normalization=config['normalization'],
    bidirectional=config['bidirectional']
).to(config['device'])

# 5. Trainer setup
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    config=config
)

# 6. Train
train_loss, train_acc, val_loss, val_acc, spike_rates, total_spikes = trainer.train()

# 7. Evaluate
print("\nTest Accuracy:", trainer.evaluate())

# 8. Plot results
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(train_loss, label='Train Loss')
plt.plot(val_loss, label='Val Loss')
plt.title("Loss")
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_acc, label='Train Acc')
plt.plot(val_acc, label='Val Acc')
plt.title("Accuracy")
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(spike_rates, label='Spike Rate')
plt.plot(total_spikes, label='Total Spikes')
plt.title("Spike Activity")
plt.legend()

plt.tight_layout()
plt.show()


ModuleNotFoundError: No module named 'sparch'

In [3]:
# Google Speech Commands - Baseline SNN using SpArch

# Install and import required libraries
!pip install torch torchaudio numpy matplotlib --quiet
!pip install git+https://github.com/SelinaJiang/SpArch.git --quiet

import os
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from sparch.datasets import get_dataset
from sparch.models import get_model
from sparch.utils.train import train
from sparch.utils.eval import evaluate
from sparch.utils.misc import count_spikes

# Dataset setup
DATA_FOLDER = './speech_commands'
if not os.path.exists(DATA_FOLDER):
    os.makedirs(DATA_FOLDER)

# Download the dataset (optional, if you haven't already downloaded it manually)
torchaudio.datasets.SPEECHCOMMANDS(root=DATA_FOLDER, download=True)

# Hyperparameters
params = {
    'model_type': 'RadLIF',
    'dataset_name': 'sc',
    'data_folder': DATA_FOLDER,
    'batch_size': 64,
    'nb_epochs': 100,
    'lr': 1e-3,
    'beta1': 0.9,
    'beta2': 0.999,
    'weight_decay': 1e-4,
    'dropout': 0.2,
    'layer_sizes': [256, 256, 128],
    'bidirectional': True,
    'normalization': 'layernorm',
    'optimizer': 'adamw',
    'scheduler': 'cosine',
    't_max': 100,
    'exp_folder': './experiments/gsc_baseline',
}

# Load dataset
data = get_dataset(
    dataset_name=params['dataset_name'],
    data_folder=params['data_folder'],
    batch_size=params['batch_size']
)

# Create model
model = get_model(
    model_type=params['model_type'],
    input_shape=data['input_shape'],
    nb_classes=data['nb_classes'],
    layer_sizes=params['layer_sizes'],
    bidirectional=params['bidirectional'],
    dropout=params['dropout'],
    normalization=params['normalization']
)

# Training
train(
    model=model,
    train_loader=data['train_loader'],
    valid_loader=data['valid_loader'],
    nb_epochs=params['nb_epochs'],
    optimizer_name=params['optimizer'],
    scheduler_name=params['scheduler'],
    lr=params['lr'],
    weight_decay=params['weight_decay'],
    beta1=params['beta1'],
    beta2=params['beta2'],
    t_max=params['t_max'],
    exp_folder=params['exp_folder']
)

# Evaluation
test_loss, test_acc = evaluate(
    model=model,
    data_loader=data['test_loader'],
    nb_classes=data['nb_classes'],
    exp_folder=params['exp_folder']
)
print(f"\nTest Accuracy: {test_acc*100:.2f}%")

# Measure spike activity
print("\nMeasuring spike activity...")
model.eval()
spike_info = count_spikes(
    model=model,
    data_loader=data['test_loader'],
    nb_batches=10  # You can increase this for more accuracy
)
print(f"Average Spike Rate per Neuron: {spike_info['spike_rate']:.4f}")
print(f"Total Spikes: {spike_info['total_spikes']}")

# Plot training results
log_path = os.path.join(params['exp_folder'], 'train_log.npy')
if os.path.exists(log_path):
    logs = np.load(log_path, allow_pickle=True).item()
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(logs['train_acc'], label='Train Acc')
    plt.plot(logs['val_acc'], label='Val Acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.title('Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(logs['train_loss'], label='Train Loss')
    plt.plot(logs['val_loss'], label='Val Loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.title('Loss')
    plt.tight_layout(); plt.show()
else:
    
    print("Training logs not found. Skipping plots.")


  error: subprocess-exited-with-error
  
  × git clone --filter=blob:none --quiet https://github.com/SelinaJiang/SpArch.git 'C:\Users\Aswin Kumar\AppData\Local\Temp\pip-req-build-f5ttamej' did not run successfully.
  │ exit code: 128
  ╰─> [2 lines of output]
      remote: Repository not found.
      fatal: repository 'https://github.com/SelinaJiang/SpArch.git/' not found
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× git clone --filter=blob:none --quiet https://github.com/SelinaJiang/SpArch.git 'C:\Users\Aswin Kumar\AppData\Local\Temp\pip-req-build-f5ttamej' did not run successfully.
│ exit code: 128
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.


ModuleNotFoundError: No module named 'sparch'