In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchaudio 
from torch.utils.data import random_split
import torchvision
from torch.nn import functional as F

import os
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import tensorflow as tf
import pathlib

DATASET_PATH = 'data/mini_speech_commands'

data_dir = pathlib.Path(DATASET_PATH)
if not data_dir.exists(): 
  tf.keras.utils.get_file(
      'mini_speech_commands.zip',
      origin="http://storage.googleapis.com/download.tensorflow.org/data/mini_speech_commands.zip",
      extract=True,
      cache_dir='.', cache_subdir='data')

In [None]:
class TrimOrPadAudio:
    def __init__(self, target_length=16000):
        self.target_length = target_length

    def __call__(self, waveform):
        # Trim or pad to the target length
        if waveform.size(0) < self.target_length:
            waveform = F.pad(waveform, (0, self.target_length - waveform.size(0)))
        elif waveform.size(0) > self.target_length:
            waveform = waveform[:, :self.target_length]
        return waveform
    
class Spectrogram:
    def __init__(self, n_fft=255, hop_length=128, win_length=None, window_fn=torch.hann_window):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window_fn = window_fn

    def __call__(self, waveform):
        spectrogram = torchaudio.transforms.Spectrogram(
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window_fn=self.window_fn,
        )(waveform)
        
        spectrogram = torch.abs(spectrogram)

        spectrogram = spectrogram.unsqueeze(0)
        return spectrogram

In [None]:
class SpeechCommandDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = np.array(os.listdir(data_dir))
        self.classes = self.classes[(self.classes != 'README.md') & (self.classes != '.DS_Store')]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.file_list = self._build_file_list()

    def _build_file_list(self):
        file_list = []
        for cls in self.classes:
            class_path = os.path.join(self.data_dir, cls)
            files = [f for f in os.listdir(class_path) if f.endswith('.wav')]
            file_list.extend([(os.path.join(class_path, file), self.class_to_idx[cls]) for file in files])
        return file_list

    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        audio_path, label = self.file_list[idx]
        waveform, sample_rate = torchaudio.load(audio_path, format='wav')
        
        waveform = waveform.squeeze(0)
        if self.transform:
            waveform = self.transform(waveform)

        return waveform, label
    
transform = torchvision.transforms.Compose([
    TrimOrPadAudio(target_length=16000),
    Spectrogram()
])
dataset = SpeechCommandDataset(data_dir, transform=transform)

total_size = len(dataset)
train_size = int(0.8 * total_size)  # 80% for training, 20% for validation
val_size = total_size - train_size

# Use random_split to split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoader for training set
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Create DataLoader for validation set
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

In [None]:
model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Dropout(0.25),
    nn.Flatten(),
    nn.Linear(253952, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, len(dataset.classes))
)

In [None]:
    
import pytorch_lightning as L
from torchmetrics import Accuracy
import torch
import torch.nn as nn
from collections import defaultdict

class Model(L.LightningModule):
    
    def __init__(self,
                loss=nn.CrossEntropyLoss()):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(253952, 128)
        self.relu3 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, len(dataset.classes))

        self.criterion = loss
        self.accuracy_metric = Accuracy(num_classes=len(dataset.classes), task="multiclass")
        self.history = defaultdict(lambda:[])
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x
    
    def training_step(self, batch, batch_idx):
        accuracy, loss, predictions = self._common_step(batch, batch_idx)
        
        self.training_step_outputs = {
            'training_loss':loss,
            'training_accuracy':accuracy
        }
        self.log_dict(self.training_step_outputs, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        accuracy, loss, predictions = self._common_step(batch, batch_idx)
        
        self.validation_step_outputs = {
            'validation_loss':loss,
            'validation_accuracy':accuracy
        }
        self.log_dict(self.validation_step_outputs, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        accuracy, loss, predictions = self._common_step(batch, batch_idx)
        
        self.test_step_outputs = {
            'test_loss':loss,
            'test_accuracy':accuracy
        }
        self.log_dict(self.test_step_outputs, prog_bar=True)
        return loss
    
    def _common_step(self, batch, batch_idx):
        images, labels = batch
        predictions = self.forward(images)
        loss = self.criterion(predictions, labels)
        accuracy = self.accuracy_metric(torch.softmax(predictions, axis=-1), labels)
        return accuracy, loss, predictions
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        self._common_on_batch_end(self.training_step_outputs)
    
    def on_validation_batch_end(self, outputs, batch, batch_idx):
        self._common_on_batch_end(self.validation_step_outputs)
    
    def on_test_batch_end(self, outputs, batch, batch_idx):
        self._common_on_batch_end(self.test_step_outputs)
    
    def _common_on_batch_end(self, step_outputs):
        loss, accuracy = step_outputs.keys()
        self.history[loss].append(step_outputs[loss])
        self.history[accuracy].append(step_outputs[accuracy])
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


In [None]:
model = Model()
trainer = L.Trainer(
    min_epochs=10, 
    max_epochs=20,
)
trainer.fit(model, train_loader, val_loader)
            