In [2]:
%pip install -r ../../requirements.txt

In [37]:
import json
import sys
from types import SimpleNamespace
from glob import glob
import ntpath

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn import metrics
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import librosa

sys.path.append('../../src')

from Lucas_custom.data import AudioDataset, DataModule
from Lucas_custom.trainer import TrainModule
from Lucas_custom.net import SimpleCNN
from Lucas_custom.utils import batch_to_device, get_state_dict, get_min_max, min_max_norm
from Lucas_custom.eval import inference_k_random, error_analysis

from custom.utils import Mixup


%load_ext autoreload
%autoreload 2

full_path = '../../'

In [20]:
def collate_fn(batch):
    return {
      'wave': torch.stack([x['wave'] for x in batch]),
      'labels': torch.stack([x['labels'] for x in batch])
    }

In [65]:
cfg = SimpleNamespace()

cfg.wav_crop_len = 5

cfg.n_classes = 66
cfg.pretrained = True
cfg.backbone = 'tf_efficientnetv2_s.in21k'
cfg.in_chans = 1
cfg.num_workers = 4
cfg.include_val = False
cfg.max_amp = False

cfg.batch_size = 32
cfg.sample_rate = 44100

# Mel Spectogram Hyperparameters
cfg.n_mels = 128
cfg.n_fft = 2048
cfg.fmin = 300
cfg.fmax = cfg.sample_rate / 2 
cfg.window_size = cfg.n_fft
cfg.hop_length = int(cfg.n_fft / 2)
cfg.power = 2
cfg.top_db = 80.0

# Norm
cfg.mel_normalized = True
cfg.minmax_norm = False

# Augmentation Parameters
cfg.impulse_prob = 0.2
cfg.noise_prob = 0.2
cfg.mixup = False
cfg.specaug = False
cfg.specaug_prob = 0.25
cfg.mixup_prob = 1
cfg.max_noise = 0.04
cfg.min_snr = 5
cfg.max_snr = 20

In [66]:
df = pd.read_csv(f"../../data/metadata.csv")
df['path'] = df['path'].apply(lambda x: f'{full_path}/{x}')

In [67]:
model = SimpleCNN(cfg).eval()

In [68]:
pred_ds = AudioDataset(df.iloc[[10]], mode='val', cfg=cfg)
pred_ds1 = AudioDataset(df.iloc[[10, *np.random.randint(0, 100, size=32)]], mode='val', cfg=cfg)
pred_ds2 = AudioDataset(df.iloc[[10, *np.random.randint(0, 100, size=10)]], mode='val', cfg=cfg)

b1 = collate_fn([pred_ds.__getitem__(i) for i in range(len(pred_ds))])
b2 = collate_fn([pred_ds1.__getitem__(i) for i in range(len(pred_ds1))])
b3 = collate_fn([pred_ds2.__getitem__(i) for i in range(len(pred_ds2))])

print('batch x channels x time')
print('Sum(abs(batch1-batch2))')

imgbatch = model.wav2img(b1['wave'][:, None, :])
imgbatch2 = model.wav2img(b2['wave'][:, None, :])
imgbatch3 = model.wav2img(b3['wave'][:, None, :])
print(torch.sum(torch.abs(imgbatch[0]-imgbatch3[0])))
print(torch.sum(torch.abs(imgbatch[0]-imgbatch2[0])))
print(torch.sum(torch.abs(imgbatch2[0]-imgbatch3[0])))

In [69]:
mixup_prob = 0.5
mixup = Mixup(mixup_prob)
mbatch, mlabels = mixup(imgbatch2, b2['labels'])

In [73]:
b2['labels'][0]

In [72]:
mlabels[0]

In [74]:
fig, ax = plt.subplots()
img = librosa.display.specshow(imgbatch2[0, 0].numpy(), ax=ax)
fig.colorbar(img, ax=ax)

In [75]:
fig, ax = plt.subplots()
img = librosa.display.specshow(mbatch[0, 0].numpy(), ax=ax)
fig.colorbar(img, ax=ax)