# Sliding window CNN for predicting notes!

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from scipy import signal
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import IPython.display as ipydisplay
import functools
import librosa
import librosa.display as ldisplay
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pathlib
import scipy.io.wavfile as wav
import torch

In [None]:
%matplotlib inline

In [None]:
from src import dataset, bincounts

### Dataset/dataloader

In [None]:
def get_dataset_dataloader(folder_path, crop_len_sec=2, sample=False):
    _dataset = dataset.SignalWindowDataset(folder_path=folder_path, crop_len_sec=crop_len_sec)

    sampler = None
    if sample:
        sampler = torch.utils.data.WeightedRandomSampler(weights=_dataset.sampling_weights, num_samples=1, replacement=True)

    _dataloader = DataLoader(_dataset, batch_size=1, sampler=sampler)
    print(len(_dataset), len(_dataloader))
    return _dataset, _dataloader

In [None]:
dataset_train, dataloader_train = get_dataset_dataloader(folder_path='/home/anuj/data/m/disk/train/', sample=True)
dataset_val, dataloader_val = get_dataset_dataloader(folder_path='/home/anuj/data/m/disk/val', sample=True)

In [None]:
# # Check the sampling
# ix = 0
# file_paths = []
# while ix < 1000:
#     batch = next(iter(dataloader_train))
#     file_paths.append(batch['file_path'])
#     ix += 1
# dataset_train.df_stats[['file_path', 'seconds']].sort_values(['seconds'], ascending=False)

### Bincounts

In [None]:
%%time
weights = bincounts.get_bin_counts(dataloader_train, keys=['labels'], n_iters=1000)

In [None]:
weights['labels']

In [None]:
plt.plot(weights['labels'], 'x-')

### Model

In [None]:
from src.models.frame_cnn import SimpleFrameCNNWithNotes

In [None]:
DEVICE = 'cuda:3'

In [None]:
model = SimpleFrameCNNWithNotes(n_feats=513).to(DEVICE)

In [None]:
batch = next(iter(dataloader_train))
print(batch['features'].shape, batch['labels'].shape)
print(batch['labels'])

inputs = batch['features'].to(DEVICE)
onset_probs, notes_activations = model(inputs)
assert np.all(onset_probs.shape[1:] == np.array([2, 1, inputs.shape[-1]]))

### Loss / optimizer

In [None]:
weights_l = Variable(torch.from_numpy(weights['labels'].astype(np.float32)))
onset_loss_func = torch.nn.NLLLoss(weight=weights_l, ignore_index=-100).to(DEVICE)
notes_loss_func = torch.nn.BCEWithLogitsLoss()

In [None]:
optimizer = torch.optim.Adam(params=model.parameters())

### logging

In [None]:
model_str = 'docmus-with-notes-all-1.00'

# logging
weights_folder = "../weights/{}".format(model_str)
log_folder =  '../tensorboard-logs/{}'.format(model_str)
writer = SummaryWriter(log_folder) # writing log to tensorboard
print('logging to: {}'.format(weights_folder))

os.makedirs(weights_folder, exist_ok=False)  # MEANT TO FAIL IF IT ALREADY EXISTS

### Train

In [None]:
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm
import collections
import mir_eval

In [None]:
Results = collections.namedtuple('Results', ['onset_loss', 'notes_loss', 'precision', 'recall', 'f1', 'support', ])

In [None]:
def get_onset_times_from_window_labels(onset_windows, window_size, sr):
    seconds_in_window = window_size / sr
    pred_onsets = np.where(onset_windows.squeeze() == 1)[0] * seconds_in_window + seconds_in_window / 2.
    pred_onsets = np.unique(np.round(pred_onsets, decimals=2))
    return pred_onsets


def predict_and_evalaute(batch, model, onset_loss_func, notes_loss_func, device, visualize=False):
    inputs = batch['features'].to(device)
    target_labels, target_notes = batch['labels'].to(device), batch['notes'].to(device)

    # Predict
    onset_probs, notes_activations = model(inputs)
    assert np.all(onset_probs.shape[1:] == np.array([2, 1, inputs.shape[-1]]))
    pred_onsets = torch.argmax(onset_probs, dim=1)
    
    # loss
    onset_loss = onset_loss_func(onset_probs, target_labels)
    notes_loss = notes_loss_func(notes_activations.squeeze(), target_notes.squeeze())

    pred_onsets = pred_onsets.data.cpu().numpy()
    target_labels = target_labels.data.cpu().numpy()

    p, r, f, s = precision_recall_fscore_support(target_labels.squeeze(), pred_onsets.squeeze(), labels=[0, 1])
    results = Results(onset_loss=onset_loss, notes_loss=notes_loss, precision=p.mean(), recall=r.mean(), f1=f.mean(), support=s)

    if visualize:
        pred_onsets = get_onset_times_from_window_labels(pred_onsets, 1024, batch['sr'].data.numpy()[0])
        plot_preds_gt(batch, pred_onsets)

    return pred_onsets, results

In [None]:
def visualize_predictions(signal, sr, spec, target_onsets, pred_onsets ,figsize=(40, 15)):
    # Calculate stuff
    duration = signal.shape[0] / sr
    n_segments = spec.shape[1]
    segment_starts_in_s = np.linspace(0, duration, n_segments + 1)

    plt.figure(figsize=figsize)

    # Plot the signal
    plt.subplot(2, 1, 1)
    ldisplay.waveplot(signal, sr=sr)

    ymax, ymin = max(signal) + 0.05, min(signal) - 0.05
    plt.vlines(target_onsets, ymin=ymin, ymax=ymax, colors='g', linestyles='--', linewidths=3)  # plot the signal
    plt.vlines(segment_starts_in_s, ymax=ymax, ymin=ymin, colors='gray', linestyles='--', linewidths=1)  # Plot the segment lines
    plt.vlines(pred_onsets, ymax=ymax, ymin=ymin, colors='r', linestyles='--', linewidths=2)  # Plot the pred onset lines

    # Plot the spectrum
    plt.subplot(2, 1, 2)
    ldisplay.specshow(librosa.amplitude_to_db(spec), sr=sr, x_axis='time', y_axis='hz', hop_length=1024)

    ymax, ymin = 22000, 0
    plt.vlines(target_onsets, ymin=ymin, ymax=ymax, colors='g', linestyles='--', linewidths=3)  # plot the signal
    plt.vlines(segment_starts_in_s, ymax=ymax, ymin=ymin, colors='gray', linestyles='--', linewidths=1)  # Plot the segment lines
    plt.vlines(pred_onsets, ymax=ymax, ymin=ymin, colors='r', linestyles='--', linewidths=2)  # Plot the pred onset lines


def plot_preds_gt(batch, pred_onsets, window_size=1024, figsize=(40, 15)):
    # get the data
    signal = batch['signal'][0].data.cpu().numpy().ravel()
    sr = batch['sr'].data.numpy()[0]
    spec = batch['features'][0].data.cpu().numpy().squeeze()
    target_onsets = batch['onsets'][0].data.cpu().numpy().squeeze()

    visualize_predictions(signal, sr, spec, target_onsets, pred_onsets)

In [None]:
n_epochs = 1000000
val_every = 1000
save_every = 10000
n_val = 5

In [None]:
train_size = len(dataloader_train)

In [None]:
epoch = 0
alpha = 1.

In [None]:
iteration = 0

In [None]:
while epoch < n_epochs:
    for i_batch, train_batch in tqdm(enumerate(dataloader_train)):
        iteration += 1

        # predict
        pred_labels, train_results = predict_and_evalaute(train_batch, model, onset_loss_func, notes_loss_func, DEVICE)
        onset_loss, notes_loss = train_results.onset_loss, train_results.notes_loss
        loss = alpha * onset_loss + notes_loss
        
        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        writer.add_scalar('loss.train', loss.data.cpu().numpy(), iteration)
        writer.add_scalar('loss.onset.train', onset_loss, iteration)
        writer.add_scalar('loss.notes.train', notes_loss, iteration)

        writer.add_scalar('acc.precision.train', train_results.precision, iteration)
        writer.add_scalar('acc.recall.train', train_results.recall, iteration)

        if iteration % val_every == 0:
            val_onset_loss, val_notes_loss = 0, 0
            average_precision, average_recall = 0, 0
            for ix, val_batch in enumerate(dataloader_val):
                _, results = predict_and_evalaute(val_batch, model, onset_loss_func, notes_loss_func, DEVICE, visualize=ix<2)
                val_onset_loss += results.onset_loss.data.cpu().numpy()
                val_notes_loss += results.notes_loss.data.cpu().numpy()

                average_precision += results.precision
                average_recall += results.recall
            
            # average out over all the batches
            val_onset_loss, val_notes_loss = val_onset_loss / (ix + 1), val_notes_loss / (ix + 1)

            # log!
            writer.add_scalar('loss.val', alpha * val_onset_loss + val_notes_loss, iteration)
            writer.add_scalar('loss.onset.val', val_onset_loss, iteration)
            writer.add_scalar('loss.notes.val', val_notes_loss, iteration)

            writer.add_scalar('acc.precision.val', average_precision / (ix + 1), iteration)
            writer.add_scalar('acc.recall.val', average_recall / (ix + 1), iteration)
            plt.show()
            
        if iteration % save_every == 0:
            torch.save(model.state_dict(), os.path.join(weights_folder, '{}.pt'.format(iteration)))

    epoch += 1

In [None]:
epoch, iteration, i_batch

In [None]:
n_epochs = 10000000000
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)