## Generate train-val split

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import csv

import numpy as np
import pandas as pd

TRAIN_FRAC = 0.9

rng = np.random.default_rng(8228)

In [3]:
ds = pd.read_csv('data/ssp/all_labels.csv')

In [4]:
p = rng.permutation(len(ds))
train_ds = ds.iloc[p[:int(TRAIN_FRAC * len(ds))]]
val_ds = ds.iloc[p[int(TRAIN_FRAC * len(ds)):]]

In [5]:
train_ds.to_csv('data/ssp/train_labels.csv', index=False, quoting=csv.QUOTE_ALL)
val_ds.to_csv('data/ssp/val_labels.csv', index=False, quoting=csv.QUOTE_ALL)

## Fine-tuning

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import warnings

import pandas as pd
import torch

from src.datasets import MultiLabelClassificationCollator, SSPNetVC
from src.models import PretrainedPaSST
from src.trainers import MultiLabelClassificationTrainer

BATCH_SIZE = 1
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
NUM_AUDIOSET_CLASSES = 527
THRESHOLD = 0.2

In [3]:
labels_df = pd.read_csv('data/audioset/class_labels_indices.csv')
label2ind = dict(zip(labels_df['display_name'], labels_df['index']))
assert len(label2ind) == NUM_AUDIOSET_CLASSES
label2ind['Filler'] = NUM_AUDIOSET_CLASSES

In [4]:
train_ds = SSPNetVC(csv=Path('data/ssp/train_labels.csv'))
val_ds = SSPNetVC(csv=Path('data/ssp/val_labels.csv'))

In [5]:
collator = MultiLabelClassificationCollator(label2ind)
train_loader = torch.utils.data.DataLoader(train_ds, BATCH_SIZE,
                                           shuffle=True, collate_fn=collator,
                                           pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_ds, BATCH_SIZE,
                                           shuffle=True, collate_fn=collator,
                                           pin_memory=True)

In [6]:
model = PretrainedPaSST(num_new_classes=1, max_audio_len=11.0).to(DEVICE)
opt = torch.optim.Adam(model.parameters(), 1e-4)
trainer = MultiLabelClassificationTrainer(model, opt, train_loader, val_loader, DEVICE)



 Loading PASST TRAINED ON AUDISET 


PaSST(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwi

In [8]:
warnings.filterwarnings('ignore', 'Input image size', UserWarning)
trainer.train_loop(1)

Epoch 1:


  0%|          | 6/2486 [00:17<2:02:44,  2.97s/it]


KeyboardInterrupt: 