In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MFCC
import torch.nn.functional as F
import torchaudio
import os
import glob
from tqdm import tqdm
import random


class MaskedConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.padding = padding
        self.dilation = dilation
        self.stride = stride
        self.kernel_size = kernel_size

    def forward(self, x, length):
        max_length = x.size(2)
        length = torch.div(((length + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1).float() + self.stride), self.stride, rounding_mode='floor').long()
        mask = torch.arange(max_length, device=x.device)[None, :] < length[:, None]
        x = x * mask.unsqueeze(1)
        x = self.conv(x)
        return x, length

class JasperBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, dropout=0.0, residual=False):
        super().__init__()
        self.mconv = nn.ModuleList([
            MaskedConv1d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False),
            MaskedConv1d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm1d(out_channels)
        ])
        self.res = None
        if residual:
            self.res = nn.ModuleList([
                nn.ModuleList([
                    MaskedConv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm1d(out_channels)
                ])
            ])
        self.mout = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout)
        )

    def forward(self, x, length):
        residual = x
        res_length = length
        out = x
        out_length = length
        for layer in self.mconv:
            if isinstance(layer, MaskedConv1d):
                out, out_length = layer(out, out_length)
            else:
                out = layer(out)
        if self.res:
          for res_layer_list in self.res:
              res = residual
              for layer in res_layer_list:
                  if isinstance(layer, MaskedConv1d):
                      res, _ = layer(res, res_length)
                  else:
                      res = layer(res)
              residual = res
        if self.res is not None:
            out = out + residual
        out = self.mout(out)
        return out, out_length


class ConvASREncoder(nn.Module):
    def __init__(self, in_channels, blocks_params):
        super().__init__()
        layers = []
        for params in blocks_params:
            layers.append(JasperBlock(**params))
        self.encoder = nn.Sequential(*layers)
        self.in_channels = in_channels

    def forward(self, x, length):
        for layer in self.encoder:
            x, length = layer(x, length)
        return x, length

class ConvASRDecoderClassification(nn.Module):
    def __init__(self, in_features, num_classes):
        super().__init__()
        self.pooling = nn.AdaptiveAvgPool1d(output_size=1)
        self.decoder_layers = nn.Sequential(
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        x = self.pooling(x)
        x = x.squeeze(2)
        x = self.decoder_layers(x)
        return x

class AudioToMFCCPreprocessor(nn.Module):
    def __init__(self, sample_rate=16000, n_mels=64, n_mfcc=64, n_fft=512, hop_length=160, f_min=0, f_max=8000):
        super().__init__()
        self.featurizer = MFCC(
            sample_rate=sample_rate,
            n_mfcc=n_mfcc,
            melkwargs={
                "n_fft": n_fft,
                "n_mels": n_mels,
                "hop_length": hop_length,
                "f_min": f_min,
                "f_max": f_max,
            },
        )

    def forward(self, x, length):
      with torch.no_grad():
        x = self.featurizer(x)
        return x, length


class SpecCutout(nn.Module):
    def __init__(self, rect_masks=5, rect_time_masks=10):
        super().__init__()
        self.rect_masks = rect_masks
        self.rect_time_masks = rect_time_masks

    def forward(self, specgram, length):
        batch_size, _, time_len = specgram.shape
        for _ in range(self.rect_time_masks):
            cutout_length = torch.randint(0, self.rect_masks, (batch_size,))
            offset = torch.randint(0, time_len, (batch_size,))
            for i in range(batch_size):
              actual_end = offset[i] + cutout_length[i]
              if actual_end < length[i] :
                specgram[i, :, offset[i]:actual_end] = 0
        return specgram, length


class SpecAugment(nn.Module):
    def __init__(self, freq_masks=2, freq_width=27, time_masks=10, time_width=0.05):
        super().__init__()
        self.freq_masks = freq_masks
        self.freq_width = freq_width
        self.time_masks = time_masks
        self.time_width = time_width

    def forward(self, specgram, length):
      batch_size, n_mels, time_len = specgram.shape
      for _ in range(self.freq_masks):
        for i in range(batch_size):
          f = torch.randint(low=0, high=self.freq_width, size=(1,)).item()
          f0 = torch.randint(low=0, high=n_mels - f, size=(1,)).item()
          specgram[i, f0 : f0 + f, :] = 0
      for _ in range(self.time_masks):
        for i in range(batch_size):
          t = int(self.time_width * length[i].item())
          t0 = torch.randint(low=0, high=length[i].item() - t, size=(1,)).item()
          specgram[i, :, t0 : t0 + t] = 0
      return specgram, length



class SpectrogramAugmentation(nn.Module):
    def __init__(self):
        super().__init__()
        self.spec_cutout = SpecCutout()
        self.spec_augment = SpecAugment()

    def forward(self, x, length):
      x, length = self.spec_cutout(x, length)
      x, length = self.spec_augment(x, length)
      return x, length

class CropOrPadSpectrogramAugmentation(nn.Module):
    def __init__(self, audio_length=1024):
        super().__init__()
        self.audio_length = audio_length

    def forward(self, x, length):
        batch_size, _, seq_len = x.size()
        if seq_len > self.audio_length:
            offset = torch.randint(0, seq_len - self.audio_length + 1, (batch_size,))
            new_x = torch.zeros((batch_size, x.size(1), self.audio_length), device=x.device)
            for i in range(batch_size):
                new_x[i] = x[i, :, offset[i]:offset[i] + self.audio_length]
            length = torch.tensor([self.audio_length] * batch_size, device=x.device)
            return new_x, length
        elif seq_len < self.audio_length:
            pad_amount = self.audio_length - seq_len
            padded_x = F.pad(x, (0,pad_amount) , "constant", 0)
            length = length + pad_amount # do not increase length because mfcc features will be computed on padded signal
            return padded_x, length
        else:
          return x, length

class EncDecClassificationModel(nn.Module):
    def __init__(self, num_classes, sample_rate=16000, n_mels=64, n_mfcc=64, n_fft=512, hop_length=160, f_min=0, f_max=8000, audio_length=16000*2):
        super().__init__()
        self.spec_augmentation = SpectrogramAugmentation()
        self.crop_or_pad = CropOrPadSpectrogramAugmentation(audio_length=audio_length)
        self.preprocessor = AudioToMFCCPreprocessor(sample_rate=sample_rate, n_mels=n_mels, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length, f_min=f_min,f_max=f_max)
        blocks_params = [
            {"in_channels": n_mfcc, "out_channels": 128, "kernel_size": 11, "stride": 1, "padding": 5, "dilation": 1, "dropout": 0.0, "residual":False},
            {"in_channels": 128, "out_channels": 64, "kernel_size": 13, "stride": 1, "padding": 6, "dilation": 1, "dropout": 0.0, "residual": True},
            {"in_channels": 64, "out_channels": 64, "kernel_size": 15, "stride": 1, "padding": 7, "dilation": 1, "dropout": 0.0, "residual": True},
            {"in_channels": 64, "out_channels": 64, "kernel_size": 17, "stride": 1, "padding": 8, "dilation": 1, "dropout": 0.0, "residual": True},
            {"in_channels": 64, "out_channels": 128, "kernel_size": 29, "stride": 1, "padding": 28, "dilation": 2, "dropout": 0.0, "residual": False},
            {"in_channels": 128, "out_channels": 128, "kernel_size": 1, "stride": 1, "padding": 0, "dilation": 1, "dropout": 0.0, "residual":False},
          ]
        self.encoder = ConvASREncoder(in_channels=n_mfcc, blocks_params=blocks_params)
        self.decoder = ConvASRDecoderClassification(in_features=128, num_classes=num_classes)
        self.loss = nn.CrossEntropyLoss()
        self._accuracy = TopKClassificationAccuracy()

    def forward(self, x, length, y=None):
        x, length = self.preprocessor(x, length)
        x, length = self.spec_augmentation(x, length)
        x, length = self.crop_or_pad(x, length)
        x, length = self.encoder(x, length)
        logits = self.decoder(x)
        if y is not None:
          loss = self.loss(logits, y)
          acc = self._accuracy(logits, y)
          return loss, acc, logits
        else:
          return logits

    def predict(self, x, length):
      with torch.no_grad():
        logits = self.forward(x,length)
        return torch.argmax(logits, dim=-1)

class TopKClassificationAccuracy(nn.Module):
    def __init__(self, k=(1,)):
        super().__init__()
        self.k = k

    def forward(self, logits, targets):
        with torch.no_grad():
            maxk = max(self.k)
            batch_size = targets.size(0)
            _, pred = logits.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(targets.view(1, -1).expand_as(pred))
            res = []
            for k in self.k:
                correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
            return res[0] if len(res) == 1 else res

class SpeechCommandsDataset(Dataset):
    def __init__(self, root_dir, keywords, split, transform=None, fixed_length=32000):
        super().__init__()
        self.root_dir = root_dir
        self.keywords = keywords
        self.split = split
        self.transform = transform
        self.fixed_length = fixed_length
        self.file_paths = []
        self.labels = []
        self.label_to_idx = {keyword: i for i, keyword in enumerate(keywords)}
        self.label_to_idx["unknown"] = len(self.label_to_idx)
        self.idx_to_label = {i: keyword for keyword, i in self.label_to_idx.items()}

        if self.split in ["validation", "testing"]:
            list_file = os.path.join(self.root_dir, f"{self.split}_list.txt")
            with open(list_file, 'r') as f:
                file_list = [line.strip() for line in f]
            for file_path in file_list:
                label = file_path.split('/')[0]
                if label in self.keywords:
                    self.file_paths.append(os.path.join(self.root_dir, file_path))
                    self.labels.append(self.label_to_idx[label])
                else:
                    self.file_paths.append(os.path.join(self.root_dir, file_path))
                    self.labels.append(self.label_to_idx["unknown"])

        elif self.split == "training":
            self._create_training_data()

        else:
            raise ValueError("Invalid split. Must be 'training', 'validation', or 'testing'.")

        if not self.file_paths:
            raise ValueError(f"No .wav files found for split '{self.split}' in {root_dir}.")

        if self.split == "training":
            self.balance_dataset()  # Call balance_dataset *after* creating the initial file list


    def _create_training_data(self):
        """Creates the training data: keywords and unknown, excluding val/test."""

        keyword_files = []
        unknown_files = []

        all_subdirs = [d for d in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir, d))]

        for subdir in all_subdirs:
            subdir_path = os.path.join(self.root_dir, subdir)
            files = [os.path.join(self.root_dir, subdir, f) for f in os.listdir(subdir_path) if f.endswith(".wav")] # Full path

            if subdir in self.keywords:
                keyword_files.extend([(f, self.label_to_idx[subdir]) for f in files])
            elif subdir != '_background_noise_':
                unknown_files.extend([(f, self.label_to_idx["unknown"]) for f in files])

        # Load validation and testing lists (full paths)
        validation_files = set()
        testing_files = set()
        with open(os.path.join(self.root_dir, "validation_list.txt"), 'r') as f:
            validation_files.update(os.path.join(self.root_dir, line.strip()) for line in f)  # Full path
        with open(os.path.join(self.root_dir, "testing_list.txt"), 'r') as f:
            testing_files.update(os.path.join(self.root_dir, line.strip()) for line in f)  # Full path

        # Filter, using full paths for correct exclusion
        training_keyword_files = [(f, lbl) for f, lbl in keyword_files if f not in validation_files and f not in testing_files]
        training_unknown_files = [(f, lbl) for f, lbl in unknown_files if f not in validation_files and f not in testing_files]

        # Combine, and populate self.file_paths and self.labels
        training_files = training_keyword_files + training_unknown_files
        for file_path, label in training_files:
            self.file_paths.append(file_path)
            self.labels.append(label)



    def balance_dataset(self):
        """Balances the dataset by oversampling the minority class."""

        # Count occurrences of each label
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        # Find the maximum count (majority class)
        max_count = max(label_counts.values())

        # Oversample the minority class
        new_file_paths = []
        new_labels = []

        for label, count in label_counts.items():
            indices = [i for i, l in enumerate(self.labels) if l == label]

            # Add all original samples
            for idx in indices:
                new_file_paths.append(self.file_paths[idx])
                new_labels.append(self.labels[idx])

            # Oversample if needed
            if count < max_count:
                num_samples_to_add = max_count - count
                samples_to_add = random.choices(indices, k=num_samples_to_add)  # with replacement
                for idx in samples_to_add:
                    new_file_paths.append(self.file_paths[idx])  # Duplicate the file path
                    new_labels.append(self.labels[idx]) # Duplicate the label

        # Update the dataset's file paths and labels
        self.file_paths = new_file_paths
        self.labels = new_labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.labels[idx]
        try:
            waveform, sample_rate = torchaudio.load(file_path)
        except RuntimeError as e:
            print(f"Error loading {file_path}: {e}")
            waveform = torch.zeros(1, self.fixed_length)
            label = -1  # Keep -1 for failed loads

        waveform = self.pad_or_trim(waveform, sample_rate)
        length = torch.tensor([waveform.shape[1]])
        if self.transform:
            waveform, length = self.transform(waveform, length)
        return waveform.squeeze(0), length.squeeze(0), label

    def pad_or_trim(self, waveform, sample_rate):
        num_frames = waveform.shape[1]
        target_frames = self.fixed_length
        if num_frames > target_frames:
            waveform = waveform[:, :target_frames]
        elif num_frames < target_frames:
            padding = target_frames - num_frames
            waveform = F.pad(waveform, (0, padding))
        return waveform

def create_data_loaders(root_dir, keywords, batch_size, sample_rate=16000, audio_length=32000):
    train_dataset = SpeechCommandsDataset(root_dir, keywords, split="training", fixed_length=audio_length)
    val_dataset = SpeechCommandsDataset(root_dir, keywords, split="validation", fixed_length=audio_length)
    # test_dataset = SpeechCommandsDataset(root_dir, keywords, split="testing", fixed_length=audio_length)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader  # , test_loader

def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    running_accuracy = 0.0
    total_samples = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False)
    for waveforms, lengths, labels in progress_bar:
        valid_indices = labels != -1
        if not torch.any(valid_indices):
          continue

        waveforms = waveforms[valid_indices].to(device)
        lengths = lengths[valid_indices].to(device)
        labels = labels[valid_indices].to(device)
        optimizer.zero_grad()
        loss, acc, _ = model(waveforms, lengths, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * waveforms.size(0)
        running_accuracy += acc[0].item() * waveforms.size(0)
        total_samples += waveforms.size(0)
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc[0].item():.4f}"})


    epoch_loss = running_loss / total_samples
    epoch_accuracy = running_accuracy / total_samples
    return epoch_loss, epoch_accuracy


def validate_epoch(model, val_loader, device, epoch):
    model.eval()
    running_loss = 0.0
    running_accuracy = 0.0
    total_samples = 0
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False)
        for waveforms, lengths, labels in progress_bar:
            valid_indices = labels != -1
            if not torch.any(valid_indices):
                continue
            waveforms = waveforms[valid_indices].to(device)
            lengths = lengths[valid_indices].to(device)
            labels = labels[valid_indices].to(device)
            loss, acc, logits = model(waveforms, lengths, labels)
            running_loss += loss.item() * waveforms.size(0)
            running_accuracy += acc[0].item() * waveforms.size(0)
            total_samples += waveforms.size(0)
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{acc[0].item():.4f}"})
            _, predicted = torch.max(logits, 1)
            all_predictions.extend(predicted.cpu().tolist())
            all_targets.extend(labels.cpu().tolist())
    epoch_loss = running_loss / total_samples
    epoch_accuracy = running_accuracy / total_samples
    return epoch_loss, epoch_accuracy, all_predictions, all_targets


# --- Configuration ---
num_epochs = 5
batch_size = 64
learning_rate = 0.001
num_classes = 2
keywords = ["marvin"]
sample_rate = 16000
audio_length = sample_rate * 2

# Dataset Path
root_dir = "/mnt/c/Users/Andrea/Downloads/wake_word_detection/data"
if not os.path.exists(root_dir):
    raise ValueError("Please set root_dir to where the SpeechCommands dataset is located.")

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Model, Optimizer, DataLoaders ---
model = EncDecClassificationModel(num_classes=num_classes, sample_rate=sample_rate, audio_length=audio_length).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_loader, val_loader = create_data_loaders(root_dir, keywords, batch_size, sample_rate, audio_length)

# --- Training Loop ---
best_val_accuracy = 0.0
for epoch in range(1, num_epochs + 1):
    train_loss, train_accuracy = train_epoch(model, train_loader, optimizer, device, epoch)
    val_loss, val_accuracy, _, _ = validate_epoch(model, val_loader, device, epoch)

    print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), "best_model.pth")
        print("Saved best model!")

print("Finished Training")

Using device: cuda


Epoch 1 [Train]:   0%|                        | 4/1552 [01:46<11:24:01, 26.51s/it, loss=0.7020, acc=45.3125]