In [10]:
import os
from pathlib import Path
from glob import glob
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import librosa
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

In [14]:
class SpeechCommandsDataset(Dataset):
    def __init__(self, data_dir, target_sr=16000, transform=None, fixed_num_samples=16000):
        self.data_dir = data_dir
        self.target_sr = target_sr
        self.transform = transform
        self.fixed_num_samples = fixed_num_samples

        self.classes = sorted(
            d.name for d in os.scandir(data_dir) if d.is_dir()
        )
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

        self.files = []
        for cls in self.classes:
            pattern = os.path.join(data_dir, cls, "*.wav")
            for path in glob(pattern):
                self.files.append((path, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        path, label = self.files[idx]

        y, sr = librosa.load(path, sr=self.target_sr, mono=True)

        if len(y) > self.fixed_num_samples:
            y = y[:self.fixed_num_samples]
        elif len(y) < self.fixed_num_samples:
            pad = self.fixed_num_samples - len(y)
            y = np.pad(y, (0, pad), mode="constant")

        waveform = torch.from_numpy(y).float().unsqueeze(0)
        if self.transform:
            waveform = self.transform(waveform)

        return waveform, label

In [18]:
data_dir = 'datasets/speech_commands_split/'

mel_spectrogram = MelSpectrogram(
    sample_rate=16000,
    n_mels=64,
    f_min=20,
    f_max=8000
)
amp_to_db = AmplitudeToDB()

def waveform_to_logmel(waveform):
    mel = mel_spectrogram(waveform)
    logmel = amp_to_db(mel)
    return logmel

train_ds = SpeechCommandsDataset(os.path.join(data_dir, 'train'), transform=waveform_to_logmel)
val_ds = SpeechCommandsDataset(os.path.join(data_dir, 'val'), transform=waveform_to_logmel)
test_ds = SpeechCommandsDataset(os.path.join(data_dir, 'test'), transform=waveform_to_logmel)

batch_size = 64

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

num_classes = len(os.listdir(os.path.join(data_dir, 'train')))

In [4]:
class SelfAttentionBlock(nn.Module):
    def __init__(self, d_model=128, n_heads=7, d_ff=256, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None):
        attn_out, _ = self.attn(
            x, x, x,
            key_padding_mask=key_padding_mask
        )
        x = self.norm1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

In [5]:
class CNNAttentionSpeechCommands(nn.Module):
    def __init__(self, n_classes, d_model=128, n_heads=4, d_ff=256, n_layers=2):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
        )

        self.proj = nn.Linear(64, d_model)

        self.layers = nn.ModuleList([
            SelfAttentionBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])

        self.classifier = nn.Linear(d_model, n_classes)

    def forward(self, x):
        x = self.cnn(x)
        B, C, Fp, Tp = x.shape

        x = x.mean(dim=2)
        x = x.permute(0, 2, 1)

        x = self.proj(x)

        for layer in self.layers:
            x = layer(x)
        
        x = x.mean(dim=1)

        logits = self.classifier(x)
        return logits

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for feats, labels in tqdm(loader, desc='Train:'):
        feats = feats.to(device)  
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(feats)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * feats.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for feats, labels in tqdm(loader, desc='Val:'):
        feats = feats.to(device)
        labels = labels.to(device)

        outputs = model(feats)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * feats.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

In [8]:
model = CNNAttentionSpeechCommands(
    n_classes=num_classes,
    d_model=128,
    n_heads=4,
    d_ff=256,
    n_layers=2
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [19]:
num_epochs = 20

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}')
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    print(f'    train_loss: {train_loss:.4f} acc: {train_acc:.4%}')
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    print(f'    val_loss: {val_loss:.4f} acc: {val_acc:.2%}')

Epoch 1


Train:: 100%|██████████| 181/181 [00:14<00:00, 12.45it/s]


    train_loss: 0.6375 acc: 77.8558%


Val:: 100%|██████████| 39/39 [00:13<00:00,  2.81it/s]


    val_loss: 0.6939 acc: 75.97%
Epoch 2


Train:: 100%|██████████| 181/181 [00:15<00:00, 12.03it/s]


    train_loss: 0.4874 acc: 83.0428%


Val:: 100%|██████████| 39/39 [00:02<00:00, 13.85it/s]


    val_loss: 0.4386 acc: 84.66%
Epoch 3


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.38it/s]


    train_loss: 0.4195 acc: 85.7316%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.22it/s]


    val_loss: 0.3958 acc: 86.53%
Epoch 4


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.75it/s]


    train_loss: 0.3502 acc: 88.3858%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.73it/s]


    val_loss: 0.4537 acc: 84.29%
Epoch 5


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.48it/s]


    train_loss: 0.3312 acc: 88.6720%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.72it/s]


    val_loss: 0.3339 acc: 88.80%
Epoch 6


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.80it/s]


    train_loss: 0.2969 acc: 89.9644%


Val:: 100%|██████████| 39/39 [00:02<00:00, 14.99it/s]


    val_loss: 0.2866 acc: 90.62%
Epoch 7


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.50it/s]


    train_loss: 0.2669 acc: 90.7624%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.48it/s]


    val_loss: 0.3062 acc: 89.41%
Epoch 8


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.68it/s]


    train_loss: 0.2661 acc: 90.6583%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.22it/s]


    val_loss: 0.2887 acc: 90.46%
Epoch 9


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.54it/s]


    train_loss: 0.2263 acc: 91.9507%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.09it/s]


    val_loss: 0.2838 acc: 90.95%
Epoch 10


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.53it/s]


    train_loss: 0.2188 acc: 92.4972%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.42it/s]


    val_loss: 0.2556 acc: 91.72%
Epoch 11


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.73it/s]


    train_loss: 0.2114 acc: 92.7140%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.49it/s]


    val_loss: 0.3807 acc: 87.82%
Epoch 12


Train:: 100%|██████████| 181/181 [00:16<00:00, 11.13it/s]


    train_loss: 0.2285 acc: 92.0028%


Val:: 100%|██████████| 39/39 [00:02<00:00, 14.66it/s]


    val_loss: 0.2404 acc: 92.45%
Epoch 13


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.38it/s]


    train_loss: 0.1998 acc: 93.0783%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.42it/s]


    val_loss: 0.2491 acc: 92.05%
Epoch 14


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.57it/s]


    train_loss: 0.1734 acc: 93.9544%


Val:: 100%|██████████| 39/39 [00:02<00:00, 14.84it/s]


    val_loss: 0.3045 acc: 89.94%
Epoch 15


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.53it/s]


    train_loss: 0.1691 acc: 94.1712%


Val:: 100%|██████████| 39/39 [00:02<00:00, 14.55it/s]


    val_loss: 0.6140 acc: 83.40%
Epoch 16


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.45it/s]


    train_loss: 0.1975 acc: 93.1043%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.21it/s]


    val_loss: 0.2165 acc: 92.78%
Epoch 17


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.53it/s]


    train_loss: 0.1720 acc: 94.0411%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.55it/s]


    val_loss: 0.2460 acc: 92.00%
Epoch 18


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.67it/s]


    train_loss: 0.1536 acc: 94.9258%


Val:: 100%|██████████| 39/39 [00:02<00:00, 15.45it/s]


    val_loss: 0.2799 acc: 91.11%
Epoch 19


Train:: 100%|██████████| 181/181 [00:15<00:00, 11.35it/s]


    train_loss: 0.1589 acc: 94.5095%


Val:: 100%|██████████| 39/39 [00:02<00:00, 14.48it/s]


    val_loss: 0.2409 acc: 92.86%
Epoch 20


Train:: 100%|██████████| 181/181 [00:16<00:00, 11.00it/s]


    train_loss: 0.1445 acc: 94.8911%


Val:: 100%|██████████| 39/39 [00:02<00:00, 14.88it/s]

    val_loss: 0.2613 acc: 92.45%





In [21]:
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for feats, labels in tqdm(test_loader):
        feats = feats.to(device)
        labels = labels.to(device)

        outputs = model(feats)
        loss = criterion(outputs, labels)

        _, preds = outputs.max(1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

100%|██████████| 39/39 [00:15<00:00,  2.56it/s]


In [22]:
from sklearn.metrics import classification_report

In [23]:
print(classification_report(all_labels, all_preds))

              precision    recall  f1-score   support

           0       0.90      0.97      0.93       354
           1       0.98      0.92      0.95       354
           2       0.86      0.94      0.90       354
           3       0.95      0.88      0.91       354
           4       0.93      0.93      0.93       354
           5       0.95      0.96      0.95       354
           6       0.94      0.92      0.93       354

    accuracy                           0.93      2478
   macro avg       0.93      0.93      0.93      2478
weighted avg       0.93      0.93      0.93      2478

