# Speech Commands Classification with ResNet

Author: Jakub Borek, Bartosz Dybowski

Model with pre-trained model ResNet-18.

## Import

In [None]:
!pip install matplotlib
!pip install scikit-learn
!pip install soundfile

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchvision.models as models
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import os

## Configuration

In [None]:
use_all_classes = 2  # 0 for use all classes, 1 for use yes/no, 2 for use silence/unknown 3 for use 11 classes
batch_size = 256
learning_rate = 0.001
epochs = 50

## Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

print("Available GPUs:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Call seed function
set_seed(1)


## Prepare data

In [None]:
class SubsetSC(torchaudio.datasets.SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__(root="./SpeechCommands", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f:
                return [os.path.normpath(os.path.join(self._path, line.strip())) for line in f]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

# MelSpectrogram transform
transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=64)

# Prepare labels depending on the mode
if use_all_classes == 0:
    labels = sorted(list(set(datapoint[2] for datapoint in SubsetSC("training"))))
elif use_all_classes == 1:
    labels = ['yes', 'no']
elif use_all_classes == 2:
    labels = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'unknown', 'silence']
elif use_all_classes == 3:
    labels = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'unknown']

label_to_index = {label: i for i, label in enumerate(labels)}

# Load background noise files
background_noises = []
if use_all_classes == 2:
    background_dir = "./SpeechCommands/SpeechCommands/speech_commands_v0.02/_background_noise_"
    if os.path.exists(background_dir):
        print("Found noises")
        for filename in os.listdir(background_dir):
            if filename.endswith('.wav'):
                path = os.path.join(background_dir, filename)
                waveform, sr = torchaudio.load(path)
                background_noises.append(waveform.squeeze(0))

# Custom collate function
def collate_fn(batch, silence_probability=0.9):
    tensors = []
    targets = []
    max_len = 128
    silence_duration_samples = 16000  # (1s = 16000 samples)

    for waveform, sample_rate, label, *_ in batch:
        if use_all_classes == 1:
            if label not in ['yes', 'no']:
                label = 'unknown'
        elif use_all_classes == 2:
            if label not in ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']:
                label = 'unknown'
        elif use_all_classes == 3:
            if label not in ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']:
                label = 'unknown'

        if label not in labels:
            continue

        spec = transform(waveform).squeeze(0)

        if spec.shape[-1] > max_len:
            spec = spec[:, :max_len]
        elif spec.shape[-1] < max_len:
            pad_size = max_len - spec.shape[-1]
            spec = torch.nn.functional.pad(spec, (0, pad_size))

        tensors.append(spec)
        targets.append(label_to_index[label])

    # Inject background noise as silence
    if use_all_classes == 2 and background_noises:
        num_silence = int(len(tensors) * silence_probability)
        for _ in range(num_silence):
            noise = random.choice(background_noises)
            if noise.size(0) >= silence_duration_samples:
                start = random.randint(0, noise.size(0) - silence_duration_samples)
                silence_waveform = noise[start:start + silence_duration_samples]
            else:
                silence_waveform = torch.nn.functional.pad(noise, (0, silence_duration_samples - noise.size(0)))

            silence_spec = transform(silence_waveform.unsqueeze(0)).squeeze(0)
            if silence_spec.shape[-1] > max_len:
                silence_spec = silence_spec[:, :max_len]
            elif silence_spec.shape[-1] < max_len:
                pad_size = max_len - silence_spec.shape[-1]
                silence_spec = torch.nn.functional.pad(silence_spec, (0, pad_size))

            tensors.append(silence_spec)
            targets.append(label_to_index['silence'])

    if len(tensors) == 0:
        return torch.empty(0), torch.empty(0, dtype=torch.long)

    tensors = torch.stack(tensors)
    targets = torch.tensor(targets)
    return tensors, targets


## Load data

In [None]:
train_set = SubsetSC("training")
val_set = SubsetSC("validation")
test_set = SubsetSC("testing")

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=8, pin_memory=True)

## Model

In [None]:
class AudioResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

model = AudioResNet(num_classes=len(labels)).to(device)

## Model features

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

## Training

In [None]:
print("Training...")

def train(model, loader):
    model.train()
    running_loss = 0
    correct = 0
    total = 0
    for inputs, targets in loader:
        if inputs.size(0) == 0:
            continue
        
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = inputs.unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == targets).sum().item()
        total += targets.size(0)

    return running_loss / len(loader), correct / total


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    all_targets = []
    all_preds = []

    with torch.no_grad():
        for inputs, targets in loader:
            if inputs.numel() == 0:
                continue  # Skip empty batch
            
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.unsqueeze(1)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += targets.size(0)

            all_targets.extend(targets.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    if total == 0:
        return 0, [], []

    accuracy = correct / total
    return accuracy, all_targets, all_preds

train_losses, val_losses, test_losses = [], [], []
train_accs, val_accs, test_accs = [], [], []

best_val_acc = 0.0
best_epoch = 0

for epoch in range(epochs):
    train_loss, train_acc = train(model, train_loader)
    val_acc, _, _ = evaluate(model, val_loader)
    test_acc, _, _ = evaluate(model, test_loader)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    test_accs.append(test_acc)

    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_epoch = epoch + 1

    print(f"Epoch {epoch+1}: Train loss {train_loss:.4f}, Train acc {train_acc:.4f}, Val acc {val_acc:.4f}, Test acc {test_acc:.4f}")

    # Save model after each epoch
    torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")

## Load the best model

In [None]:
print(f"Loading best model from epoch {best_epoch} with val acc {best_val_acc:.4f}")
model.load_state_dict(torch.load(f"model_epoch_{best_epoch}.pt"))
model = model.to(device)

## Analysis

In [None]:
train_acc, all_targets, all_preds = evaluate(model, train_loader)
print(f"Train accuracy: {test_acc:.4f}")

val_acc_final, val_targets, val_preds = evaluate(model, val_loader)
print(f"Validation accuracy: {val_acc_final:.4f}")

test_acc, all_targets, all_preds = evaluate(model, test_loader)
print(f"Test accuracy: {test_acc:.4f}")

# Loss plot
plt.figure(figsize=(10,6))
plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss per Epoch')
plt.legend()
plt.grid(True)
plt.savefig("train_loss.jpg")
plt.show()

# Accuracy plot
plt.figure(figsize=(10,6))
plt.plot(range(1, epochs+1), train_accs, label='Train Accuracy')
plt.plot(range(1, epochs+1), val_accs, label='Validation Accuracy')
plt.plot(range(1, epochs+1), test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy per Epoch')
plt.legend()
plt.grid(True)
plt.savefig("accuracy_per_epoch.jpg")
plt.show()

# Generate confusion matrix
cm = confusion_matrix(all_targets, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
fig, ax = plt.subplots(figsize=(12, 12))
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
disp.plot(
    ax=ax,
    cmap='Blues',              
    colorbar=False,            
    values_format='d'         
)
plt.title("Confusion Matrix", fontsize=20)
plt.xlabel('Predicted label', fontsize=16)
plt.ylabel('True label', fontsize=16)
plt.xticks(rotation=45, fontsize=12)
plt.yticks(fontsize=12)
plt.grid(False)
plt.tight_layout()
plt.savefig("confusion_matrix.jpg")
plt.show()

# Easiest and hardest commands (Precision per class)
from sklearn.metrics import precision_score

precisions = precision_score(all_targets, all_preds, average=None)
for idx, label in enumerate(labels):
    print(f"Precision for class '{label}': {precisions[idx]:.4f}")

# Analysis of recording length (whether shorter recordings are easier to classify)

waveform_lengths = [waveform.shape[-1] for waveform, *_ in test_set]
short_indices = [i for i, l in enumerate(waveform_lengths) if l < 16000]  # less than 1 second
long_indices = [i for i, l in enumerate(waveform_lengths) if l >= 16000]

short_correct = sum(1 for i in short_indices if all_preds[i] == all_targets[i])
long_correct = sum(1 for i in long_indices if all_preds[i] == all_targets[i])

print(f"Short samples accuracy: {short_correct / len(short_indices):.4f}")
print(f"Long samples accuracy: {long_correct / len(long_indices):.4f}")

# Impact of loudness (mean amplitude of the signal)

amplitudes = [torch.abs(waveform).mean().item() for waveform, *_ in test_set]
high_amp_indices = [i for i, a in enumerate(amplitudes) if a > 0.05]
low_amp_indices = [i for i, a in enumerate(amplitudes) if a <= 0.05]

high_amp_correct = sum(1 for i in high_amp_indices if all_preds[i] == all_targets[i])
low_amp_correct = sum(1 for i in low_amp_indices if all_preds[i] == all_targets[i])

print(f"High amplitude samples accuracy: {high_amp_correct / len(high_amp_indices):.4f}")
print(f"Low amplitude samples accuracy: {low_amp_correct / len(low_amp_indices):.4f}")

# Pure analysis of "unknown" and "silence" classes

unknown_idx = [i for i, t in enumerate(all_targets) if labels[t] == 'unknown'] if 'unknown' in labels else []
silence_idx = [i for i, t in enumerate(all_targets) if labels[t] == 'silence'] if 'silence' in labels else []

if unknown_idx:
    unknown_correct = sum(1 for i in unknown_idx if all_preds[i] == all_targets[i])
    print(f"Unknown class accuracy: {unknown_correct / len(unknown_idx):.4f}")
if silence_idx:
    silence_correct = sum(1 for i in silence_idx if all_preds[i] == all_targets[i])
    print(f"Silence class accuracy: {silence_correct / len(silence_idx):.4f}")