## Generate train-val split

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import csv

import numpy as np
import pandas as pd

TRAIN_FRAC = 0.9

rng = np.random.default_rng(8228)

In [None]:
ds = pd.read_csv('data/ssp/all_labels.csv')

In [None]:
p = rng.permutation(len(ds))
train_ds = ds.iloc[p[:int(TRAIN_FRAC * len(ds))]]
val_ds = ds.iloc[p[int(TRAIN_FRAC * len(ds)):]]

In [None]:
train_ds.to_csv('data/ssp/train_labels.csv', index=False, quoting=csv.QUOTE_ALL)
val_ds.to_csv('data/ssp/val_labels.csv', index=False, quoting=csv.QUOTE_ALL)

## Fine-tuning

In [1]:
#!g1.1
%load_ext autoreload
%autoreload 2

In [2]:
#!g1.1
from pathlib import Path
import warnings

import pandas as pd
import torch

from src.datasets import MultiLabelClassificationCollator, SSPNetVC
from src.models import PretrainedPaSST
from src.trainers import MultiLabelClassificationTrainer

BATCH_SIZE = 8
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
NUM_AUDIOSET_CLASSES = 527
THRESHOLD = 0.2

In [3]:
#!g1.1
labels_df = pd.read_csv('data/audioset/class_labels_indices.csv')
label2ind = dict(zip(labels_df['display_name'], labels_df['index']))
assert len(label2ind) == NUM_AUDIOSET_CLASSES
label2ind['Filler'] = NUM_AUDIOSET_CLASSES

In [4]:
#!g1.1
train_ds = SSPNetVC(csv=Path('data/ssp/train_labels.csv'))
val_ds = SSPNetVC(csv=Path('data/ssp/val_labels.csv'))

In [5]:
#!g1.1
collator = MultiLabelClassificationCollator(label2ind)
train_loader = torch.utils.data.DataLoader(train_ds, BATCH_SIZE,
                                           shuffle=True, collate_fn=collator,
                                           pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_ds, BATCH_SIZE,
                                           shuffle=True, collate_fn=collator,
                                           pin_memory=True)

In [6]:
#!g1.1
model = PretrainedPaSST(num_new_classes=1, max_audio_len=11.0).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), 1e-4)
trainer = MultiLabelClassificationTrainer(model, opt, train_loader, val_loader, DEVICE)

Downloading: "https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt" to /tmp/xdg_cache/torch/hub/checkpoints/passt-s-f128-p16-s10-ap.476-swa.pt




 Loading PASST TRAINED ON AUDISET 


PaSST(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwi

In [7]:
#!g1.1
warnings.filterwarnings('ignore', 'Input image size', UserWarning)
for i in range(10):
    print('Epoch', i + 1)
    trainer.train_loop(1)
    torch.save(trainer.model, f'checkpoint{i}.pth')

  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
100%|██████████| 311/311 [03:19<00:00,  1.56it/s]
100%|██████████| 35/35 [00:09<00:00,  3.78it/s]
100%|██████████| 311/311 [03:19<00:00,  1.56it/s]
100%|██████████| 35/35 [00:09<00:00,  3.78it/s]
100%|██████████| 311/311 [03:20<00:00,  1.55it/s]
100%|██████████| 35/35 [00:09<00:00,  3.77it/s]
100%|██████████| 311/311 [03:20<00:00,  1.55it/s]
100%|██████████| 35/35 [00:09<00:00,  3.78it/s]
100%|██████████| 311/311 [03:20<00:00,  1.55it/s]
100%|██████████| 35/35 [00:09<00:00,  3.78it/s]
100%|██████████| 311/311 [03:20<00:00,  1.55it/s]
100%|██████████| 35/35 [00:09<00:00,  3.77it/s]
100%|██████████| 311/311 [03:20<00:00,  1.55it/s]
100%|██████████| 35/35 [00:09<00:00,  3.78it/s]
100%|██████████| 311/311 [03:20<00:00,  1.55it/s]
100%|██████████| 35/35 [00:09<00:00,  3.77it/s]
100%|██████████| 311/311 [03:20<00:00,  1.55it/s]
100%|██████████| 35/35 [00:09<00:00,  3.78it/s]
100%|██████████| 311/311

Epoch 1
Epoch 1:
x torch.Size([8, 1, 128, 1100])
self.norm(x) torch.Size([8, 768, 12, 109])
 patch_embed :  torch.Size([8, 768, 12, 109])
 self.time_new_pos_embed.shape torch.Size([1, 768, 1, 1100])
 CUT time_new_pos_embed.shape torch.Size([1, 768, 1, 109])
 self.freq_new_pos_embed.shape torch.Size([1, 768, 12, 1])
X flattened torch.Size([8, 1308, 768])
 self.new_pos_embed.shape torch.Size([1, 2, 768])
 self.cls_tokens.shape torch.Size([8, 1, 768])
 self.dist_token.shape torch.Size([8, 1, 768])
 final sequence x torch.Size([8, 1310, 768])
 after 12 atten blocks x torch.Size([8, 1310, 768])
forward_features torch.Size([8, 768])
head torch.Size([8, 528])
Train loss: tensor(0.0086, device='cuda:0')
Train mAP: tensor(0.8558, device='cuda:0')
Validation loss: tensor(0.0065, device='cuda:0')
Validation mAP: tensor(0.9014, device='cuda:0')
----------------------------------------------------------------------------------------------------
Epoch 2
Epoch 1:
Train loss: tensor(0.0070, device='cu