In [1]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False

In [2]:
device = 2
torch.cuda.set_device(device)


In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']).union(set(self.overlays['second_speaker'])))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(overlay['first_file'])/(2**15)
        second_segment = np.load(overlay['second_file'])/(2**15)
        #padding to compensate rounding errors
        if len(first_segment)>len(second_segment):
            padding = np.zeros(len(first_segment)-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
        
        if len(first_segment)<len(second_segment):
            padding = np.zeros(len(second_segment)-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.mfcc(segment, sr=16000, n_mfcc=20, dct_type=2, n_fft = 1024, hop_length = 160)[1:14].T
        # 200*13
        S1 = np.diff(S)
        S2 = np.diff(S1)
        S = np.concatenate((S, S1, S2), axis = -1)
        return S
    
trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
testset = OverlayDataSet('overlay-test.csv', False)
print(trainset.speakers)
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)

['andrea_arsenault', 'brian_lamb', 'csp_waj_susan', 'david_brancaccio', 'eddie_mair', 'joie_chen', 'kathleen_kennedy', 'leon_harris', 'linda_wertheimer', 'linden_soles', 'lisa_mullins', 'lou_waters', 'lynn_vaughan', 'mark_mullen', 'natalie_allen', 'noah_adams', 'peter_jennings', 'robert_siegel', 'ted_koppel', 'thalia_assuras']


<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [13]:
num_heads_2 = 4 # MHA heads


class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.LayerNorm(36)
        self.reshape =  Lambda(lambda x: x.permute((1, 0, 2))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(36, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.mha =  torch.nn.MultiheadAttention(64, num_heads = num_heads_2, dropout=dropout, bias=True, kdim=64, vdim=64) # L * N * 64
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.sigmoid = nn.Sigmoid()
    def forward(self, X):
        X = self.bn(X)
        X = self.reshape(X)
        X, _ = self.lstm(X)
        X, _ = self.mha(X, X, X)
        X = self.fc1(X)
        X = self.average(X)
        X = self.tanh(X)
        X = self.fc2(X)
        X = self.sigmoid(X)
        return X
    
    
overnet = Baseline().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/baseline.pth'):
    print('load model')
    checkpoint = torch.load('models/baseline.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

load model


('bestacc:', 0.0)

## Also Do metrics on hitting a single person right

In [15]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

batch_size = 64
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/baseline.pth')

        
        
    corrects = 0
    for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(), target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)
        overnet.eval()
        out = overnet(spec) 
        corrects += compute_corrects(out, target)
    print('val acc:', corrects/len(valset))
    if corrects/len(valset) > bestacc:
        bestacc = corrects/len(valset)
        torch.save({
        'model_state_dict': overnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'bestacc': bestacc
        }, 'models/best-baseline.pth')
    overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[1,   200] loss: 0.053 accuracy: 0.838
[1,   400] loss: 0.053 accuracy: 0.832
[1,   600] loss: 0.053 accuracy: 0.833
[1,   800] loss: 0.054 accuracy: 0.830
[1,  1000] loss: 0.052 accuracy: 0.836
[1,  1200] loss: 0.053 accuracy: 0.832



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7734119782214156


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[2,   200] loss: 0.053 accuracy: 0.832
[2,   400] loss: 0.051 accuracy: 0.838
[2,   600] loss: 0.053 accuracy: 0.829
[2,   800] loss: 0.053 accuracy: 0.832
[2,  1000] loss: 0.054 accuracy: 0.831
[2,  1200] loss: 0.054 accuracy: 0.824



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7665154264972777


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[3,   200] loss: 0.050 accuracy: 0.841
[3,   400] loss: 0.053 accuracy: 0.834
[3,   600] loss: 0.053 accuracy: 0.831
[3,   800] loss: 0.053 accuracy: 0.832
[3,  1000] loss: 0.053 accuracy: 0.831
[3,  1200] loss: 0.054 accuracy: 0.829



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.773956442831216


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[4,   200] loss: 0.052 accuracy: 0.840
[4,   400] loss: 0.052 accuracy: 0.839
[4,   600] loss: 0.052 accuracy: 0.837
[4,   800] loss: 0.053 accuracy: 0.829
[4,  1000] loss: 0.053 accuracy: 0.834
[4,  1200] loss: 0.052 accuracy: 0.835



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7720508166969147


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[5,   200] loss: 0.051 accuracy: 0.841
[5,   400] loss: 0.052 accuracy: 0.835
[5,   600] loss: 0.053 accuracy: 0.832
[5,   800] loss: 0.052 accuracy: 0.836
[5,  1000] loss: 0.053 accuracy: 0.828
[5,  1200] loss: 0.054 accuracy: 0.828



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.768874773139746


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[6,   200] loss: 0.051 accuracy: 0.841
[6,   400] loss: 0.052 accuracy: 0.835
[6,   600] loss: 0.053 accuracy: 0.833
[6,   800] loss: 0.052 accuracy: 0.834
[6,  1000] loss: 0.053 accuracy: 0.831
[6,  1200] loss: 0.053 accuracy: 0.829



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7690562613430127


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[7,   200] loss: 0.051 accuracy: 0.840
[7,   400] loss: 0.052 accuracy: 0.840
[7,   600] loss: 0.052 accuracy: 0.838
[7,   800] loss: 0.052 accuracy: 0.838
[7,  1000] loss: 0.053 accuracy: 0.832
[7,  1200] loss: 0.052 accuracy: 0.834



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7754990925589836


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[8,   200] loss: 0.050 accuracy: 0.838
[8,   400] loss: 0.051 accuracy: 0.838
[8,   600] loss: 0.052 accuracy: 0.837
[8,   800] loss: 0.052 accuracy: 0.836
[8,  1000] loss: 0.053 accuracy: 0.829
[8,  1200] loss: 0.054 accuracy: 0.828



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7763157894736842


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[9,   200] loss: 0.050 accuracy: 0.839
[9,   400] loss: 0.052 accuracy: 0.840
[9,   600] loss: 0.052 accuracy: 0.834
[9,   800] loss: 0.053 accuracy: 0.837
[9,  1000] loss: 0.052 accuracy: 0.834
[9,  1200] loss: 0.053 accuracy: 0.832



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7742286751361162


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[10,   200] loss: 0.050 accuracy: 0.845
[10,   400] loss: 0.052 accuracy: 0.835
[10,   600] loss: 0.051 accuracy: 0.838
[10,   800] loss: 0.051 accuracy: 0.839
[10,  1000] loss: 0.052 accuracy: 0.840
[10,  1200] loss: 0.053 accuracy: 0.834



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7775862068965518


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[11,   200] loss: 0.050 accuracy: 0.844
[11,   400] loss: 0.051 accuracy: 0.842
[11,   600] loss: 0.051 accuracy: 0.838
[11,   800] loss: 0.052 accuracy: 0.837
[11,  1000] loss: 0.052 accuracy: 0.834
[11,  1200] loss: 0.051 accuracy: 0.841



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7735934664246824


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[12,   200] loss: 0.051 accuracy: 0.839
[12,   400] loss: 0.051 accuracy: 0.841
[12,   600] loss: 0.051 accuracy: 0.836
[12,   800] loss: 0.051 accuracy: 0.838
[12,  1000] loss: 0.052 accuracy: 0.837
[12,  1200] loss: 0.051 accuracy: 0.838



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7770417422867514


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[13,   200] loss: 0.051 accuracy: 0.840
[13,   400] loss: 0.049 accuracy: 0.847
[13,   600] loss: 0.051 accuracy: 0.836
[13,   800] loss: 0.052 accuracy: 0.835
[13,  1000] loss: 0.054 accuracy: 0.824
[13,  1200] loss: 0.051 accuracy: 0.840



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7735934664246824


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[14,   200] loss: 0.050 accuracy: 0.838
[14,   400] loss: 0.050 accuracy: 0.842
[14,   600] loss: 0.050 accuracy: 0.840
[14,   800] loss: 0.051 accuracy: 0.834
[14,  1000] loss: 0.051 accuracy: 0.839
[14,  1200] loss: 0.053 accuracy: 0.831



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7726860254083484


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[15,   200] loss: 0.051 accuracy: 0.841
[15,   400] loss: 0.051 accuracy: 0.840
[15,   600] loss: 0.051 accuracy: 0.840
[15,   800] loss: 0.051 accuracy: 0.839
[15,  1000] loss: 0.053 accuracy: 0.835
[15,  1200] loss: 0.051 accuracy: 0.840



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7788566243194193


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[16,   200] loss: 0.050 accuracy: 0.847


KeyboardInterrupt: 

In [16]:
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
        
corrects = 0
for batch_idx, (spec, target) in enumerate(tqdm(testloader)):
    spec, target = spec.float(), target.float()
    if half:
        spec, target = spec.half(), target.half()
    spec = spec.cuda(device)
    target = target.cuda(device)
    overnet.eval()
    out = overnet(spec) 
    corrects += compute_corrects(out, target)
print('test acc:', corrects/len(testset))
overnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


test acc: 0.7659709618874773
