# PANNs Post-Training Tutorial

In this tutorial, we will go through the basic steps for post-training of PANNs (Pretrained Audio Neural Networks):

1. Dataset preparation
2. Data loading and preprocessing
3. Model setup
4. Training loop
5. Validation and metrics
6. Saving the best models

Let‚Äôs start by importing the necessary libraries and setting up our environment!



In [1]:
import os
import random
import json
import librosa
import numpy as np
import torch
import logging
from datetime import datetime
from textgrid import TextGrid
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from torch.nn import functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from panns_inference.models import Cnn14_DecisionLevelMax  # From PANNs
os.chdir('/data2/nitin/main/separation/audioset_tagging_cnn')

from utils import config  # From PANNs

## 1Ô∏è‚É£ Dataset Class and Label Extraction

We start by defining the dataset class `TextGridDataset` which:
- Loads audio files and corresponding TextGrid label files
- Generates frame-wise labels
- Computes class weights to handle class imbalance

Below is the code for the dataset.


In [2]:
from textgrid import TextGrid
import numpy as np
import librosa
import os

def time_to_frame(time, sr, hop_size):
    return int(time * sr // hop_size)

def get_ground_truth(textgrid_file, audio_file, fixed_sample_points=64000, hop_size_samples=12800):
    # Fixed label mapping (4 classes)
    label_mapping = {
        "Crow": 0,
        "Speech": 1,
        "2+Crows": 2,
        "Other": 3
    }

    # Load TextGrid and audio
    tg = TextGrid.fromFile(textgrid_file)
    waveform, _ = librosa.load(audio_file, sr=32000, mono=True)

    # Extract labeled intervals
    intervals = []
    for tier in tg.tiers:
        if tier.__class__.__name__ == "IntervalTier":
            for i in tier.intervals:
                label = i.mark.strip()
                if label and label in label_mapping:
                    intervals.append((i.minTime, i.maxTime, label))
                elif label and label not in label_mapping:
                    print(f"‚ö†Ô∏è Skipping unknown label '{label}' in {os.path.basename(textgrid_file)}")

    def create_labels(start_sample):
        num_frames = fixed_sample_points // 320 + 1
        clip_labels = np.zeros((num_frames, len(label_mapping)))
        clip_barktype = np.full(num_frames, -1)

        for s, e, label in intervals:
            label_idx = label_mapping[label]
            start_frame = time_to_frame(max(s, start_sample / 32000), 32000, 320) - (start_sample // 320)
            end_frame = time_to_frame(min(e, (start_sample + fixed_sample_points) / 32000), 32000, 320) - (start_sample // 320)
            clip_barktype[max(0, start_frame):end_frame] = label_idx
            clip_labels[max(0, start_frame):end_frame, label_idx] = 1

        return clip_labels, clip_barktype

    # Segment audio
    all_waveforms, all_labels, all_barktype, all_original_length = [], [], [], []
    start_sample = 0
    while start_sample + fixed_sample_points <= waveform.shape[0]:
        clip = waveform[start_sample:start_sample + fixed_sample_points]
        labels, barktype = create_labels(start_sample)
        all_waveforms.append(clip)
        all_labels.append(labels)
        all_barktype.append(barktype)
        all_original_length.append(fixed_sample_points)
        start_sample += hop_size_samples

    return all_waveforms, all_labels, all_barktype, all_original_length, label_mapping

In [3]:
class TextGridDataset(Dataset):
    """
    Dataset class that loads .wav files and TextGrid label files.
    It also computes class weights for balanced training.
    """
    def __init__(self, file_pairs, sample_rate=32000, hop_length=320):
        self.waveforms, self.framewise_labels, self.barktype, self.original_length = [], [], [], []
        self.hop_length = hop_length
        self.label_mapping = None

        for audio_path, textgrid_path in file_pairs:
            waveform_list, gt_list, barktype_list, original_length_list, label_mapping = get_ground_truth(textgrid_path, audio_path)
            if self.label_mapping is None:
                self.label_mapping = label_mapping

            self.waveforms.extend(waveform_list)
            self.framewise_labels.extend(gt_list)
            self.barktype.extend(barktype_list)
            self.original_length.extend(original_length_list)

        print(f"Loaded {len(self.waveforms)} samples")

        # Validate consistent class count
        num_classes = self.framewise_labels[0].shape[1]
        for i, lbl in enumerate(self.framewise_labels):
            if lbl.shape[1] != num_classes:
                raise ValueError(f"Sample {i} has {lbl.shape[1]} classes, expected {num_classes}")

        # Compute class weights
        self.class_counts = torch.zeros(num_classes)
        valid_frame_total = 0

        for labels, orig_len in zip(self.framewise_labels, self.original_length):
            labels_tensor = torch.from_numpy(labels).float()
            valid_frames = int(orig_len // self.hop_length) + 1
            self.class_counts += labels_tensor[:valid_frames].sum(dim=0)
            valid_frame_total += valid_frames

        self.class_counts[self.class_counts == 0] = 1e-6
        self.class_weights = valid_frame_total / (num_classes * self.class_counts)
        self.class_weights = self.class_weights / self.class_weights.sum() * num_classes

        print("Class weights:", self.class_weights)

    def __len__(self):
        return len(self.framewise_labels)

    def __getitem__(self, idx):
        waveform = torch.from_numpy(self.waveforms[idx]).float()
        labels = torch.from_numpy(self.framewise_labels[idx]).float()
        barktype = torch.from_numpy(self.barktype[idx]).float()
        return waveform, labels, barktype, self.original_length[idx]

## üêæ Model Building Blocks

We define two key components of our model:

- `ConvBlock`: Basic convolutional unit with BatchNorm, ReLU, and pooling.
- `Cnn14_DecisionLevelMax`: Our 14-layer CNN with decision-level max pooling for audio event detection and classification.

These modules will help you understand how to stack convolutional layers and modify pretrained models.


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from pytorch.models import SpecAugmentation, init_layer, init_bn
from pytorch.pytorch_utils import interpolate, pad_framewise_output, do_mixup


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.init_weight()

    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

    def forward(self, x, pool_size=(2, 2), pool_type="avg"):
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))

        if pool_type == "max":
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == "avg":
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == "avg+max":
            x = F.avg_pool2d(x, kernel_size=pool_size) + F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == "no-downsample-avg":
            x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        elif pool_type == "no-pooling":
            pass  # no pooling
        else:
            raise ValueError(f"Invalid pooling type: {pool_type}")

        return x


class Cnn14_DecisionLevelMax(nn.Module):
    def __init__(
        self,
        sample_rate,
        window_size,
        hop_size,
        mel_bins,
        fmin,
        fmax,
        classes_num,
        freeze_base,
        pooling_type="avg",
    ):
        super().__init__()
        self.interpolate_ratio = 32
        self.freeze_base = freeze_base
        self.pooling_type = pooling_type

        # Spectrogram and logmel feature extraction
        self.spectrogram_extractor = Spectrogram(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window="hann",
            center=True,
            pad_mode="reflect",
            freeze_parameters=True,
        )
        self.logmel_extractor = LogmelFilterBank(
            sr=sample_rate,
            n_fft=window_size,
            n_mels=mel_bins,
            fmin=fmin,
            fmax=fmax,
            ref=1.0,
            amin=1e-10,
            top_db=None,
            freeze_parameters=True,
        )

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(
            time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2
        )

        self.bn0 = nn.BatchNorm2d(64)
        self.conv_block1 = ConvBlock(1, 64)
        self.conv_block2 = ConvBlock(64, 128)
        self.conv_block3 = ConvBlock(128, 256)
        self.conv_block4 = ConvBlock(256, 512)
        self.conv_block5 = ConvBlock(512, 1024)
        self.conv_block6 = ConvBlock(1024, 2048)

        self.fc1 = nn.Linear(2048, 2048)
        self.fc_audioset = nn.Linear(2048, classes_num)
        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)
        init_layer(self.fc_audioset)

    def forward(self, input, mixup_lambda=None):
        """
        input: (batch_size, data_length)
        """
        x = self.spectrogram_extractor(input)
        x = self.logmel_extractor(x)
        frames_num = x.shape[2]

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)
            if mixup_lambda is not None:
                x = do_mixup(x, mixup_lambda)

        if self.pooling_type in ["no-downsample-avg", "no-pooling"]:
            self.interpolate_ratio = 1

        x = self.conv_block1(x, (2, 2), self.pooling_type)
        x = F.dropout(x, 0.2, self.training)
        x = self.conv_block2(x, (2, 2), self.pooling_type)
        x = F.dropout(x, 0.2, self.training)
        x = self.conv_block3(x, (2, 2), self.pooling_type)
        x = F.dropout(x, 0.2, self.training)
        x = self.conv_block4(x, (2, 2), self.pooling_type)
        x = F.dropout(x, 0.2, self.training)
        x = self.conv_block5(x, (2, 2), self.pooling_type)
        x = F.dropout(x, 0.2, self.training)
        x = self.conv_block6(x, (1, 1), self.pooling_type)
        x = F.dropout(x, 0.2, self.training)

        x = torch.mean(x, dim=3)
        x = F.max_pool1d(x, 3, 1, 1) + F.avg_pool1d(x, 3, 1, 1)
        x = F.dropout(x, 0.5, self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = F.dropout(x, 0.5, self.training)
        segmentwise_output = torch.sigmoid(self.fc_audioset(x))
        clipwise_output, _ = torch.max(segmentwise_output, dim=1)

        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)

        return {"framewise_output": framewise_output, "clipwise_output": clipwise_output}

    def load_and_modify_state_dict(self, state_dict, selected_classes_idx, num_additional_classes):
        self.load_state_dict(state_dict)
        original_fc_weight = self.fc_audioset.weight.data.clone()
        original_fc_bias = self.fc_audioset.bias.data.clone()

        selected_weight = original_fc_weight[selected_classes_idx, :]
        selected_bias = original_fc_bias[selected_classes_idx]

        new_fc = nn.Linear(2048, len(selected_classes_idx) + num_additional_classes)
        new_fc.weight.data[:len(selected_classes_idx)] = selected_weight
        new_fc.bias.data[:len(selected_classes_idx)] = selected_bias

        if self.freeze_base:
            for name, param in self.named_parameters():
                if "fc" not in name:
                    param.requires_grad = False

        self.fc_audioset = new_fc
        self.fc_audioset.weight.requires_grad = True
        self.fc_audioset.bias.requires_grad = True

        print("‚úÖ Pretrained weights loaded and final FC layer modified.")


## 2Ô∏è‚É£ Training Function

Now let‚Äôs define the function for training the model. We will:
- Load the dataset
- Instantiate the model
- Define the loss and optimizer
- Perform training and validation


In [5]:
import torch
import torch.nn.functional as F
from torch import nn

class FrameBceMaskedWeighted(nn.Module):
    """
    BCE Loss for frame-wise predictions with class weights,
    and increased penalty on false positives.
    """
    def __init__(self, hop_length, pos_weight=1.0, neg_weight=1.0):
        """
        pos_weight: weight for positive class (target=1)
        neg_weight: weight for negative class (target=0), higher to penalize false positives
        """
        super().__init__()
        self.hop_length = hop_length
        self.pos_weight = pos_weight
        self.neg_weight = neg_weight

    def forward(self, pred, target, origin_length):
        batch_ori_length = (origin_length // self.hop_length) + 1
        loss_total = 0.0
        for i in range(batch_ori_length.shape[0]):
            ori_len = int(batch_ori_length[i])
            p = pred.clamp(1e-7, 1.0)[i, :ori_len]  # predicted probs
            t = target[i, :ori_len]                  # targets

            # Create per-element weights:
            # Weight positive targets by pos_weight, negatives by neg_weight
            weights = torch.where(t == 1, 
                                  torch.tensor(self.pos_weight, device=t.device), 
                                  torch.tensor(self.neg_weight, device=t.device))

            loss = F.binary_cross_entropy(p, t, weight=weights)
            loss_total += loss

        return loss_total / batch_ori_length.shape[0]


In [6]:
def train_model(train_dataset, valid_dataset, model_type="Cnn14_DecisionLevelMax", epochs=30, learning_rate=1e-3, device="cuda:0"):
    device = device if torch.cuda.is_available() else "cpu"
    model = eval(model_type)(32000, 1024, 320, 64, 50, 14000, config.classes_num, freeze_base=False, pooling_type="avg+max")
    pretrained_checkpoint_path = "../Cnn14_mAP=0.431.pth"
    logging.info(
        "Load pretrained model from {}".format(pretrained_checkpoint_path)
    )
    checkpoint = torch.load(
        pretrained_checkpoint_path, map_location=device, weights_only=True
    )
    model.load_and_modify_state_dict(
        checkpoint["model"],
        selected_classes_idx=[117, 0],
        num_additional_classes=2,
    )
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=True)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=3, factor=0.1, min_lr=1e-7)
    loss_fn = FrameBceMaskedWeighted(320, train_dataset.class_weights.to(device))

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for waveforms, labels, _, orig_len in train_loader:
            waveforms, labels = waveforms.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(waveforms, None)["framewise_output"]
            loss = loss_fn(outputs, labels, orig_len)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss}")

        # Validation
        model.eval() 
        val_loss, correct, total = 0, 0, 0
        crow_correct, crow_total = 0, 0  # <-- For Crow label (index 0)

        with torch.no_grad():
            for waveforms, labels, _, orig_len in valid_loader:
                waveforms, labels = waveforms.to(device), labels.to(device)
                outputs = model(waveforms, None)["framewise_output"]
                loss = loss_fn(outputs, labels, orig_len)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=2)
                targets = torch.argmax(labels, dim=2)

                correct += (preds == targets).sum().item()
                total += targets.numel()

                # Compute Crow label accuracy (label index 0)
                crow_mask = (targets == 0)  # Locations where Crow is ground truth
                crow_correct += ((preds == targets) & crow_mask).sum().item()
                crow_total += crow_mask.sum().item()

        avg_val_loss = val_loss / len(valid_loader)
        val_acc = correct / total
        crow_acc = crow_correct / crow_total if crow_total > 0 else 0.0

        print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_acc:.4f}")
        print(f"Crow Accuracy (class 0): {crow_acc:.4f}")
        scheduler.step(avg_val_loss)
        return model



## 3Ô∏è‚É£ Putting It All Together

Now let‚Äôs load the dataset and start training!


In [7]:
file_folder = "../fine_tuning_annotations"
all_files = []

for f in os.listdir(file_folder):
    if f.endswith(".wav"):
        wav_path = os.path.join(file_folder, f)
        tg_path = wav_path.replace(".wav", ".TextGrid")
        if os.path.exists(tg_path):
            all_files.append((wav_path, tg_path))

print(f"Found {len(all_files)} audio-textgrid pairs.")
random.shuffle(all_files)
train_size = int(0.8 * len(all_files))
train_dataset = TextGridDataset(all_files[:train_size])
valid_dataset = TextGridDataset(all_files[train_size:])

model = train_model(train_dataset, valid_dataset, epochs=20)
save_path = "../fine_tuned_cnn14_framewise.pth"
torch.save({"model": model.state_dict()}, save_path)
print(f"Model saved to {save_path}")



Found 350 audio-textgrid pairs.
Loaded 14268 samples
Class weights: tensor([0.0850, 3.3190, 0.2338, 0.3622])
Loaded 3766 samples
Class weights: tensor([0.3295, 2.5750, 0.6050, 0.4904])
‚úÖ Pretrained weights loaded and final FC layer modified.


  torch.tensor(self.pos_weight, device=t.device),


Epoch 1, Training Loss: 0.11714814295950492
Validation Loss: 0.1759, Accuracy: 0.6774
Crow Accuracy (class 0): 0.6879
Model saved to ../fine_tuned_cnn14_clipwise.pth
