# Train a CNN model on sliding windows!

### TO-DOs
* [x] trian/val/test split
* [ ] random crop from the signal
* [x] tensorboard
* [x] weights saving
* [ ] Visualization of the results during training
* [x] P/R during training

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from scipy import signal
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import IPython.display as ipydisplay
import functools
import librosa
import librosa.display as ldisplay
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pathlib
import scipy.io.wavfile as wav
import torch

In [None]:
%matplotlib inline

In [None]:
from src import dataset, bincounts

### Dataset/dataloader

In [None]:
dataset_train = dataset.SignalWindowDataset(folder_path='/home/anuj/data/m/p_cl/train/',)
dataloader_train = DataLoader(dataset_train, batch_size=1)

dataset_val = dataset.SignalWindowDataset(folder_path='/home/anuj/data/m/p_cl/val/',)
dataloader_val = DataLoader(dataset_val, batch_size=1)

print(len(dataloader_train), len(dataloader_val))

In [None]:
batch = next(iter(dataloader_train))
print(batch['features'].shape, batch['labels'].shape)
print(batch['labels'])

### Bincounts

In [None]:
%%time
weights = bincounts.get_bin_counts(dataloader_train, keys=['labels'], n_iters=10)

In [None]:
weights['labels']

### Model

In [None]:
import torch.nn.functional as F

class ConvBlock(torch.nn.Module):
    def __init__(self, n_in, n_out, kernel_size, stride, padding):
        super().__init__()
        
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding),
            torch.nn.ReLU(inplace=True),
            torch.nn.BatchNorm2d(n_out),
            torch.nn.Dropout2d(p=0, inplace=True),
        )
    
    def forward(self, x):
        return self.block(x)
        
        
class SimpleFrameCNN(torch.nn.Module):
    def __init__(self, n_feats, n_channels_in=1, n_classes=2,) -> None:
        super().__init__()
        
        self.feature_extractor = torch.nn.Sequential(
            ConvBlock(n_channels_in, 16, kernel_size=9, stride=1, padding=4),
            ConvBlock(16, 32, kernel_size=7, stride=1, padding=3),
            ConvBlock(32, 64, kernel_size=5, stride=1, padding=2),
            ConvBlock(64, 128, kernel_size=3, stride=1, padding=1),
        )

        self.classifier = torch.nn.Sequential(
            torch.nn.Conv2d(128, n_classes, kernel_size=(n_feats, 1), stride=1, padding=0)
        )

    def forward(self, x):
        feats = self.feature_extractor(x)
        probs = F.log_softmax(self.classifier(feats), dim=1)
        return probs

In [None]:
DEVICE = 'cuda:1'

In [None]:
model = SimpleFrameCNN(n_feats=513).to(DEVICE)

In [None]:
inputs = batch['features'].to(DEVICE)
pred_probs = model(inputs).shape[1:]
assert np.all(pred_probs == np.array([2, 1, inputs.shape[-1]]))

### Loss / optimizer

In [None]:
weights_l = Variable(torch.from_numpy(weights['labels'].astype(np.float32)))
loss_func = torch.nn.NLLLoss(weight=weights_l, ignore_index=-100).to(DEVICE)

In [None]:
optimizer = torch.optim.Adam(params=model.parameters())

### logging

In [None]:
model_str = 'docmus-1.01'

# logging
weights_folder = "/opt/weights/{}".format(model_str)
log_folder =  '../tensorboard-logs/{}'.format(model_str)
writer = SummaryWriter(log_folder) # writing log to tensorboard
print('logging to: {}'.format(weights_folder))

os.makedirs(weights_folder)  # MEANT TO FAIL IF IT ALREADY EXISTS

### Train

In [None]:
from sklearn.metrics import precision_recall_fscore_support

In [None]:
import collections

In [None]:
Results = collections.namedtuple('Results', ['precision', 'recall', 'f1', 'support', ])

In [None]:
def predict_and_evalaute(batch, model, loss_func, device, visualize=False):
    inputs = batch['features'].to(device)
    target_labels = batch['labels'].to(device)

    # Predict
    label_probs = model(inputs)
    assert np.all(label_probs.shape[1:] == np.array([2, 1, inputs.shape[-1]]))
    pred_labels = torch.argmax(label_probs, dim=1)
    
    # loss
    loss = loss_func(label_probs, target_labels)

    if visualize:
        pred_labels = pred_labels.data.cpu().numpy()
        target_labels = target_labels.data.cpu().numpy()

#         print(pred_labels)
#         print(target_labels)
        
        p, r, f, s = precision_recall_fscore_support(target_labels.squeeze(), pred_labels.squeeze())
        results = Results(precision=p[1], recall=r[1], f1=f[1], support=s)

        return pred_labels, loss, results

    return pred_labels, loss

In [None]:
n_epochs = 100000  # Each epoch would only see a sample each from 26 files
val_every = 10
save_every = 1000
n_val = 5

In [None]:
train_size = len(dataloader_train)

In [None]:
epoch = 0

In [None]:
from tqdm import tqdm

In [None]:
while epoch < n_epochs:
    for i_batch, train_batch in tqdm(enumerate(dataloader_train)):
        iteration = epoch * train_size + i_batch

        # predict
        pred_labels, loss = predict_and_evalaute(train_batch, model, loss_func, DEVICE)
        
        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        writer.add_scalar('loss.train', loss.data.cpu().numpy(), iteration)

        if iteration % val_every == 0:
            val_loss_total = 0
            average_precision, average_recall = 0, 0
            for ix, val_batch in enumerate(dataloader_val):
                _, val_loss, results = predict_and_evalaute(val_batch, model, loss_func, DEVICE, visualize=True)
                val_loss_total += val_loss.data.cpu().numpy() / n_val
                average_precision += results.precision / n_val
                average_recall += results.recall / n_val
            
            writer.add_scalar('loss.val', val_loss_total, iteration)
            writer.add_scalar('acc.precision.val', average_precision, iteration)
            writer.add_scalar('acc.recall.val', average_recall, iteration)
            
        if iteration % save_every == 0:
            torch.save(model.state_dict(), os.path.join(weights_folder, '{}.pt'.format(iteration)))

    epoch += 1

In [None]:
epoch

In [None]:
iteration

In [None]:
i_batch