In [1]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False

In [2]:
device = 3
torch.cuda.set_device(device)


In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']).union(set(self.overlays['second_speaker'])))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(overlay['first_file'])/(2**15)
        second_segment = np.load(overlay['second_file'])/(2**15)
        #padding to compensate rounding errors
        if len(first_segment)>len(second_segment):
            padding = np.zeros(len(first_segment)-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
        
        if len(first_segment)<len(second_segment):
            padding = np.zeros(len(second_segment)-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.mfcc(segment, sr=16000, n_mfcc=20, dct_type=2, n_fft = 1024, hop_length = 160)[1:14].T
        # 200*13
        S1 = np.diff(S)
        S2 = np.diff(S1)
        S = np.concatenate((S, S1, S2), axis = -1)
        return S
    
trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
testset = OverlayDataSet('overlay-test.csv', False)
print(trainset.speakers)
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)

['andrea_arsenault', 'brian_lamb', 'csp_waj_susan', 'david_brancaccio', 'eddie_mair', 'joie_chen', 'kathleen_kennedy', 'leon_harris', 'linda_wertheimer', 'linden_soles', 'lisa_mullins', 'lou_waters', 'lynn_vaughan', 'mark_mullen', 'natalie_allen', 'noah_adams', 'peter_jennings', 'robert_siegel', 'ted_koppel', 'thalia_assuras']


<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [4]:
num_heads_2 = 4 # MHA heads


class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.LayerNorm(36)
        self.reshape =  Lambda(lambda x: x.permute((1, 0, 2))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(36, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.mha =  torch.nn.MultiheadAttention(64, num_heads = num_heads_2, dropout=dropout, bias=True, kdim=64, vdim=64) # L * N * 64
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.sigmoid = nn.Sigmoid()
    def forward(self, X):
        X = self.bn(X)
        X = self.reshape(X)
        X, _ = self.lstm(X)
        X, _ = self.mha(X, X, X)
        X = self.fc1(X)
        X = self.average(X)
        X = self.tanh(X)
        X = self.fc2(X)
        X = self.sigmoid(X)
        return X
    
    
overnet = Baseline().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/baseline.pth'):
    print('load model')
    checkpoint = torch.load('models/baseline.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

load model


('bestacc:', 0.0)

## Also Do metrics on hitting a single person right

In [6]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

batch_size = 128
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/baseline.pth')

        
        
    corrects = 0
    for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(), target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)
        overnet.eval()
        out = overnet(spec) 
        corrects += compute_corrects(out, target)
    print('val acc:', corrects/len(valset))
    if corrects/len(valset) > bestacc:
        bestacc = corrects/len(valset)
        torch.save({
        'model_state_dict': overnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'bestacc': bestacc
        }, 'models/best-baseline.pth')
    overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[1,   200] loss: 0.046 accuracy: 0.855
[1,   400] loss: 0.047 accuracy: 0.854
[1,   600] loss: 0.047 accuracy: 0.853



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7870235934664247


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[2,   200] loss: 0.047 accuracy: 0.855
[2,   400] loss: 0.047 accuracy: 0.854
[2,   600] loss: 0.047 accuracy: 0.852



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7793103448275862


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[3,   200] loss: 0.046 accuracy: 0.855
[3,   400] loss: 0.047 accuracy: 0.855
[3,   600] loss: 0.048 accuracy: 0.852



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7827586206896552


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[4,   200] loss: 0.046 accuracy: 0.859
[4,   400] loss: 0.047 accuracy: 0.851
[4,   600] loss: 0.047 accuracy: 0.853



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7816696914700545


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[5,   200] loss: 0.046 accuracy: 0.857
[5,   400] loss: 0.047 accuracy: 0.854
[5,   600] loss: 0.048 accuracy: 0.850



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7768602540834846


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[6,   200] loss: 0.047 accuracy: 0.854
[6,   400] loss: 0.047 accuracy: 0.852
[6,   600] loss: 0.047 accuracy: 0.854



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7787658802177858


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[7,   200] loss: 0.046 accuracy: 0.857
[7,   400] loss: 0.047 accuracy: 0.855
[7,   600] loss: 0.048 accuracy: 0.849



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7809437386569873


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[8,   200] loss: 0.046 accuracy: 0.856
[8,   400] loss: 0.047 accuracy: 0.852
[8,   600] loss: 0.048 accuracy: 0.850



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.782940108892922


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[9,   200] loss: 0.047 accuracy: 0.854
[9,   400] loss: 0.046 accuracy: 0.856
[9,   600] loss: 0.048 accuracy: 0.854



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.780852994555354


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[10,   200] loss: 0.046 accuracy: 0.856
[10,   400] loss: 0.046 accuracy: 0.856
[10,   600] loss: 0.047 accuracy: 0.853



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7789473684210526


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[11,   200] loss: 0.046 accuracy: 0.860
[11,   400] loss: 0.047 accuracy: 0.852
[11,   600] loss: 0.047 accuracy: 0.854



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7813974591651542


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[12,   200] loss: 0.046 accuracy: 0.859
[12,   400] loss: 0.047 accuracy: 0.852
[12,   600] loss: 0.047 accuracy: 0.855



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.782486388384755


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[13,   200] loss: 0.046 accuracy: 0.855
[13,   400] loss: 0.046 accuracy: 0.855
[13,   600] loss: 0.047 accuracy: 0.855



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7841197822141561


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[14,   200] loss: 0.046 accuracy: 0.859
[14,   400] loss: 0.047 accuracy: 0.855
[14,   600] loss: 0.047 accuracy: 0.852



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7800362976406534


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[15,   200] loss: 0.046 accuracy: 0.857
[15,   400] loss: 0.047 accuracy: 0.852
[15,   600] loss: 0.047 accuracy: 0.852



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7784936479128857


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[16,   200] loss: 0.045 accuracy: 0.860
[16,   400] loss: 0.047 accuracy: 0.853
[16,   600] loss: 0.047 accuracy: 0.852



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.77994555353902


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[17,   200] loss: 0.046 accuracy: 0.859
[17,   400] loss: 0.046 accuracy: 0.859
[17,   600] loss: 0.047 accuracy: 0.854



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.781578947368421


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[18,   200] loss: 0.046 accuracy: 0.858
[18,   400] loss: 0.046 accuracy: 0.858
[18,   600] loss: 0.047 accuracy: 0.853



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7823956442831216


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[19,   200] loss: 0.045 accuracy: 0.863
[19,   400] loss: 0.046 accuracy: 0.856
[19,   600] loss: 0.047 accuracy: 0.853



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7828493647912885


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[20,   200] loss: 0.044 accuracy: 0.863
[20,   400] loss: 0.046 accuracy: 0.857
[20,   600] loss: 0.047 accuracy: 0.853



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7747731397459166


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[21,   200] loss: 0.044 accuracy: 0.865
[21,   400] loss: 0.046 accuracy: 0.855
[21,   600] loss: 0.046 accuracy: 0.856



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7818511796733212


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[22,   200] loss: 0.045 accuracy: 0.861
[22,   400] loss: 0.046 accuracy: 0.855
[22,   600] loss: 0.047 accuracy: 0.854



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7827586206896552


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[23,   200] loss: 0.045 accuracy: 0.860
[23,   400] loss: 0.045 accuracy: 0.858
[23,   600] loss: 0.046 accuracy: 0.858



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7801270417422868


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[24,   200] loss: 0.044 accuracy: 0.861
[24,   400] loss: 0.046 accuracy: 0.857
[24,   600] loss: 0.046 accuracy: 0.854



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7798548094373866


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[25,   200] loss: 0.045 accuracy: 0.863
[25,   400] loss: 0.046 accuracy: 0.857
[25,   600] loss: 0.046 accuracy: 0.857



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7868421052631579


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[26,   200] loss: 0.045 accuracy: 0.860
[26,   400] loss: 0.045 accuracy: 0.862
[26,   600] loss: 0.046 accuracy: 0.856



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7803992740471869


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[27,   200] loss: 0.045 accuracy: 0.859
[27,   400] loss: 0.046 accuracy: 0.857
[27,   600] loss: 0.046 accuracy: 0.857



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7821234119782214


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[28,   200] loss: 0.045 accuracy: 0.862
[28,   400] loss: 0.046 accuracy: 0.857
[28,   600] loss: 0.046 accuracy: 0.854



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7791288566243194


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[29,   200] loss: 0.044 accuracy: 0.862
[29,   400] loss: 0.046 accuracy: 0.855
[29,   600] loss: 0.045 accuracy: 0.859



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.778584392014519


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[30,   200] loss: 0.045 accuracy: 0.860
[30,   400] loss: 0.045 accuracy: 0.857
[30,   600] loss: 0.046 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7766787658802178


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[31,   200] loss: 0.044 accuracy: 0.863
[31,   400] loss: 0.045 accuracy: 0.858
[31,   600] loss: 0.046 accuracy: 0.856



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.77513611615245


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[32,   200] loss: 0.045 accuracy: 0.861
[32,   400] loss: 0.045 accuracy: 0.858
[32,   600] loss: 0.045 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7791288566243194


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[33,   200] loss: 0.045 accuracy: 0.860
[33,   400] loss: 0.045 accuracy: 0.859
[33,   600] loss: 0.046 accuracy: 0.855



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7851179673321234


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[34,   200] loss: 0.044 accuracy: 0.866
[34,   400] loss: 0.045 accuracy: 0.857
[34,   600] loss: 0.046 accuracy: 0.856



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7835753176043557


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[35,   200] loss: 0.044 accuracy: 0.865
[35,   400] loss: 0.045 accuracy: 0.861
[35,   600] loss: 0.045 accuracy: 0.859



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7802177858439201


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[36,   200] loss: 0.044 accuracy: 0.863
[36,   400] loss: 0.044 accuracy: 0.862
[36,   600] loss: 0.046 accuracy: 0.857



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7843920145190563


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[37,   200] loss: 0.044 accuracy: 0.866
[37,   400] loss: 0.045 accuracy: 0.860
[37,   600] loss: 0.045 accuracy: 0.858



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7816696914700545


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[38,   200] loss: 0.044 accuracy: 0.863
[38,   400] loss: 0.045 accuracy: 0.860
[38,   600] loss: 0.045 accuracy: 0.859



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7796733212341198


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[39,   200] loss: 0.043 accuracy: 0.868
[39,   400] loss: 0.045 accuracy: 0.859
[39,   600] loss: 0.045 accuracy: 0.858



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7861161524500907


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[40,   200] loss: 0.044 accuracy: 0.868
[40,   400] loss: 0.045 accuracy: 0.860
[40,   600] loss: 0.045 accuracy: 0.858



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7840290381125227


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[41,   200] loss: 0.044 accuracy: 0.864
[41,   400] loss: 0.044 accuracy: 0.861
[41,   600] loss: 0.044 accuracy: 0.861



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7823956442831216


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[42,   200] loss: 0.044 accuracy: 0.864
[42,   400] loss: 0.045 accuracy: 0.861
[42,   600] loss: 0.045 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7788566243194193


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[43,   200] loss: 0.044 accuracy: 0.863
[43,   400] loss: 0.044 accuracy: 0.861
[43,   600] loss: 0.045 accuracy: 0.859



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7805807622504537


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[44,   200] loss: 0.043 accuracy: 0.866
[44,   400] loss: 0.044 accuracy: 0.861
[44,   600] loss: 0.046 accuracy: 0.855



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7817604355716878


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[45,   200] loss: 0.043 accuracy: 0.869
[45,   400] loss: 0.045 accuracy: 0.859
[45,   600] loss: 0.045 accuracy: 0.861



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.782940108892922


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[46,   200] loss: 0.044 accuracy: 0.865
[46,   400] loss: 0.044 accuracy: 0.863
[46,   600] loss: 0.045 accuracy: 0.858



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7816696914700545


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[47,   200] loss: 0.044 accuracy: 0.865
[47,   400] loss: 0.045 accuracy: 0.861
[47,   600] loss: 0.044 accuracy: 0.861



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7804900181488204


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[48,   200] loss: 0.044 accuracy: 0.864
[48,   400] loss: 0.044 accuracy: 0.863
[48,   600] loss: 0.044 accuracy: 0.863



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7775862068965518


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[49,   200] loss: 0.043 accuracy: 0.866
[49,   400] loss: 0.044 accuracy: 0.861
[49,   600] loss: 0.045 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7848457350272232


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[50,   200] loss: 0.044 accuracy: 0.865
[50,   400] loss: 0.045 accuracy: 0.859
[50,   600] loss: 0.044 accuracy: 0.864



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7822141560798548


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[51,   200] loss: 0.043 accuracy: 0.864
[51,   400] loss: 0.044 accuracy: 0.866
[51,   600] loss: 0.045 accuracy: 0.858



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7852087114337568


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[52,   200] loss: 0.045 accuracy: 0.861
[52,   400] loss: 0.044 accuracy: 0.862
[52,   600] loss: 0.044 accuracy: 0.862



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7826678765880217


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[53,   200] loss: 0.043 accuracy: 0.866
[53,   400] loss: 0.045 accuracy: 0.861
[53,   600] loss: 0.045 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.780852994555354


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[54,   200] loss: 0.043 accuracy: 0.868
[54,   400] loss: 0.044 accuracy: 0.862
[54,   600] loss: 0.045 accuracy: 0.859



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7811252268602541


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[55,   200] loss: 0.044 accuracy: 0.864
[55,   400] loss: 0.044 accuracy: 0.864
[55,   600] loss: 0.044 accuracy: 0.863



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7781306715063521


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[56,   200] loss: 0.043 accuracy: 0.863
[56,   400] loss: 0.044 accuracy: 0.863
[56,   600] loss: 0.044 accuracy: 0.862



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7737749546279492


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[57,   200] loss: 0.043 accuracy: 0.865
[57,   400] loss: 0.044 accuracy: 0.861
[57,   600] loss: 0.044 accuracy: 0.862



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7838475499092559


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[58,   200] loss: 0.043 accuracy: 0.866
[58,   400] loss: 0.043 accuracy: 0.865
[58,   600] loss: 0.044 accuracy: 0.863



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7776769509981851


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[59,   200] loss: 0.042 accuracy: 0.869
[59,   400] loss: 0.044 accuracy: 0.861
[59,   600] loss: 0.044 accuracy: 0.865



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7857531760435572


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[60,   200] loss: 0.043 accuracy: 0.868
[60,   400] loss: 0.044 accuracy: 0.863
[60,   600] loss: 0.044 accuracy: 0.863



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7784029038112523


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[61,   200] loss: 0.043 accuracy: 0.868
[61,   400] loss: 0.043 accuracy: 0.867
[61,   600] loss: 0.045 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7805807622504537


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[62,   200] loss: 0.043 accuracy: 0.867
[62,   400] loss: 0.044 accuracy: 0.863
[62,   600] loss: 0.045 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7833030852994556


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[63,   200] loss: 0.043 accuracy: 0.867
[63,   400] loss: 0.044 accuracy: 0.864
[63,   600] loss: 0.044 accuracy: 0.865



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7798548094373866


HBox(children=(FloatProgress(value=0.0, max=689.0), HTML(value='')))

[64,   200] loss: 0.042 accuracy: 0.871
[64,   400] loss: 0.044 accuracy: 0.867
[64,   600] loss: 0.045 accuracy: 0.860



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7874773139745916


In [9]:
checkpoint = torch.load('models/best-baseline.pth')
overnet.load_state_dict(checkpoint['model_state_dict'])

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
        
corrects = 0
for batch_idx, (spec, target) in enumerate(tqdm(testloader)):
    spec, target = spec.float(), target.float()
    if half:
        spec, target = spec.half(), target.half()
    spec = spec.cuda(device)
    target = target.cuda(device)
    overnet.eval()
    out = overnet(spec) 
    corrects += compute_corrects(out, target)
print('test acc:', corrects/len(testset))
overnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


test acc: 0.7786751361161525
