In [12]:
import argparse
from pathlib import Path


parser = argparse.ArgumentParser()
'''train'''
parser.add_argument("--max_lr", default=3e-4, type=float)
parser.add_argument("--wd", default=1e-5, type=float)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--run_name", default=None, type=Path)
parser.add_argument('--loss_type', default="label_smooth", type=str)
parser.add_argument('--n_epochs', default=None, type=int)
parser.add_argument('--epoch_mix', default=None, type=int)
parser.add_argument("--amp", action='store_true')
parser.add_argument("--filter_bias_and_bn", action='store_true', default=True)
parser.add_argument("--ext_pretrained", default=None, type=str)
parser.add_argument("--multilabel", action='store_true')
parser.add_argument('--save_path', default=None, type=Path)
parser.add_argument('--load_path', default=None, type=Path)
parser.add_argument('--scheduler', default=None, type=str)
parser.add_argument('--augs_signal', nargs='+', type=str,
                    default=['amp', 'neg', 'tshift', 'tmask', 'ampsegment', 'cycshift'])
parser.add_argument('--augs_noise', nargs='+', type=str,
                    default=['awgn', 'abgn', 'apgn', 'argn', 'avgn', 'aun', 'phn', 'sine'])
parser.add_argument('--augs_mix', nargs='+', type=str, default=['mixup', 'timemix', 'freqmix', 'phmix'])
parser.add_argument('--mix_loss', default='bce', type=str)
parser.add_argument('--mix_ratio', default=1, type=float)
parser.add_argument('--ema', default=0.995, type=float)
parser.add_argument('--log_interval', default=100, type=int)
parser.add_argument("--kd_model", default=None, type=Path)
parser.add_argument("--use_bg", action='store_true', default=False)
parser.add_argument("--resume_training", action='store_true', default=False)
parser.add_argument("--use_balanced_sampler", action='store_true', default=False)
'''common'''
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--gpu_ids', nargs='+', default=[0])
parser.add_argument("--use_ddp", action='store_true')
parser.add_argument("--use_dp", action='store_true')
parser.add_argument('--save_interval', default=100, type=int)
'''data'''
parser.add_argument('--fold_id', default=1, type=int)
parser.add_argument("--data_subtype", default='balanced', type=str)
parser.add_argument('--seq_len', default=90112, type=int)
parser.add_argument('--dataset', default="urban8k", type=str)
parser.add_argument('--n_classes', default=50, type=int)
'''net'''
parser.add_argument('--ds_factors', nargs='+', type=int, default=[4, 4, 4, 4])
parser.add_argument('--n_head', default=8, type=int)
parser.add_argument('--n_layers', default=4, type=int)
parser.add_argument("--emb_dim", default=128, type=int)
parser.add_argument("--model_type", default='SoundNetRaw', type=str)
parser.add_argument("--nf", default=16, type=int)
parser.add_argument("--dim_feedforward", default=512, type=int)
parser.add_argument("--sampling_rate", default=22050, type=int)
'''system'''
parser.add_argument('--data_dir', default='data/', type=Path)
parser.add_argument('--gpus', type=list, default=[0])
parser.add_argument('--num_workers', type=int, default=32)
args = parser.parse_args(args=[])


## Data

### ESC50 Dataset
The ESC-50 dataset is a labeled collection of 2000 environmental audio recordings suitable for benchmarking methods of environmental sound classification.

The dataset consists of 5-second-long recordings organized into 50 semantical classes.
[Github](https://github.com/karolpiczak/ESC-50)

[Huggingface](https://huggingface.co/datasets/ashraq/esc50) "ashraq/esc50"

In [5]:
from datasets import load_dataset
from datasets import Audio
esc50 = load_dataset("ashraq/esc50")
esc50 = esc50.cast_column("audio", Audio(sampling_rate=args.sampling_rate))
# split into train, val, test
esc50 = esc50['train'].train_test_split(test_size=0.2, shuffle=True)
esc50


Repo card metadata block was not found. Setting CardData to empty.


DatasetDict({
    train: Dataset({
        features: ['filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take', 'audio'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take', 'audio'],
        num_rows: 400
    })
})

In [6]:
# rename column target to label
esc50 = esc50.rename_column("target", "label")

In [8]:

dataset_test = esc50['test'].with_format('torch', columns=['audio', 'label'])
dataset_train = esc50['train'].with_format('torch', columns=['audio', 'label'])

In [9]:
dataset_train

Dataset({
    features: ['filename', 'fold', 'label', 'category', 'esc10', 'src_file', 'take', 'audio'],
    num_rows: 1600
})

In [41]:
import random
import torch.nn.functional as F

def preprocess_audio(example):
    audio = example['audio']
    samling_rate = audio['sampling_rate']
    audio = audio['array']
    if audio.shape[0] >= args.seq_len:
        max_audio_start = audio.size(0) - args.seq_len
        audio_start = random.randint(0, max_audio_start)
        audio = audio[audio_start : audio_start + args.seq_len]
    else:
        audio = F.pad(
            audio, (0, args.seq_len - audio.size(0)), "constant"
        ).data
    example['audio'] = audio
    return example

preprocess_audio(dataset_train[0])

{'label': tensor(39),
 'audio': tensor([-0.0017,  0.0006, -0.0022,  ...,  0.0000,  0.0000,  0.0000])}

In [19]:
dataset_train[0]

{'label': tensor(39),
 'audio': {'path': None,
  'array': tensor([ 3.5112e-07, -3.0661e-07,  2.1097e-07,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]),
  'sampling_rate': tensor(22050)}}

In [26]:
import lightning as L
from torch.utils.data import random_split, DataLoader
from datasets import load_dataset
from datasets import Audio

import random
import torch.nn.functional as F

def preprocess_audio(example):
    audio = example['audio']
    samling_rate = audio['sampling_rate']
    audio = audio['array']
    if audio.shape[0] >= args.seq_len:
        max_audio_start = audio.size(0) - args.seq_len
        audio_start = random.randint(0, max_audio_start)
        audio = audio[audio_start : audio_start + args.seq_len]
    else:
        audio = F.pad(
            audio, (0, args.seq_len - audio.size(0)), "constant"
        ).data
    example['audio'] = audio
    return example

class ESC50DataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = args.data_dir):
        super().__init__()
        self.data_dir = data_dir
        
    # called only within a single process on CPU
    def prepare_data(self):
        # download
        load_dataset("ashraq/esc50")

    # run on each GPU
    def setup(self, stage: str):
        esc50 = load_dataset("ashraq/esc50")
        esc50 = esc50.cast_column("audio", Audio(sampling_rate=args.sampling_rate))
        # rename column target to label
        esc50 = esc50.rename_column("target", "label")

        # split into train, val, test
        esc50 = esc50['train'].train_test_split(test_size=0.2, shuffle=True)
        self.dataset_test = esc50['test'].with_format('torch', columns=['audio', 'label'])
        self.dataset_train = esc50['train'].with_format('torch', columns=['audio', 'label'])
        self.dataset_train = self.dataset_train.map(preprocess_audio)
        self.dataset_test = self.dataset_test.map(preprocess_audio)
       



    def train_dataloader(self):
        return DataLoader(self.dataset_train, batch_size=args.batch_size, num_workers=args.num_workers,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
        )

    def val_dataloader(self):
         return DataLoader(self.dataset_test, batch_size=args.batch_size, num_workers=args.num_workers,
        pin_memory=True,
        shuffle=False,
        drop_last=True,
        )

    # def test_dataloader(self):
    #      return DataLoader(self.dataset_test, batch_size=args.batch_size, num_workers=args.num_workers,
    #     pin_memory=True,
    #     shuffle=False,
    #     drop_last=True,
    #     )

datamodule = ESC50DataModule()

In [44]:
datamodule.setup(stage='fit')

Repo card metadata block was not found. Setting CardData to empty.


Epoch 0:   0%|          | 0/12 [14:57<?, ?it/s]


Map: 100%|██████████| 1600/1600 [01:08<00:00, 23.40 examples/s]
Map: 100%|██████████| 400/400 [00:01<00:00, 220.34 examples/s]


In [45]:
datamodule.train_dataloader()

<torch.utils.data.dataloader.DataLoader at 0x7fceb3500940>

In [46]:
train_dataloader = datamodule.train_dataloader()
first_batch = next(iter(train_dataloader))
x, y = first_batch

In [51]:
x = first_batch['audio']
x

tensor([[ 4.3875e-02, -3.4781e-02, -6.9376e-02,  ..., -2.7814e-02,
          1.0329e-01,  4.5388e-01],
        [ 1.3737e-04, -3.0496e-05,  4.6044e-04,  ...,  6.5663e-04,
          7.5605e-04,  9.8250e-04],
        [ 3.3793e-02,  4.1921e-02,  4.5471e-02,  ..., -8.2139e-02,
         -8.3095e-02, -8.4711e-02],
        ...,
        [ 2.9470e-02,  3.2372e-02,  2.8961e-02,  ..., -2.3377e-02,
          4.7540e-03,  1.9314e-02],
        [-1.2007e-02, -1.0140e-02, -8.9423e-03,  ...,  1.8060e-02,
          2.0214e-02,  2.1049e-02],
        [-1.3702e-01, -1.0605e-01, -1.0402e-01,  ..., -2.8995e-02,
         -2.9097e-02, -2.8035e-02]])

In [48]:
first_batch

{'label': tensor([41, 38,  9, 22,  6, 37, 44,  7,  8,  0, 10, 19, 45, 34, 31, 47,  8, 29,
         24, 40, 37, 39, 45, 12, 44, 42, 37, 32, 13, 47, 37, 23, 22,  0, 20, 16,
         20, 45, 27, 40, 13, 44,  6, 49, 12,  6, 46,  6,  9, 20,  5, 34, 21, 36,
         24, 15, 33, 32, 39, 37, 38, 29, 21, 30, 27, 37, 45, 43,  2, 21,  6,  0,
         27, 32, 13, 10, 21, 13, 39,  4, 25,  5, 32, 30, 20, 45, 14, 35, 46, 17,
          2,  5,  2,  0, 35,  5, 30, 36, 13,  2, 28,  9, 18, 28, 11, 42, 35, 49,
          8,  6, 26, 49, 14, 25, 21, 49, 41,  8, 49, 44,  5, 22, 20, 48,  3,  2,
         40, 26]),
 'audio': tensor([[ 4.3875e-02, -3.4781e-02, -6.9376e-02,  ..., -2.7814e-02,
           1.0329e-01,  4.5388e-01],
         [ 1.3737e-04, -3.0496e-05,  4.6044e-04,  ...,  6.5663e-04,
           7.5605e-04,  9.8250e-04],
         [ 3.3793e-02,  4.1921e-02,  4.5471e-02,  ..., -8.2139e-02,
          -8.3095e-02, -8.4711e-02],
         ...,
         [ 2.9470e-02,  3.2372e-02,  2.8961e-02,  ..., -2.3377e-02,

## Model

In [33]:
import torch
from workspace.datasets.batch_augs import BatchAugs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


ba_params = {
        'seq_len': args.seq_len,
        'fs': args.sampling_rate,
        'augs': args.augs_mix,
        'device': device,
        'mix_ratio': args.mix_ratio,
        'batch_sz': args.local_rank,
        'epoch_mix': args.epoch_mix,
        'resample_factors': [0.8, 0.9, 1.1, 1.2],
        'multilabel': True if args.multilabel else False,
        'mix_loss': args.mix_loss
    }
batch_augs = BatchAugs(ba_params)

In [34]:
import torch.nn as nn
#####################
# losses            #
#####################
if args.loss_type == "label_smooth":
    from modules.losses import LabelSmoothCrossEntropyLoss
    criterion = LabelSmoothCrossEntropyLoss(smoothing=0.1, reduction='sum')
elif args.loss_type == "cross_entropy":
    criterion = nn.CrossEntropyLoss(reduction='sum')
elif args.loss_type == "focal":
    from modules.losses import FocalLoss
    criterion = FocalLoss()
elif args.loss_type == 'bce':
    criterion = nn.BCEWithLogitsLoss(reduction='sum')
else:
    raise ValueError

In [52]:
import numpy as np
import torch, torch.nn as nn
import lightning as L
from workspace.datasets.batch_augs import BatchAugs
from modules.soundnet import SoundNetRaw as SoundNet

class EAT(L.LightningModule):
    def __init__(self):
        super().__init__()
        ds_fac = np.prod(np.array(args.ds_factors)) * 4
        self.model = SoundNet(
                nf=args.nf,
                dim_feedforward=args.dim_feedforward,
                clip_length=args.seq_len // ds_fac,
                embed_dim=args.emb_dim,
                n_layers=args.n_layers,
                nhead=args.n_head,
                n_classes=args.n_classes,
                factors=args.ds_factors,
                )

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x = batch['audio']
        y = batch['label']
        x, targets, is_mixed = batch_augs(x, y) # TODO: removed epoch parameter
        pred = self(x)
        if is_mixed:
            loss_cls = batch_augs.mix_loss(pred, targets, n_classes=args.n_classes,
            pred_one_hot=args.multilabel)
        else:
            loss_cls = criterion(pred, y)
        self.log('loss_cls', loss_cls)
        return loss_cls
    
    # def validation_step(self, batch, batch_idx):
    #     loss = self.training_step(batch, batch_idx)
    #     self.log('val_loss', loss)
    #     return loss
    
    def configure_optimizers(self):
        if args.amp:
            from torch.cuda.amp import GradScaler
            scaler = GradScaler(init_scale=2**10)
            eps = 1e-4
        else:
            scaler = None
            eps = 1e-8
        parameters = self.model.parameters()
        return torch.optim.AdamW(parameters,
                            lr=args.max_lr,
                            betas=[0.9, 0.99],
                            weight_decay=0,
                            eps=eps)
model = EAT()


## Trainer

In [53]:
from lightning import Trainer


trainer = L.Trainer(max_epochs=args.n_epochs, accelerator='gpu', devices=args.gpus) # set devices to a list of GPU ids to train on
# start training 
trainer.fit(model, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


  rank_zero_warn(
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Map: 100%|██████████| 1600/1600 [00:18<00:00, 84.41 examples/s]
Map: 100%|██████████| 400/400 [00:04<00:00, 96.60 examples/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name  | Type        | Params
--------------------------------------
0 | model | SoundNetRaw | 5.2 M 
--------------------------------------
5.2 M     Trainable params
0         Non-trainable params
5.2 M     Total params
20.722    Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/12 [00:00<?, ?it/s] 

RuntimeError: Given groups=1, weight of size [16, 1, 7], expected input[128, 128, 90118] to have 1 channels, but got 128 channels instead