In [10]:
# byol_audio.py
import os
import math
import time
import random
from typing import List, Tuple

import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import models
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score

# -------------------------
# Configuration (edit as needed)
# -------------------------
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 96
NUM_WORKERS = 8
SAMPLE_RATE = 16000
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 256
EPOCHS_BYOL = 5
EPOCHS_PROBE = 2
LR_BYOL = 3e-4
LR_PROBE = 1e-3
WEIGHT_DECAY = 1e-4
MODEL_DIR = "./ssl_checkpoints"
os.makedirs(MODEL_DIR, exist_ok=True)
BACKBONE = "resnet50"  # or "resnet18" to speed up testing

# reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [2]:
# -------------------------
# Utilities: Mel Preproc + Augmentations
# -------------------------
def waveform_to_logmel(waveform: torch.Tensor,
                       sample_rate=SAMPLE_RATE,
                       n_fft=N_FFT,
                       hop_length=HOP_LENGTH,
                       n_mels=N_MELS,
                       fmin=20,
                       fmax=8000) -> torch.Tensor:
    """
    waveform: 1D torch tensor (N,)
    returns: torch.FloatTensor shape [3, n_mels, time]
    """
    wav = waveform.cpu().numpy().astype(np.float32)
    mel = librosa.feature.melspectrogram(
        y=wav,
        sr=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
        power=2.0
    )
    mel_db = librosa.power_to_db(mel, ref=np.max)
    # normalize per-sample
    mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)
    mel_db = torch.from_numpy(mel_db).float().unsqueeze(0)  # [1, n_mels, T]
    mel_db = mel_db.repeat(3, 1, 1)  # 3 channels for ResNet
    return mel_db

# Simple waveform-level augmentations
def add_noise(wav: np.ndarray, snr_db: float = 20.0) -> np.ndarray:
    if snr_db <= 0:
        return wav
    rms = np.sqrt(np.mean(wav ** 2) + 1e-9)
    snr = 10 ** (snr_db / 20.0)
    noise_rms = rms / snr
    noise = np.random.normal(0, noise_rms, size=wav.shape)
    return wav + noise

def random_gain(wav: np.ndarray, min_gain_db=-6.0, max_gain_db=6.0) -> np.ndarray:
    g = np.random.uniform(min_gain_db, max_gain_db)
    factor = 10 ** (g / 20)
    return wav * factor

def random_time_shift(wav: np.ndarray, max_shift_seconds=0.05, sr=SAMPLE_RATE) -> np.ndarray:
    max_shift = int(max_shift_seconds * sr)
    if max_shift <= 0:
        return wav
    shift = np.random.randint(-max_shift, max_shift)
    if shift == 0:
        return wav
    return np.roll(wav, shift)

# SpecAugment-style (on Log-Mel)
def spec_augment(mel: torch.Tensor,
                 time_mask_max=10,
                 freq_mask_max=15,
                 n_time_masks=1,
                 n_freq_masks=1) -> torch.Tensor:
    """
    mel: [C, F, T] but we apply masks to last two dims; channels are identical anyway.
    Returns same shape.
    """
    _, F, T = mel.shape
    mel = mel.clone()
    for _ in range(n_freq_masks):
        f = np.random.randint(0, freq_mask_max + 1)
        if f == 0:
            continue
        f0 = np.random.randint(0, max(1, F - f))
        mel[:, f0:f0 + f, :] = 0.0
    for _ in range(n_time_masks):
        t = np.random.randint(0, time_mask_max + 1)
        if t == 0:
            continue
        t0 = np.random.randint(0, max(1, T - t))
        mel[:, :, t0:t0 + t] = 0.0
    return mel

# Augmentation pipeline returning two views
def make_two_views(waveform: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    waveform: torch tensor (N,)
    Returns two log-mel tensors: view1, view2 shape [3, n_mels, T]
    We'll apply different random augmentations to waveform, compute log-mel, then specaugment.
    """
    wav_np = waveform.cpu().numpy().astype(np.float32)

    # VIEW 1 (mild)
    v1 = wav_np.copy()
    v1 = add_noise(v1, snr_db=np.random.uniform(15, 30))
    v1 = random_gain(v1, -3.0, 3.0)
    v1 = random_time_shift(v1, max_shift_seconds=0.02)
    mel1 = librosa.feature.melspectrogram(y=v1, sr=SAMPLE_RATE, n_fft=N_FFT,
                                          hop_length=HOP_LENGTH, n_mels=N_MELS,
                                          fmin=20, fmax=8000, power=2.0)
    mel1 = librosa.power_to_db(mel1, ref=np.max)
    mel1 = (mel1 - mel1.mean()) / (mel1.std() + 1e-6)
    mel1 = torch.from_numpy(mel1).float().unsqueeze(0).repeat(3, 1, 1)
    mel1 = spec_augment(mel1, time_mask_max=8, freq_mask_max=10, n_time_masks=1, n_freq_masks=1)

    # VIEW 2 (stronger)
    v2 = wav_np.copy()
    v2 = add_noise(v2, snr_db=np.random.uniform(10, 25))
    v2 = random_gain(v2, -6.0, 6.0)
    v2 = random_time_shift(v2, max_shift_seconds=0.04)
    mel2 = librosa.feature.melspectrogram(y=v2, sr=SAMPLE_RATE, n_fft=N_FFT,
                                          hop_length=HOP_LENGTH, n_mels=N_MELS,
                                          fmin=20, fmax=8000, power=2.0)
    mel2 = librosa.power_to_db(mel2, ref=np.max)
    mel2 = (mel2 - mel2.mean()) / (mel2.std() + 1e-6)
    mel2 = torch.from_numpy(mel2).float().unsqueeze(0).repeat(3, 1, 1)
    mel2 = spec_augment(mel2, time_mask_max=12, freq_mask_max=15, n_time_masks=2, n_freq_masks=2)

    return mel1, mel2

# -------------------------
# Collate function: returns two augmented views + label (label optional)
# -------------------------
def collate_fn_byol(batch):
    """
    batch: list of dataset items where item["audio"] is a torch tensor shape (8000,)
    returns: dict with
       view1: tensor [B, 3, F, T]
       view2: tensor [B, 3, F, T]
       labels: tensor [B] (optional, used later for linear probe)
    """
    view1_list = []
    view2_list = []
    labels = []

    for item in batch:
        wav = item["audio"]  # HF dataset gives a torch tensor already
        v1, v2 = make_two_views(wav)
        view1_list.append(v1)
        view2_list.append(v2)
        labels.append(item["label"])

    # stack (they should all have same time dimension since audio fixed length)
    view1 = torch.stack(view1_list, dim=0)
    view2 = torch.stack(view2_list, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)

    return {"view1": view1, "view2": view2, "label": labels}

In [3]:

# -------------------------
# BYOL model components
# -------------------------
class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=2048, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        # x: [B, in_dim]
        return self.net(x)

class BYOLAudio(nn.Module):
    def __init__(self, backbone_name="resnet50", projector_hidden=2048, projector_out=256, pretrained=False):
        super().__init__()
        # load backbone and adapt first conv to 3 channels (we already have 3 channels)
        if backbone_name == "resnet50":
            backbone = models.resnet50(weights=None if not pretrained else "IMAGENET1K_V1")
        elif backbone_name == "resnet18":
            backbone = models.resnet18(weights=None if not pretrained else "IMAGENET1K_V1")
        else:
            raise ValueError("unsupported backbone")

        # remove classifier
        num_features = backbone.fc.in_features
        backbone.fc = nn.Identity()
        self.backbone = backbone
        # projector
        self.projector = MLPHead(num_features, hidden_dim=projector_hidden, out_dim=projector_out)
        # predictor
        self.predictor = MLPHead(projector_out, hidden_dim=projector_hidden // 2, out_dim=projector_out)

    def forward_backbone(self, x):
        # x: [B, 3, F, T]
        out = self.backbone(x)  # [B, num_features]
        return out

    def online_forward(self, x):
        feat = self.forward_backbone(x)
        proj = self.projector(feat)
        pred = self.predictor(proj)
        return feat, proj, pred

    def target_forward(self, x):
        with torch.no_grad():
            feat = self.forward_backbone(x)
            proj = self.projector(feat)
        return feat, proj


In [4]:
# -------------------------
# Helper: create target network as copy of online network
# -------------------------
def initialize_target_network(online_net: BYOLAudio) -> BYOLAudio:
    target = BYOLAudio(backbone_name=BACKBONE)
    target.load_state_dict(online_net.state_dict())
    for p in target.parameters():
        p.requires_grad = False
    return target

# momentum update
@torch.no_grad()
def momentum_update(online: nn.Module, target: nn.Module, m: float):
    # target = m * target + (1 - m) * online
    for param_o, param_t in zip(online.parameters(), target.parameters()):
        param_t.data = param_t.data * m + param_o.data * (1.0 - m)

# loss: negative cosine similarity (or l2 between normalized vectors)
def byol_loss_fn(p, z):
    # p: prediction from online net [B, D]
    # z: projection from target net [B, D] (stop grad)
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return 2 - 2 * (p * z).sum(dim=1)  # per-sample loss (2 - 2 cos) -> mean later


In [5]:
# -------------------------
# Training BYOL
# -------------------------
def train_byol(dataset, epochs=EPOCHS_BYOL, save_every=5):
    # prepare dataloader
    train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE,
                              shuffle=True, num_workers=NUM_WORKERS,
                              collate_fn=collate_fn_byol, pin_memory=True)

    online_net = BYOLAudio(backbone_name=BACKBONE).to(DEVICE)
    target_net = initialize_target_network(online_net).to(DEVICE)

    optimizer = optim.AdamW(online_net.parameters(), lr=LR_BYOL, weight_decay=WEIGHT_DECAY)

    # momentum schedule params (cosine schedule)
    base_m = 0.996
    final_m = 1.0
    total_steps = epochs * len(train_loader)
    global_step = 0

    best_loss = float("inf")
    for epoch in range(epochs):
        online_net.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"BYOL Epoch {epoch+1}/{epochs}", leave=False)
        for batch in pbar:
            v1 = batch["view1"].to(DEVICE)  # [B, 3, F, T]
            v2 = batch["view2"].to(DEVICE)

            # ONLINE forward: view1 -> pred1, view2 -> pred2
            _, proj1, pred1 = online_net.online_forward(v1)
            _, proj2, pred2 = online_net.online_forward(v2)

            # TARGET forward (no grad): view1 -> z1, view2 -> z2
            with torch.no_grad():
                _, z1 = target_net.target_forward(v1)
                _, z2 = target_net.target_forward(v2)

            # BYOL loss: compare pred1 vs z2 and pred2 vs z1
            loss1 = byol_loss_fn(pred1, z2).mean()
            loss2 = byol_loss_fn(pred2, z1).mean()
            loss = (loss1 + loss2) * 0.5

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update target network momentum
            global_step += 1
            # cosine schedule from base_m -> final_m
            momentum = 1.0 - (1.0 - base_m) * (math.cos(math.pi * global_step / total_steps) + 1) / 2
            momentum_update(online_net, target_net, momentum)

            running_loss += loss.item()
            pbar.set_postfix({"loss": f"{running_loss / (global_step % len(train_loader) + 1):.4f}",
                              "momentum": f"{momentum:.4f}"})

        epoch_loss = running_loss / len(train_loader)
        print(f"BYOL Epoch {epoch+1}/{epochs} loss {epoch_loss:.4f}")

        # save checkpoint
        if (epoch + 1) % save_every == 0 or epoch == epochs - 1:
            ckpt_path = os.path.join(MODEL_DIR, f"byol_epoch{epoch+1}.pth")
            torch.save({"online_state": online_net.state_dict(),
                        "target_state": target_net.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch+1}, ckpt_path)
            print(f"Saved {ckpt_path}")

    # final save
    final_path = os.path.join(MODEL_DIR, "byol_final.pth")
    torch.save({"online_state": online_net.state_dict()}, final_path)
    print("BYOL training finished. Model saved to", final_path)
    return online_net

In [6]:
# -------------------------
# Linear probe
# -------------------------
class LinearProbe(nn.Module):
    def __init__(self, backbone: BYOLAudio, out_dim=1):
        super().__init__()
        # backbone.backbone is the actual CNN
        self.backbone = backbone.backbone
        # freeze backbone
        for p in self.backbone.parameters():
            p.requires_grad = False
        # classifier
        num_features = self.backbone.fc.in_features if hasattr(self.backbone, "fc") else None
        # if we replaced fc with Identity earlier, note num_features is known from BYOLAudio internals
        # We'll infer by passing a dummy input during initialization if needed
        self.classifier = nn.Linear(backbone.projector.net[-1].out_features if False else backbone.projector.net[-1].out_features, out_dim)
        # but simpler: we will compute features size at runtime in train_probe

    def forward(self, x):
        feat = self.backbone(x)  # shape [B, feat_dim]
        out = self.classifier(feat)
        return out

def extract_feature_dim(backbone_model: BYOLAudio):
    # helper: run a dummy input through backbone to get feature dim
    dummy = torch.randn(1, 3, N_MELS, (SAMPLE_RATE // HOP_LENGTH) + 1).to(DEVICE)
    with torch.no_grad():
        feat = backbone_model.forward_backbone(dummy)
    return feat.shape[1]

def train_linear_probe(online_net: BYOLAudio, dataset, epochs=EPOCHS_PROBE):
    # build dataloaders (no augmentation, just mel preprocessing)
    def collate_probe(batch):
        xs = []
        ys = []
        for item in batch:
            wav = item["audio"]
            mel = waveform_to_logmel(wav)  # deterministic view
            xs.append(mel)
            ys.append(item["label"])
        x = torch.stack(xs, dim=0)
        y = torch.tensor(ys, dtype=torch.float32)
        return x, y

    train_loader = DataLoader(dataset["train"], batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, collate_fn=collate_probe, pin_memory=True)
    val_loader = DataLoader(dataset["val"], batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, collate_fn=collate_probe, pin_memory=True)

    feat_dim = extract_feature_dim(online_net)  # get dim
    # freeze backbone
    backbone_cnn = online_net.backbone
    for p in backbone_cnn.parameters():
        p.requires_grad = False

    # classifier head for binary
    clf = nn.Linear(feat_dim, 1).to(DEVICE)
    optimizer = optim.Adam(clf.parameters(), lr=LR_PROBE, weight_decay=WEIGHT_DECAY)
    criterion = nn.BCEWithLogitsLoss()

    best_auc = 0.0
    for epoch in range(epochs):
        # train
        clf.train()
        running_loss = 0.0
        preds_all = []
        labels_all = []
        pbar = tqdm(train_loader, desc=f"Probe Train {epoch+1}/{epochs}", leave=False)
        for x, y in pbar:
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            with torch.no_grad():
                feats = backbone_cnn(x)

            logits = clf(feats).squeeze(1)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            preds_all.extend(probs.tolist())
            labels_all.extend(y.cpu().numpy().tolist())

            # partial metrics
            if len(labels_all) >= 1:
                auc = roc_auc_score(labels_all, preds_all) if len(set(labels_all)) > 1 else 0.5
                acc = accuracy_score(np.array(labels_all) > 0.5, np.array(preds_all) > 0.5)
                pbar.set_postfix({"loss": f"{running_loss / (len(preds_all) // BATCH_SIZE + 1):.4f}",
                                  "auc": f"{auc:.4f}", "acc": f"{acc:.4f}"})

        # validate
        clf.eval()
        v_preds = []
        v_labels = []
        v_loss = 0.0
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(DEVICE)
                y = y.to(DEVICE)
                feats = backbone_cnn(x)
                logits = clf(feats).squeeze(1)
                loss = criterion(logits, y)
                v_loss += loss.item()
                probs = torch.sigmoid(logits).cpu().numpy()
                v_preds.extend(probs.tolist())
                v_labels.extend(y.cpu().numpy().tolist())

        auc_val = roc_auc_score(v_labels, v_preds) if len(set(v_labels)) > 1 else 0.5
        acc_val = accuracy_score(np.array(v_labels) > 0.5, np.array(v_preds) > 0.5)
        print(f"Probe Epoch {epoch+1}/{epochs}  Train loss: {running_loss/len(train_loader):.4f}  Val loss: {v_loss/len(val_loader):.4f}  Val AUC: {auc_val:.4f}  Val Acc: {acc_val:.4f}")

        # save best
        if auc_val > best_auc:
            best_auc = auc_val
            torch.save({"clf_state": clf.state_dict(), "epoch": epoch+1}, os.path.join(MODEL_DIR, "best_probe.pth"))
            print("Saved best linear probe checkpoint (AUC improved)")

    print("Linear probe training finished. Best Val AUC:", best_auc)


In [8]:
from datasets import load_dataset
dataset = load_dataset("Hibou-Foundation/big_ds_4_raw_wav_balanced")
dataset = dataset.with_format("torch", columns=["audio", "label"])

Downloading data: 100%|██████████| 23/23 [00:00<00:00, 4191.21files/s]
Generating train split: 100%|██████████| 352116/352116 [00:44<00:00, 7950.49 examples/s] 
Generating val split: 100%|██████████| 43580/43580 [00:04<00:00, 9607.53 examples/s] 
Generating test split: 100%|██████████| 43478/43478 [00:05<00:00, 8691.22 examples/s] 


In [11]:
# Train BYOL (self-supervised)
ssl_model = train_byol(dataset, epochs=EPOCHS_BYOL, save_every=5)

# Run linear probe
train_linear_probe(ssl_model, dataset, epochs=EPOCHS_PROBE)

                                                                                                   

BYOL Epoch 1/5 loss 0.1178


                                                                                                   

BYOL Epoch 2/5 loss 0.0344


                                                                                                   

BYOL Epoch 3/5 loss 0.0296


                                                                                                   

BYOL Epoch 4/5 loss 0.0290


                                                                                                   

BYOL Epoch 5/5 loss 0.0279
Saved ./ssl_checkpoints/byol_epoch5.pth
BYOL training finished. Model saved to ./ssl_checkpoints/byol_final.pth


                                                                                                          

Probe Epoch 1/2  Train loss: 0.1298  Val loss: 0.1276  Val AUC: 0.9891  Val Acc: 0.9580
Saved best linear probe checkpoint (AUC improved)


                                                                                                          

Probe Epoch 2/2  Train loss: 0.0974  Val loss: 0.1029  Val AUC: 0.9906  Val Acc: 0.9683
Saved best linear probe checkpoint (AUC improved)
Linear probe training finished. Best Val AUC: 0.9906252831430898


In [15]:
sampling_rate = 16000
ds_test_online = load_dataset("Usernameeeeee/drone_test", split="test")
ds_test_online2 = load_dataset("Usernameeeeee/drone_test_2", split="test")

def swap_labels(example):
    example["label"] = 1 - example["label"]   # flips 0 ↔ 1
    return example

ds_test_online = ds_test_online.map(swap_labels)
ds_test_online2 = ds_test_online2.map(swap_labels)

Generating test split: 100%|██████████| 893/893 [00:00<00:00, 6283.64 examples/s]
Generating test split: 100%|██████████| 2805/2805 [00:00<00:00, 3591.47 examples/s]
Map: 100%|██████████| 893/893 [00:00<00:00, 8501.19 examples/s]
Map: 100%|██████████| 2805/2805 [00:00<00:00, 9393.58 examples/s] 


In [13]:
import librosa
import numpy as np
import torch

def convert_to_mel_spectrogram(data):
    mel = librosa.feature.melspectrogram(
        y=data,
        sr=sampling_rate,
        n_fft=1025,
        hop_length=256,
        n_mels=128,
        fmin=20,
        fmax=8000,
        power=2.0
    )

    mel_db = librosa.power_to_db(mel, ref=np.max)
    mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)

    mel_db = torch.tensor(mel_db).float().unsqueeze(0)  # [1, 128, T]

    return mel_db

In [14]:
def collate_fn_librosa_test(batch):
    inputs, labels = [], []

    for item in batch:
        wav = item["audio"]["array"]
        mel = convert_to_mel_spectrogram(wav)   # [1, 128, T]
        mel = mel.repeat(3, 1, 1)               # → [3, 128, T]
        inputs.append(mel)
        labels.append(item["label"])

    # Pad to max T in batch
    max_T = max(x.shape[-1] for x in inputs)

    padded = []
    for x in inputs:
        pad_T = max_T - x.shape[-1]
        padded.append(
            torch.nn.functional.pad(x, (0, pad_T))  # pad last dimension only
        )

    inputs = torch.stack(padded)   # [B, 3, 128, max_T]
    labels = torch.tensor(labels).float()

    return inputs, labels

In [None]:
test_loader_1 = DataLoader(
    ds_test_online,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=1,
    collate_fn=collate_fn_librosa_test,
)

test_loader_2 = DataLoader(
    ds_test_online2,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=1,
    collate_fn=collate_fn_librosa_test,
)

In [16]:
from tqdm import tqdm

def run_test(model, loader):
    y_true, y_pred = [], []
    model.eval()

    with torch.no_grad():
        for x, y in tqdm(loader, desc="Testing"):
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            logits = model(x).squeeze(1)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            y_true.extend(y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    return y_true, y_pred

In [17]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

def compute_metrics(y_true, y_pred):
    y_true = np.array(y_true).astype(int)
    y_pred = np.array(y_pred).astype(int)

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)

    try:
        auc = roc_auc_score(y_true, y_pred)
    except:
        auc = float("nan")

    return {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "auc": auc,
        "confusion_matrix": cm,
    }


In [19]:
import seaborn as sns
import matplotlib.pyplot as plt

CLASS_NAMES = ["no_drone", "drone"]

def plot_confusion(cm, title):
    plt.figure(figsize=(4,4))
    sns.heatmap(
        cm, annot=True, fmt="d", cmap="Blues",
        xticklabels=CLASS_NAMES,
        yticklabels=CLASS_NAMES
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)
    plt.show()



In [20]:
import torch
from torch import nn
import torchvision.models as models

class ResNet50Audio(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet50(weights=None)   # No pretrained weights? → set pretrained=True if needed

        # Modify first conv to accept 3-channel mel (shape [3,128,T])
        self.backbone.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False
        )

        # Replace classifier head
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, 1)

    def forward(self, x):
        return self.backbone(x)


In [21]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet50Audio().to(DEVICE)

checkpoint_path = "./models_resnet.pt"

state = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(state)

model.eval()

print("Loaded model from:", checkpoint_path)


FileNotFoundError: [Errno 2] No such file or directory: './models_resnet.pt'