In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
from IPython.display import Audio
from components.dataset import AudioBinaryClassifyDS
from components.loaders import loaders
from components.model import AudioRNNBinary
from components.trainer import Trainer
from utils.plot import plot_waveform, plot_spectrogram

In [None]:
dataset = AudioBinaryClassifyDS(
    dir1='.catdog_audio/cat/',
    dir2='.catdog_audio/dog/',
    class_dict={0: 'cat', 1: 'dog'}
)  
dataset.pre_comp(n_cut=3)

In [None]:
random_idx = random.randint(0, len(dataset))
sample = dataset[random_idx]

plot_waveform(sample['waveform'], sample['sr'], 'Waveform - {}: {}'.format(random_idx, dataset.class_dict[sample['label'].item()]))
plot_spectrogram(sample['specgram'], 'MelSpectogram - {}'.format(random_idx))
Audio(sample['waveform'], rate=sample['sr'])

In [None]:
sample['specgram'].shape

In [None]:
loaders = loaders(dataset, 32)

In [None]:
input_size = len(sample['specgram'][0])
hidden_size = len(sample['specgram'][0])*2

model = AudioRNNBinary(input_size, hidden_size, num_layers=2, drop_out=0.6)
trainer = Trainer(model, loaders['train'], loaders['val'])
model

In [None]:
trainer.start(120)
trainer.plot_history()
print('best val acc: {}'.format(trainer.best_acc))

In [None]:
model = trainer.load_checkpoint()
model.eval();

r = random.randint(0, loaders['val'].batch_size)
sample = loaders['val'].dataset[r]

wf = sample['waveform']
sr = sample['sr']
spec = sample['specgram']
label = sample['label'].item()

out = model(spec.unsqueeze(0).to('mps')).item()
res = 0 if out < 0.5 else 1

print('true: {}'.format(dataset.class_dict[label]))
print('pred: {}, out={}'.format(dataset.class_dict[res], out))

Audio(wf, rate=sample['sr'])