Transfer Learning using Audioset Weights and Classifier Head with 7 classes

In [None]:
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
from srcv2full.model import YAMNet
from srcv2full.feature_extraction import WaveformToMelSpec
import srcv2full.params as params

# -----------------------------
# 1. Load pretrained model (AudioSet weights)
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = YAMNet()
current_state_dict = torch.load("checkpoints/yamnet_audioset_converted.pth", map_location=device)

# Fix state_dict keys
new_state_dict = {}
for k, v in current_state_dict.items():
    if k.startswith("layer."):
        parts = k.split(".")
        layer_idx = int(parts[1]) + 1
        new_key = f"layer_{layer_idx}." + ".".join(parts[2:])
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

# Load weights but ignore classifier mismatch
model.load_state_dict(new_state_dict, strict=False)

# Replace classifier for ESC50Artifact (7 classes)
model.classifier = nn.Linear(1024, 7, bias=True)

# Freeze all layers except classifier
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False

model.to(device)
model.eval()

# -----------------------------
# 2. Load & preprocess an audio file
# -----------------------------
audio_path = "ESC50Artifact/audio/1-7973-A-7_hiss.wav"  # replace with your audio file
waveform, sr = torchaudio.load(audio_path)

# Resample if needed
if sr != params.SAMPLE_RATE:
    waveform = torchaudio.functional.resample(waveform, sr, params.SAMPLE_RATE)

# Ensure mono
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

waveform = waveform.to(device)

# Convert waveform to Mel Spectrogram chunks
waveform_to_mel = WaveformToMelSpec(device=device)
x_chunks, mel_spectrogram = waveform_to_mel(waveform, params.SAMPLE_RATE)
x_chunks = x_chunks.to(device)

# -----------------------------
# 3. Run forward pass (inference)
# -----------------------------
with torch.no_grad():
    logits = model(x_chunks)  # [num_chunks, 7]
    probs = torch.softmax(logits, dim=-1)

# Average predictions across chunks
avg_probs = probs.mean(dim=0)
pred_class = torch.argmax(avg_probs).item()
confidence = avg_probs[pred_class].item()

print(f"Predicted class: {pred_class}, confidence: {confidence:.3f}")

# -----------------------------
# 4. Visualization
# -----------------------------
plt.figure(figsize=(12,4))
plt.plot(waveform.squeeze().cpu().numpy())
plt.title(f"Waveform\nPredicted class: {pred_class}, Confidence: {confidence:.3f}")
plt.show()

plt.figure(figsize=(12,4))
plt.imshow(mel_spectrogram.squeeze(), aspect='auto', origin='lower', cmap='seismic',
           extent=[0, mel_spectrogram.shape[1], 0, params.SAMPLE_RATE])
plt.title("Mel Spectrogram (Reds)")
plt.xlabel("Time")
plt.ylabel("Frequency (Hz)")
plt.colorbar()
plt.show()

plt.figure(figsize=(8,4))
plt.bar(range(7), avg_probs.cpu().numpy())
plt.title("Prediction Probabilities")
plt.xlabel("Class ID (0–6)")
plt.ylabel("Probability")
plt.show()


Finetuning first on ESC50Artifact and evaluating

In [1]:
import torch
from torch.utils.data import DataLoader, random_split
import srcv2full.params as params
from srcv2full.data import ESC50ArtifactData
from srcv2full.model import YAMNet
from srcv2full.engine import YAMNetEngine
import logging
import os

# -----------------------------
# Logger setup
# -----------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("YAMNetTraining")

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load pretrained model
# -----------------------------
model = YAMNet()
checkpoint = torch.load("checkpoints/yamnet_audioset_converted.pth", map_location=device)

# Fix layer names if needed
new_state_dict = {}
for k, v in checkpoint.items():
    if k.startswith("layer."):
        parts = k.split(".")
        layer_idx = int(parts[1]) + 1
        new_key = f"layer_{layer_idx}." + ".".join(parts[2:])
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict, strict=False)

# Replace classifier for 7 classes
model.classifier = torch.nn.Linear(model.classifier.in_features, 7)

# Freeze backbone
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False

# -----------------------------
# Dataset and DataLoader
# -----------------------------
data_dir = "ESC50Artifact/"
full_dataset = ESC50ArtifactData(data_dir)

# Split 80/20 train/val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: x)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: x)

# -----------------------------
# Engine
# -----------------------------
tt_chunk_size = params.CHUNK_SIZE
engine = YAMNetEngine(model=model, tt_chunk_size=tt_chunk_size, logger=logger)

# -----------------------------
# Train
# -----------------------------
checkpoint_path = "checkpoints/yamnet_finetune_esc50artifact.pth"
num_epochs = 10

engine.train_yamnet(
    train_loader=train_loader,
    val_loader=val_loader,
    checkpoint_path=checkpoint_path,
    num_labels=7,
    num_epochs=num_epochs
)

logger.info("Fine-tuning complete!")


INFO:YAMNetTraining:Using CUDA
INFO:YAMNetTraining:NVIDIA GeForce RTX 3050 Laptop GPU
INFO:YAMNetTraining:Started training
Epoch 1/10 [Train]: 100%|██████████| 2800/2800 [01:47<00:00, 26.01it/s, loss=2.37]
INFO:YAMNetTraining:Epoch 1/10, Train Loss: 1.9768, Train Acc: 0.0333
Epoch 1/10 [Val]: 100%|██████████| 700/700 [00:22<00:00, 31.04it/s, loss=2.32]  
INFO:YAMNetTraining:Epoch 1/10, Val Loss: 2.9743, Val Acc: 0.0108
INFO:YAMNetTraining:New best model saved with val_acc: 0.0108
Epoch 2/10 [Train]: 100%|██████████| 2800/2800 [01:48<00:00, 25.72it/s, loss=1.81]
INFO:YAMNetTraining:Epoch 2/10, Train Loss: 1.9760, Train Acc: 0.0342
Epoch 2/10 [Val]: 100%|██████████| 700/700 [00:23<00:00, 29.72it/s, loss=2.32] 
INFO:YAMNetTraining:Epoch 2/10, Val Loss: 2.1432, Val Acc: 0.0081
Epoch 3/10 [Train]: 100%|██████████| 2800/2800 [01:52<00:00, 24.97it/s, loss=1.69]
INFO:YAMNetTraining:Epoch 3/10, Train Loss: 1.9732, Train Acc: 0.0361
Epoch 3/10 [Val]: 100%|██████████| 700/700 [00:21<00:00, 32.22i