In [1]:
from pathlib import Path
import os
from glob import glob
import librosa
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import Wav2Vec2Model
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

In [2]:
class SpeechCommandsDataset(Dataset):
    def __init__(self, data_dir, target_sr=16000):
        self.data_dir = data_dir
        self.target_sr = target_sr

        self.classes = sorted(
            d.name for d in os.scandir(data_dir) if d.is_dir()
        )
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

        self.files = []
        for cls in self.classes:
            pattern = os.path.join(data_dir, cls, "*.wav")
            for path in glob(pattern):
                self.files.append((path, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        path, label = self.files[idx]

        y, sr = librosa.load(path, sr=self.target_sr, mono=True)
        waveform = torch.from_numpy(y)

        return {'input_values': waveform, 'labels': label}

In [3]:
wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
wav2vec.eval()



Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
  

In [4]:
def collate_fn(batch):
    inputs = [b["input_values"] for b in batch]
    labels = torch.tensor([b["labels"] for b in batch], dtype=torch.long)

    input_values_padded = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)

    return {"input_values": input_values_padded, "labels": labels}

In [5]:
def extract_features(wav2vec, input_values):
    with torch.no_grad():
        out = wav2vec(input_values=input_values, output_hidden_states=False)
    feats = out.last_hidden_state.mean(dim=1)
    return feats

In [6]:
class Wav2Vec2Classifier(nn.Module):
    def __init__(self, backbone_name="facebook/wav2vec2-base", num_classes=35):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(backbone_name)
        self.classifier = nn.Linear(self.backbone.config.hidden_size, num_classes)

    def forward(self, input_values, freeze_backbone=True):
        if freeze_backbone:
            with torch.no_grad():
                out = self.backbone(input_values=input_values)
        else:
            out = self.backbone(input_values=input_values)
        x = out.last_hidden_state.mean(dim=1)
        logits = self.classifier(x)
        return logits

In [7]:
data_dir = 'datasets/speech_commands/'

num_classes = len(os.listdir(data_dir))

In [8]:
model = Wav2Vec2Classifier(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [12]:
data_dir_split = 'datasets/speech_commands_split/'

train_ds = SpeechCommandsDataset(os.path.join(data_dir_split, 'train'))
val_ds = SpeechCommandsDataset(os.path.join(data_dir_split, 'val'))
test_ds = SpeechCommandsDataset(os.path.join(data_dir_split, 'test'))

In [13]:
import random
from collections import defaultdict

def subset_per_class(dataset, max_per_class=1000, seed=42):
    random.seed(seed)
    by_class = defaultdict(list)
    for idx, (_, label) in enumerate(dataset.files):
        by_class[label].append(idx)

    keep_indices = []
    for label, idxs in by_class.items():
        if len(idxs) > max_per_class:
            idxs = random.sample(idxs, max_per_class)
        keep_indices.extend(idxs)

    subset = torch.utils.data.Subset(dataset, keep_indices)
    return subset

train_ds = subset_per_class(train_ds, max_per_class=700)
val_ds = subset_per_class(val_ds, max_per_class=150)
test_ds = subset_per_class(test_ds, max_per_class=150)

In [14]:
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [16]:
def train_one_epoch(model, loader, optimizer, freeze_backbone):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in tqdm(loader, desc='Train:'):
        inputs = batch["input_values"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer.zero_grad()
        
        logits = model(inputs, freeze_backbone=freeze_backbone)
        loss = criterion(logits, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
    return total_loss / len(loader), correct / total

def validate(model, loader, freeze_backbone=True):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Val:'):
            inputs = batch["input_values"].to(device)
            labels = batch["labels"].to(device)
            
            logits = model(inputs, freeze_backbone=freeze_backbone)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
    return total_loss / len(loader), correct / total

In [18]:
for epoch in range(1):
    print(f'Epoch {epoch+1}')
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, freeze_backbone=True)
    print(f'    train_loss: {train_loss:.4f} acc: {train_acc:.4%}')
    val_loss, val_acc = validate(model, val_loader, freeze_backbone=True)
    print(f'    val_loss: {val_loss:.4f} acc: {val_acc:.2%}')

Epoch 1


Train:: 100%|██████████| 307/307 [00:24<00:00, 12.78it/s]


    train_loss: 1.4992 acc: 68.9388%


Val:: 100%|██████████| 66/66 [00:13<00:00,  4.80it/s]

    val_loss: 1.5724 acc: 55.52%





In [19]:
optimizer_ft = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [20]:
for epoch in range(2):
    print(f'Epoch {epoch+1}')
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, freeze_backbone=False)
    print(f'    train_loss: {train_loss:.4f} acc: {train_acc:.4%}')
    val_loss, val_acc = validate(model, val_loader, freeze_backbone=False)
    print(f'    val_loss: {val_loss:.4f} acc: {val_acc:.2%}')

Epoch 1


Train:: 100%|██████████| 307/307 [01:24<00:00,  3.63it/s]


    train_loss: 0.4252 acc: 86.8776%


Val:: 100%|██████████| 66/66 [00:05<00:00, 11.67it/s]


    val_loss: 0.1029 acc: 96.76%
Epoch 2


Train:: 100%|██████████| 307/307 [10:54<00:00,  2.13s/it]


    train_loss: 0.2893 acc: 90.4082%


Val:: 100%|██████████| 66/66 [00:26<00:00,  2.53it/s]

    val_loss: 0.1686 acc: 94.76%





In [21]:
model.eval()
correct = 0
total = 0
total_loss = 0

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Val:'):
        inputs = batch["input_values"].to(device)
        labels = batch["labels"].to(device)
        
        logits = model(inputs, freeze_backbone=False)
        loss = criterion(logits, labels)
        
        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

avg_loss = total_loss / len(test_loader)
accuracy = correct / total

Val:: 100%|██████████| 66/66 [00:33<00:00,  1.97it/s]


In [22]:
print(f'loss: {avg_loss:.4f}, acc: {accuracy:.4f}')

loss: 0.1688, acc: 0.9457


In [None]:
from sklearn.metrics import classification_report

In [24]:
print(classification_report(all_labels, all_preds))

              precision    recall  f1-score   support

           0       0.96      0.99      0.97       150
           1       0.86      0.99      0.92       150
           2       0.98      0.87      0.93       150
           3       0.97      0.92      0.94       150
           4       0.98      0.91      0.94       150
           5       0.98      0.98      0.98       150
           6       0.92      0.96      0.94       150

    accuracy                           0.95      1050
   macro avg       0.95      0.95      0.95      1050
weighted avg       0.95      0.95      0.95      1050

