Transfer Learning using Audioset Weights and Classifier Head with 7 classes

In [None]:
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
from srcv2full.model import YAMNet
from srcv2full.feature_extraction import WaveformToMelSpec
import srcv2full.params as params

# -----------------------------
# 1. Load pretrained model (AudioSet weights)
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = YAMNet()
current_state_dict = torch.load("checkpoints/yamnet_audioset_converted.pth", map_location=device)

# Fix state_dict keys
new_state_dict = {}
for k, v in current_state_dict.items():
    if k.startswith("layer."):
        parts = k.split(".")
        layer_idx = int(parts[1]) + 1
        new_key = f"layer_{layer_idx}." + ".".join(parts[2:])
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

# Load weights but ignore classifier mismatch
model.load_state_dict(new_state_dict, strict=False)

# Replace classifier for ESC50Artifact (7 classes)
model.classifier = nn.Linear(1024, 7, bias=True)

# Freeze all layers except classifier
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False

model.to(device)
model.eval()

# -----------------------------
# 2. Load & preprocess an audio file
# -----------------------------
audio_path = "ESC50Artifact/audio/1-7973-A-7_hiss.wav"  # replace with your audio file
waveform, sr = torchaudio.load(audio_path)

# Resample if needed
if sr != params.SAMPLE_RATE:
    waveform = torchaudio.functional.resample(waveform, sr, params.SAMPLE_RATE)

# Ensure mono
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

waveform = waveform.to(device)

# Convert waveform to Mel Spectrogram chunks
waveform_to_mel = WaveformToMelSpec(device=device)
x_chunks, mel_spectrogram = waveform_to_mel(waveform, params.SAMPLE_RATE)
x_chunks = x_chunks.to(device)

# -----------------------------
# 3. Run forward pass (inference)
# -----------------------------
with torch.no_grad():
    logits = model(x_chunks)  # [num_chunks, 7]
    probs = torch.softmax(logits, dim=-1)

# Average predictions across chunks
avg_probs = probs.mean(dim=0)
pred_class = torch.argmax(avg_probs).item()
confidence = avg_probs[pred_class].item()

print(f"Predicted class: {pred_class}, confidence: {confidence:.3f}")

# -----------------------------
# 4. Visualization
# -----------------------------
plt.figure(figsize=(12,4))
plt.plot(waveform.squeeze().cpu().numpy())
plt.title(f"Waveform\nPredicted class: {pred_class}, Confidence: {confidence:.3f}")
plt.show()

plt.figure(figsize=(12,4))
plt.imshow(mel_spectrogram.squeeze(), aspect='auto', origin='lower', cmap='seismic',
           extent=[0, mel_spectrogram.shape[1], 0, params.SAMPLE_RATE])
plt.title("Mel Spectrogram (Reds)")
plt.xlabel("Time")
plt.ylabel("Frequency (Hz)")
plt.colorbar()
plt.show()

plt.figure(figsize=(8,4))
plt.bar(range(7), avg_probs.cpu().numpy())
plt.title("Prediction Probabilities")
plt.xlabel("Class ID (0–6)")
plt.ylabel("Probability")
plt.show()


Finetuning first on ESC50Artifact with only classifier unfrozen

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
import srcv2full.params as params
from srcv2full.data import ESC50ArtifactData
from srcv2full.model import YAMNet
from srcv2full.engine import YAMNetEngine
import logging
import os

# -----------------------------
# Logger setup
# -----------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("YAMNetTraining")

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load pretrained model
# -----------------------------
model = YAMNet()
checkpoint = torch.load("checkpoints/yamnet_audioset_converted.pth", map_location=device)

# Fix layer names if needed
new_state_dict = {}
for k, v in checkpoint.items():
    if k.startswith("layer."):
        parts = k.split(".")
        layer_idx = int(parts[1]) + 1
        new_key = f"layer_{layer_idx}." + ".".join(parts[2:])
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict, strict=False)

# Replace classifier for 7 classes
model.classifier = torch.nn.Linear(model.classifier.in_features, 7)

# Freeze backbone
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False

# -----------------------------
# Dataset and DataLoader
# -----------------------------
data_dir = "ESC50Artifact/"
full_dataset = ESC50ArtifactData(data_dir)

# Split 80/20 train/val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: x)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: x)

# -----------------------------
# Engine
# -----------------------------
tt_chunk_size = params.CHUNK_SIZE
engine = YAMNetEngine(model=model, tt_chunk_size=tt_chunk_size, logger=logger)

# -----------------------------
# Train
# -----------------------------
checkpoint_path = "checkpoints/yamnet_finetune_esc50artifact.pth"
num_epochs = 10

engine.train_yamnet(
    train_loader=train_loader,
    val_loader=val_loader,
    checkpoint_path=checkpoint_path,
    num_labels=7,
    num_epochs=num_epochs
)

logger.info("v1 Fine-tuning complete! -1 unfrozen")


INFO:YAMNetTraining:Using CUDA
INFO:YAMNetTraining:NVIDIA GeForce RTX 3050 Laptop GPU
INFO:YAMNetTraining:Started training
Epoch 1/10 [Train]: 100%|██████████| 2800/2800 [01:47<00:00, 26.01it/s, loss=2.37]
INFO:YAMNetTraining:Epoch 1/10, Train Loss: 1.9768, Train Acc: 0.0333
Epoch 1/10 [Val]: 100%|██████████| 700/700 [00:22<00:00, 31.04it/s, loss=2.32]  
INFO:YAMNetTraining:Epoch 1/10, Val Loss: 2.9743, Val Acc: 0.0108
INFO:YAMNetTraining:New best model saved with val_acc: 0.0108
Epoch 2/10 [Train]: 100%|██████████| 2800/2800 [01:48<00:00, 25.72it/s, loss=1.81]
INFO:YAMNetTraining:Epoch 2/10, Train Loss: 1.9760, Train Acc: 0.0342
Epoch 2/10 [Val]: 100%|██████████| 700/700 [00:23<00:00, 29.72it/s, loss=2.32] 
INFO:YAMNetTraining:Epoch 2/10, Val Loss: 2.1432, Val Acc: 0.0081
Epoch 3/10 [Train]: 100%|██████████| 2800/2800 [01:52<00:00, 24.97it/s, loss=1.69]
INFO:YAMNetTraining:Epoch 3/10, Train Loss: 1.9732, Train Acc: 0.0361
Epoch 3/10 [Val]: 100%|██████████| 700/700 [00:21<00:00, 32.22i

```python
1. Backbone is still frozen

In your current setup, if you froze all the pre-trained layers, the classifier alone might not have enough capacity to learn meaningful features from ESC50Artifact.

Especially since AudioSet classes (521) are very different from your 7 artifact classes.

Freezing everything except the last layer works well only if your dataset is very similar to the pre-trained one.

✅ Solution: Unfreeze at least the last few convolutional/separable layers of YAMNet, not just the classifier.

2. Learning rate too low

For the classifier alone, 0.0001 may be okay.

But if you unfreeze some backbone layers, you’ll likely need a slightly lower LR for the backbone and higher for the classifier.

3. Loss & chunk averaging

Currently, your forward averages first chunk only for loss.

You should average all chunks’ predictions per audio sample before computing the loss. Otherwise, the classifier gets extremely noisy gradients.

4. Data normalization

If your Mel spectrogram preprocessing is different from what AudioSet training used, the backbone may not recognize features well.

You may need same normalization/scaling as pre-trained model.

5. Tiny dataset

ESC50Artifact is small; the model may overfit or barely learn if backbone is frozen.

Consider data augmentation (e.g., noise, pitch shift, time stretch) to help fine-tuning.

✅ Recommended Steps

Unfreeze last few YAMNet layers, e.g., layer_12, layer_13, and the classifier:

for name, param in model.named_parameters():
    if "layer_12" in name or "layer_13" in name or "classifier" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False


Use separate learning rates for backbone vs classifier:

optimizer = torch.optim.Adam([
    {"params": [p for n, p in model.named_parameters() if "layer_12" in n or "layer_13" in n], "lr": 1e-4},
    {"params": model.classifier.parameters(), "lr": 5e-4}
])


Average predictions over all chunks per sample for training loss.

Optional: Add data augmentation for each audio batch.

Not strictly — your ESC50ArtifactData class is fine structurally — it correctly:

Loads audio

Converts multi-channel to mono

Returns integer labels for classification

…but there are a few optional improvements that can help fine-tuning:

1. Ensure all audio is long enough

Your current chunking in WaveformToMelSpec might pad very short files, which can create lots of near-zero chunks.

Option: Filter out extremely short audio or pad intelligently.

if waveform.shape[1] < min_length_samples:
    waveform = torch.nn.functional.pad(waveform, (0, min_length_samples - waveform.shape[1]))

2. Data augmentation

For small datasets like ESC50Artifact, augmentations help the model generalize:

# Example augmentations
def augment(waveform, sr):
    # Random noise
    waveform += 0.005 * torch.randn_like(waveform)
    
    # Random pitch shift
    n_steps = torch.randint(-2, 3, (1,)).item()
    waveform = torchaudio.functional.pitch_shift(waveform, sr, n_steps)
    
    # Random time stretch
    rate = 0.9 + 0.2 * torch.rand(1).item()  # 0.9-1.1x
    waveform = torchaudio.functional.time_stretch(waveform, sr, rate)
    
    return waveform


You can apply this in __getitem__ during training only (not validation).

3. Label consistency

Make sure artifact2id mapping is always deterministic: sorted(self.annotations['artifact_label'].unique()) so the same label always maps to the same index.

self.artifact2id = {lbl: i for i, lbl in enumerate(sorted(self.annotations['artifact_label'].unique()))}

4. Optional normalization

If your pre-trained YAMNet used normalized Mel spectrograms, scale your input similarly:

mel = (mel - mel.mean()) / (mel.std() + 1e-6)

✅ TL;DR

No major changes needed.

Recommended: data augmentation, label sorting, and padding very short audio.

Your current dataset class is already compatible with your engine.
```

In [1]:
import torch
from torch.utils.data import DataLoader, random_split
import srcv2full.params as params
from srcv2full.data import ESC50ArtifactData
from srcv2full.model import YAMNet
from srcv2full.engine import YAMNetEngine
import logging
import os

# -----------------------------
# Logger setup
# -----------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("YAMNetTraining")

# -----------------------------
# Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# Load pretrained model
# -----------------------------
model = YAMNet()
checkpoint = torch.load("checkpoints/yamnet_audioset_converted.pth", map_location=device)

# Fix layer names if needed
new_state_dict = {}
for k, v in checkpoint.items():
    if k.startswith("layer."):
        parts = k.split(".")
        layer_idx = int(parts[1]) + 1
        new_key = f"layer_{layer_idx}." + ".".join(parts[2:])
        new_state_dict[new_key] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict, strict=False)

# Replace classifier for 7 classes
model.classifier = torch.nn.Linear(model.classifier.in_features, 7)

# Freeze backbone
for name, param in model.named_parameters():
    if "layer_12" in name or "layer_13" in name or "classifier" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

# -----------------------------
# Dataset and DataLoader
# -----------------------------
data_dir = "ESC50Artifact/"
full_dataset = ESC50ArtifactData(data_dir)

# Split 80/20 train/val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: x)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: x)

# -----------------------------
# Engine
# -----------------------------
tt_chunk_size = params.CHUNK_SIZE
engine = YAMNetEngine(model=model, tt_chunk_size=tt_chunk_size, logger=logger)

# -----------------------------
# Train
# -----------------------------
checkpoint_path = "checkpoints/yamnet_finetune_esc50artifact_frozen_-3.pth"
num_epochs = 10

engine.train_yamnet(
    train_loader=train_loader,
    val_loader=val_loader,
    checkpoint_path=checkpoint_path,
    num_labels=7,
    num_epochs=num_epochs
)

logger.info("v2 Fine-tuning complete! -3 unfrozen")


INFO:YAMNetTraining:Using CUDA
INFO:YAMNetTraining:NVIDIA GeForce RTX 3050 Laptop GPU
INFO:YAMNetTraining:Started training
Epoch 1/10 [Train]: 100%|██████████| 2800/2800 [01:21<00:00, 34.27it/s, loss=1.82] 
INFO:YAMNetTraining:Epoch 1/10, Train Loss: 2.0976, Train Acc: 0.0349
Epoch 1/10 [Val]: 100%|██████████| 700/700 [00:16<00:00, 43.23it/s, loss=2.49]  
INFO:YAMNetTraining:Epoch 1/10, Val Loss: 2.6005, Val Acc: 0.0081
INFO:YAMNetTraining:New best model saved with val_acc: 0.0081
Epoch 2/10 [Train]: 100%|██████████| 2800/2800 [01:19<00:00, 35.22it/s, loss=2.22] 
INFO:YAMNetTraining:Epoch 2/10, Train Loss: 2.0899, Train Acc: 0.0384
Epoch 2/10 [Val]: 100%|██████████| 700/700 [00:15<00:00, 44.27it/s, loss=1.65] 
INFO:YAMNetTraining:Epoch 2/10, Val Loss: 2.6293, Val Acc: 0.0101
INFO:YAMNetTraining:New best model saved with val_acc: 0.0101
Epoch 3/10 [Train]: 100%|██████████| 2800/2800 [01:16<00:00, 36.64it/s, loss=3.16] 
INFO:YAMNetTraining:Epoch 3/10, Train Loss: 2.0786, Train Acc: 0.037