# 02. Model Prototyping

In this notebook, we define the PyTorch Dataset class to handle audio loading and preprocessing, and we design a Convolutional Neural Network (CNN) for classification.

## 1. Custom Dataset Class
We need a class that:

In [None]:
import os
import glob
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader

class SpokenDigitDataset(Dataset):
    def __init__(self, data_path, sample_rate=16000, n_mels=64, max_duration=1.0):
        self.data_path = data_path
        self.sample_rate = sample_rate
        self.max_length = int(sample_rate * max_duration)
        self.file_list = []
        self.labels = []
        
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate, n_mels=n_mels)
        self.db_transform = torchaudio.transforms.AmplitudeToDB()
        
        self._load_dataset()
        
    def _load_dataset(self):
        for label in range(10):
            label_dir = os.path.join(self.data_path, str(label))
            if not os.path.isdir(label_dir): continue
            # We look for .ogg and .wav files (converted from m4a)
            files = []
            for ext in ['*.ogg', '*.wav']:
                files.extend(glob.glob(os.path.join(label_dir, ext)))
                
            for f in files:
                self.file_list.append(f)
                self.labels.append(label)
                
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        
        waveform, sr = torchaudio.load(file_path)
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
        
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
            
        length_adj = self.max_length - waveform.shape[1]
        if length_adj > 0:
            waveform = torch.nn.functional.pad(waveform, (0, length_adj))
        else:
            waveform = waveform[:, :self.max_length]
            
        melspec = self.mel_transform(waveform)
        melspec = self.db_transform(melspec)
        
        return melspec, label

dataset = SpokenDigitDataset('../data/processed')
loader = DataLoader(dataset, batch_size=4, shuffle=True)
print(f"Total samples: {len(dataset)}")

## 2. CNN Model Architecture
We use a simple 4-layer CNN.

In [None]:
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2))
        self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(), nn.AdaptiveAvgPool2d((4, 4)))
        self.fc = nn.Sequential(nn.Flatten(), nn.Linear(128*4*4, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes))
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return self.fc(x)

model = SimpleCNN()
print(model)