# Finetune the ConvNext BirdSet model on a custom dataset

This notebook demonstrates how to finetune the ConvNext BirdSet model on a custom dataset. The custom dataset used as an example is [ESC50 dataset](https://github.com/karolpiczak/ESC-50), which contains 2000 environmental audio recordings. We will use Lightning as a high-level interface for PyTorch to simplify the training process.

In [3]:
CACHE_DIR = "../../data_birdset"

## ESC50 datamodule
First we define a datamodule for the ESC50 dataset. The datamodule will download the dataset, split it into training and validation sets, here augmentations could be added.

In [None]:
from lightning import LightningDataModule
from datasets import load_dataset, Audio
from torch.utils.data import DataLoader
import torchaudio


class ESC50DataModule(LightningDataModule):
    def __init__(
        self,
    ):
        super().__init__()
        self.resampler = torchaudio.transforms.Resample(orig_freq=44100, new_freq=32000)
        self.train_dataset = None
        self.test_dataset = None

    def _transforms(self, batch):
        # collate audio
        waveform_batch = [audio["array"] for audio in batch["audio"]]
        # TODO add data augmentation here
        return {"input_values": waveform_batch, "labels": batch["labels"]}

    def prepare_data(self):
        dataset = load_dataset(
            path="ashraq/esc50",
            cache_dir=CACHE_DIR,
        )
        dataset = dataset.cast_column(
            column="audio",
            feature=Audio(
                sampling_rate=32000,
                mono=True,
                decode=True,
            ),
        )

        dataset = dataset.rename_column("target", "labels")
        # the ESC50 samples are split into 5 folds, select 4 folds for training and 1 fold for testing
        self.train_dataset = dataset["train"].filter(
            lambda x: x["fold"] in [1, 2, 3, 0]
        )
        self.test_dataset = dataset["train"].filter(lambda x: x["fold"] == 4)
        # rename target column to labels
        self.train_dataset = self.train_dataset.select_columns(["audio", "labels"])
        self.test_dataset = self.test_dataset.select_columns(["audio", "labels"])
        self.train_dataset.set_transform(self._transforms)
        self.test_dataset.set_transform(self._transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=32, shuffle=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=32, shuffle=False)


dm = ESC50DataModule()
dm.prepare_data()
dm.setup("fit")
print(dm.train_dataset)

Repo card metadata block was not found. Setting CardData to empty.


Dataset({
    features: ['audio', 'labels'],
    num_rows: 1200
})


## Define Module
Next, we define a lightning module that loads the ConvNext BirdSet model, defines the training and testing step including logging the accuracy and loss. We use the ADAM optimizer and the CrossEntropyLoss as the loss function.

In [None]:
from birdset import ConvNextBirdSet
import torch
import torch.nn.functional as F
from lightning import LightningModule
from torchmetrics import Accuracy


class BirdsetLightningModule(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = ConvNextBirdSet(num_classes=50)
        self.train_acc_metric = Accuracy(task="multiclass", num_classes=50)
        self.test_acc_metric = Accuracy(task="multiclass", num_classes=50)

    def forward(self, x):
        return self.model(x)

    def _step(self, batch):
        x = batch["input_values"]
        y = batch["labels"]
        preprocessed = self.model.preprocess(x)
        y_hat = self.model(preprocessed)
        loss = F.cross_entropy(y_hat, y)
        return loss, y_hat

    def training_step(self, batch, batch_idx):
        loss, y_hat = self._step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        # update accuracy metric for training batch
        self.train_acc_metric.update(y_hat, batch["labels"])
        acc = self.train_acc_metric.compute()
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def on_test_epoch_end(self):
        # reset the test accuracy metric at the end of testing
        self.test_acc_metric.reset()

    def test_step(self, batch, batch_idx):
        loss, y_hat = self._step(batch)
        self.log("test_loss", loss)
        # update accuracy metric for test batch
        self.test_acc_metric.update(y_hat, batch["labels"])
        acc = self.test_acc_metric.compute()
        self.log("test_acc", acc, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        # reset the training accuracy metric at the end of the epoch
        self.train_acc_metric.reset()

    def on_test_epoch_end(self):
        # reset the test accuracy metric at the end of testing
        self.test_acc_metric.reset()

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=5e-5)


module = BirdsetLightningModule()

Some weights of ConvNextForImageClassification were not initialized from the model checkpoint at DBD-research-group/ConvNeXT-Base-BirdSet-XCL and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([9736]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.weight: found shape torch.Size([9736, 1024]) in the checkpoint and torch.Size([50, 1024]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
# define trainer
from lightning import Trainer

trainer = Trainer(max_epochs=10, accelerator="gpu", devices=[0])

# train the model
trainer.fit(module, dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Repo card metadata block was not found. Setting CardData to empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name             | Type               | Params | Mode 
----------------------------------------------------------------
0 | model            | ConvNextBirdSet    | 87.6 M | train
1 | train_acc_metric | MulticlassAccuracy | 0      | train
2 | test_acc_metric  | MulticlassAccuracy | 0      | train
----------------------------------------------------------------
87.6 M    Trainable params
0         Non-trainable params
87.6 M    Total params
350.454   Total estimated model params size (MB)
8         Modules in train mode
279       Modules in eval mode
/home/vscode/.cache/pypoetry/virtualenvs/birdset-xS3fZVNL-py3.10/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataL

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## Test the model
Finally, we test the model on the test set and report the accuracy.
0.89, not bad for a simple finetuning! See [paperswithcode](https://paperswithcode.com/sota/audio-classification-on-esc-50) for the state of the art on the ESC50 dataset.
Keep in mind, that we are training on the first 4 folds and test on the 5th fold, so the results are not directly comparable to the state of the art as this envolves cross-validation on all 5 folds.

In [44]:
trainer.test(module, dm.test_dataloader())

Repo card metadata block was not found. Setting CardData to empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.5572055578231812, 'test_acc': 0.8874237537384033}]