---
# <center>Dataset creation

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import torch.nn.functional as F
from tqdm import tqdm

#################################
# Dataset and Data Preparation  #
#################################

class SubsetSC(torchaudio.datasets.SPEECHCOMMANDS):
    def __init__(self, subset: str = None, root="./datasets/speechcommand", download=True):
        super().__init__(root=root, download=download)
        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as f:
                return [os.path.join(self._path, line.strip()) for line in f]
        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]


class InMemorySpeechCommands(Dataset):
    def __init__(self, subset="training", fixed_length=16000):
        """
        Args:
          subset: One of "training", "validation", "testing".
          fixed_length: The desired number of audio samples per clip.
        """
        self.fixed_length = fixed_length
        self.dataset = SubsetSC(subset=subset, root="./datasets/speechcommand", download=True)
        self.data = []
        self.labels = []
        
        # Build label index from the dataset's actual folder location
        dataset_root = self.dataset._path  # use the path used by the dataset
        all_labels = [
            d for d in os.listdir(dataset_root)
            if os.path.isdir(os.path.join(dataset_root, d)) and d != '_background_noise_'
        ]
        self.label_set = sorted(all_labels)
        self.label_to_idx = {label: idx for idx, label in enumerate(self.label_set)}
        print(f"Found {len(self.label_set)} labels: {self.label_set}")
        
        # Load all samples into memory (resample, pad/trim)
        for waveform, sample_rate, label, *_ in tqdm(self.dataset, desc=f"Loading {subset} data", leave=False):
            if sample_rate != 16000:
                resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
                waveform = resampler(waveform)
            waveform = waveform.squeeze(0)  # Remove channel dimension if exists
            if waveform.size(0) > fixed_length:
                waveform = waveform[:fixed_length]
            elif waveform.size(0) < fixed_length:
                waveform = F.pad(waveform, (0, fixed_length - waveform.size(0)))
            self.data.append(waveform)
            try:
                self.labels.append(self.label_to_idx[label])
            except KeyError:
                # If a label is encountered that's not in our mapping, add it
                # (or alternatively, you could skip it or raise an error)
                new_idx = len(self.label_to_idx)
                print(f"New label found: {label}. Assigning new index {new_idx}")
                self.label_to_idx[label] = new_idx
                self.label_set.append(label)
                self.labels.append(new_idx)
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.fixed_length, self.labels[idx]



---
# <center>Model Architecture

In [2]:
class MaskedConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.padding = padding
        self.dilation = dilation
        self.stride = stride
        self.kernel_size = kernel_size

    def forward(self, x, length):
        max_length = x.size(2)
        length = torch.div(((length + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1).float() + self.stride),
                           self.stride, rounding_mode='floor').long()
        mask = torch.arange(max_length, device=x.device)[None, :] < length[:, None]
        x = x * mask.unsqueeze(1)
        x = self.conv(x)
        return x, length

class JasperBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, dropout=0.0, residual=False):
        super().__init__()
        self.mconv = nn.ModuleList([
            MaskedConv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False),
            MaskedConv1d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm1d(out_channels)
        ])
        self.res = None
        if residual:
            self.res = nn.ModuleList([
                nn.ModuleList([
                    MaskedConv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm1d(out_channels)
                ])
            ])
        self.mout = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout)
        )

    def forward(self, x, length):
        residual = x
        res_length = length
        out = x
        out_length = length
        for layer in self.mconv:
            if isinstance(layer, MaskedConv1d):
                out, out_length = layer(out, out_length)
            else:
                out = layer(out)
        if self.res:
            for res_layer_list in self.res:
                res = residual
                for layer in res_layer_list:
                    if isinstance(layer, MaskedConv1d):
                        res, _ = layer(res, res_length)
                    else:
                        res = layer(res)
                residual = res
        if self.res is not None:
            out = out + residual
        out = self.mout(out)
        return out, out_length

class ConvASREncoder(nn.Module):
    def __init__(self, in_channels, blocks_params):
        super().__init__()
        layers = [JasperBlock(**params) for params in blocks_params]
        self.encoder = nn.Sequential(*layers)
        self.in_channels = in_channels

    def forward(self, x, length):
        for layer in self.encoder:
            x, length = layer(x, length)
        return x, length

class AudioToMFCCPreprocessor(nn.Module):
    def __init__(self, sample_rate=16000, n_mels=64, n_mfcc=64, n_fft=512, hop_length=160, f_min=0, f_max=8000):
        super().__init__()
        self.featurizer = T.MFCC(
            sample_rate=sample_rate,
            n_mfcc=n_mfcc,
            melkwargs={
                "n_fft": n_fft,
                "n_mels": n_mels,
                "hop_length": hop_length,
                "f_min": f_min,
                "f_max": f_max,
            },
        )

    def forward(self, x, length):
        with torch.no_grad():
            x = self.featurizer(x)
        return x, length

class ConvASRDecoderClassification(nn.Module):
    def __init__(self, in_features, num_classes):
        super().__init__()
        self.pooling = nn.AdaptiveAvgPool1d(output_size=1)
        self.decoder_layers = nn.Sequential(
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        x = self.pooling(x)
        x = x.squeeze(2)
        x = self.decoder_layers(x)
        return x

class TopKClassificationAccuracy(nn.Module):
    def __init__(self, k=(1,)):
        super().__init__()
        self.k = k

    def forward(self, logits, targets):
        with torch.no_grad():
            maxk = max(self.k)
            batch_size = targets.size(0)
            _, pred = logits.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(targets.view(1, -1).expand_as(pred))
            res = []
            for k in self.k:
                correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
            return res[0] if len(res) == 1 else res

class EncDecClassificationModel(nn.Module):
    def __init__(self, num_classes, sample_rate=16000, n_mels=64, n_mfcc=64, n_fft=512, hop_length=160, f_min=0, f_max=8000):
        super().__init__()
        self.preprocessor = AudioToMFCCPreprocessor(
            sample_rate=sample_rate, n_mels=n_mels, n_mfcc=n_mfcc,
            n_fft=n_fft, hop_length=hop_length, f_min=f_min, f_max=f_max)
        blocks_params = [
            {"in_channels": n_mfcc, "out_channels": 128, "kernel_size": 11, "stride": 1, "padding": 5, "dilation": 1, "dropout": 0.0, "residual": False},
            {"in_channels": 128, "out_channels": 64, "kernel_size": 13, "stride": 1, "padding": 6, "dilation": 1, "dropout": 0.0, "residual": True},
            {"in_channels": 64, "out_channels": 64, "kernel_size": 15, "stride": 1, "padding": 7, "dilation": 1, "dropout": 0.0, "residual": True},
            {"in_channels": 64, "out_channels": 64, "kernel_size": 17, "stride": 1, "padding": 8, "dilation": 1, "dropout": 0.0, "residual": True},
            {"in_channels": 64, "out_channels": 128, "kernel_size": 29, "stride": 1, "padding": 28, "dilation": 2, "dropout": 0.0, "residual": False},
            {"in_channels": 128, "out_channels": 128, "kernel_size": 1, "stride": 1, "padding": 0, "dilation": 1, "dropout": 0.0, "residual": False},
        ]
        self.encoder = ConvASREncoder(in_channels=n_mfcc, blocks_params=blocks_params)
        self.decoder = ConvASRDecoderClassification(in_features=128, num_classes=num_classes)
        self.loss = nn.CrossEntropyLoss()
        self._accuracy = TopKClassificationAccuracy()

    def forward(self, x, length, y=None):
        x, length = self.preprocessor(x, length)
        x, length = self.encoder(x, length)
        logits = self.decoder(x)
        if y is not None:
            loss = self.loss(logits, y)
            acc = self._accuracy(logits, y)
            return loss, acc, logits
        else:
            return logits

    def predict(self, x, length):
        with torch.no_grad():
            logits = self.forward(x, length)
            return torch.argmax(logits, dim=-1)


---
# <center>Training

In [3]:
def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    running_accuracy = 0.0
    total_samples = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False)
    for waveforms, lengths, labels in progress_bar:
        waveforms = waveforms.to(device)
        lengths = torch.tensor([lengths]).to(device) if isinstance(lengths, int) else lengths.to(device)
        labels = torch.tensor(labels).to(device) if isinstance(labels, int) else labels.to(device)

        optimizer.zero_grad()
        loss, acc, _ = model(waveforms, lengths, labels)
        loss.backward()
        optimizer.step()

        batch_size = waveforms.size(0)
        running_loss += loss.item() * batch_size
        running_accuracy += acc.item() * batch_size
        total_samples += batch_size

        progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc.item():.2f}%"})
    epoch_loss = running_loss / total_samples
    epoch_accuracy = running_accuracy / total_samples
    return epoch_loss, epoch_accuracy

def validate_epoch(model, val_loader, device, epoch):
    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0
    total_samples = 0

    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False)
    with torch.no_grad():
        for waveforms, lengths, labels in progress_bar:
            waveforms = waveforms.to(device)
            lengths = torch.tensor([lengths]).to(device) if isinstance(lengths, int) else lengths.to(device)
            labels = torch.tensor(labels).to(device) if isinstance(labels, int) else labels.to(device)

            loss, acc, _ = model(waveforms, lengths, labels)
            batch_size = waveforms.size(0)
            running_loss += loss.item() * batch_size
            running_accuracy += acc.item() * batch_size
            total_samples += batch_size

            progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc.item():.2f}%"})
    epoch_loss = running_loss / total_samples
    epoch_accuracy = running_accuracy / total_samples
    return epoch_loss, epoch_accuracy

In [None]:
# Configuration
num_epochs = 5
batch_size = 64
learning_rate = 0.001
sample_rate = 16000
fixed_length = sample_rate  # 1 second clips (SpeechCommands are ~1 sec)

# Prepare datasets (in-memory)
print("Preparing datasets...")
train_dataset = InMemorySpeechCommands(subset="training", fixed_length=fixed_length)
val_dataset = InMemorySpeechCommands(subset="validation", fixed_length=fixed_length)
num_classes = len(train_dataset.label_set)
print("Label set:", train_dataset.label_set)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Initialize model, optimizer
model = EncDecClassificationModel(num_classes=num_classes, sample_rate=sample_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Wrap epoch loop with tqdm for overall progress monitoring
for epoch in tqdm(range(1, num_epochs + 1), desc="Epochs"):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device, epoch)
    val_loss, val_acc = validate_epoch(model, val_loader, device, epoch)
    print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    # Save the best model
    if val_acc > 0:
        torch.save(model.state_dict(), "best_model.pth")
        print("Saved best model!")
print("Training complete.")

Preparing datasets...
Found 35 labels: ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']


Loading training data:  15%|█▍        | 15473/105829 [00:24<02:25, 620.14it/s]