In [1]:
from datasets import load_from_disk
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchmetrics
from tqdm import tqdm
import torchaudio
from datasets import load_dataset, DatasetDict
import numpy as np
import random


SEED = 42
SAMPLE_RATE = 16000
BATCH_SIZE = 18
NUM_WORKERS = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
print(f"Using device: {DEVICE}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
dataset = load_from_disk("../../data/datasets/1_augmented_spectogram_test")

In [3]:
# Create the datasets splits, train valid and test
ds_split = dataset.train_test_split(test_size=0.3, seed=SEED, shuffle=True)
test_and_valid = ds_split["test"].train_test_split(test_size=0.5, seed=SEED, shuffle=True)

ds = DatasetDict({
    "train": ds_split["train"],
    "valid": test_and_valid["train"],
    "test": test_and_valid["test"],
})

print("Dataset splits:", {k: v.shape for k, v in ds.items()})

Dataset splits: {'train': (4446, 2), 'valid': (953, 2), 'test': (953, 2)}


In [4]:
# -----------------------------
# Collate function with Mel computation
# -----------------------------
# Define augmentations
# augmentations = torch.nn.Sequential(
#     torchaudio.transforms.Vol(gain=random.uniform(-5, 5)),
# )

def collate_fn(batch):
    xs, ys = [], []

    for b in batch:
        waveform = torch.tensor(b["audio"]).float()

        if waveform.ndim > 1:
            print(waveform.ndim)
            waveform = waveform.mean(dim=0)


        xs.append(waveform)
        ys.append(b["label"])

    max_len = max(x.shape[-1] for x in xs)
    xs_padded = torch.zeros((len(xs), 1, 128, max_len))
    for i, x in enumerate(xs):
        xs_padded[i, 0, :, :x.shape[-1]] = x

    return xs_padded, torch.tensor(ys).float().unsqueeze(1)
train_loader = DataLoader(ds["train"], batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, collate_fn=collate_fn)
valid_loader = DataLoader(ds["valid"], batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, collate_fn=collate_fn)
test_loader = DataLoader(ds["test"], batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKERS, collate_fn=collate_fn)

In [5]:
# -----------------------------
# Simplified CNN
# -----------------------------
class SimpleAudioCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.PReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.PReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Sequential(
            nn.Linear(64, 64),
            nn.PReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

model = SimpleAudioCNN().to(DEVICE)

In [6]:
# -----------------------------
# Loss, optimizer, metric
# -----------------------------
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
metric_acc = torchmetrics.classification.BinaryAccuracy().to(DEVICE)

In [7]:
# -----------------------------
# Training loop
# -----------------------------
EPOCHS = 5
best_val_acc = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss, train_acc = 0, 0

    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x.size(0)
        train_acc += metric_acc(out, y) * x.size(0)

    scheduler.step()
    train_loss /= len(ds["train"])
    train_acc /= len(ds["train"])

    # Validation
    model.eval()
    val_loss, val_acc = 0, 0
    with torch.no_grad():
        for x, y in tqdm(valid_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Valid]"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            out = model(x)
            loss = criterion(out, y)
            val_loss += loss.item() * x.size(0)
            val_acc += metric_acc(out, y) * x.size(0)

    val_loss /= len(ds["valid"])
    val_acc /= len(ds["valid"])

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | "
          f"Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_simple_cnn.pt")
        print("✅ Saved new best model!")

Epoch 1/5 [Train]: 100%|██████████| 247/247 [00:41<00:00,  5.94it/s]
Epoch 1/5 [Valid]: 100%|██████████| 53/53 [00:04<00:00, 12.51it/s]


Epoch 1/5 | Train Loss: 0.3913 | Train Acc: 0.8599 | Val Loss: 0.3207 | Val Acc: 0.8751
✅ Saved new best model!


Epoch 2/5 [Train]: 100%|██████████| 247/247 [00:40<00:00,  6.07it/s]
Epoch 2/5 [Valid]: 100%|██████████| 53/53 [00:04<00:00, 12.57it/s]


Epoch 2/5 | Train Loss: 0.3253 | Train Acc: 0.8754 | Val Loss: 0.3675 | Val Acc: 0.8751


Epoch 3/5 [Train]: 100%|██████████| 247/247 [00:40<00:00,  6.09it/s]
Epoch 3/5 [Valid]: 100%|██████████| 53/53 [00:04<00:00, 12.47it/s]


Epoch 3/5 | Train Loss: 0.3184 | Train Acc: 0.8797 | Val Loss: 0.2925 | Val Acc: 0.8783
✅ Saved new best model!


Epoch 4/5 [Train]: 100%|██████████| 247/247 [00:40<00:00,  6.08it/s]
Epoch 4/5 [Valid]: 100%|██████████| 53/53 [00:04<00:00, 12.56it/s]


Epoch 4/5 | Train Loss: 0.3007 | Train Acc: 0.8828 | Val Loss: 0.2970 | Val Acc: 0.8825
✅ Saved new best model!


Epoch 5/5 [Train]: 100%|██████████| 247/247 [00:40<00:00,  6.08it/s]
Epoch 5/5 [Valid]: 100%|██████████| 53/53 [00:04<00:00, 12.43it/s]

Epoch 5/5 | Train Loss: 0.3023 | Train Acc: 0.8835 | Val Loss: 0.2823 | Val Acc: 0.8825





In [11]:
with torch.no_grad():
    torch.cuda.empty_cache()