In [2]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False

In [3]:
device = 3
torch.cuda.set_device(device)


In [4]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']).union(set(self.overlays['second_speaker'])))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(overlay['first_file'])/(2**15)
        second_segment = np.load(overlay['second_file'])/(2**15)
        #padding to compensate rounding errors
        if len(first_segment)>len(second_segment):
            padding = np.zeros(len(first_segment)-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
        
        if len(first_segment)<len(second_segment):
            padding = np.zeros(len(second_segment)-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.mfcc(segment, sr=16000, n_mfcc=20, dct_type=2, n_fft = 1024, hop_length = 160)[1:14].T
        # 200*13
        S1 = np.diff(S)
        S2 = np.diff(S1)
        S = np.concatenate((S, S1, S2), axis = -1)
        return S
    
trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
testset = OverlayDataSet('overlay-test.csv', False)
print(trainset.speakers)
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)

['andrea_arsenault', 'brian_lamb', 'csp_waj_susan', 'david_brancaccio', 'eddie_mair', 'joie_chen', 'kathleen_kennedy', 'leon_harris', 'linda_wertheimer', 'linden_soles', 'lisa_mullins', 'lou_waters', 'lynn_vaughan', 'mark_mullen', 'natalie_allen', 'noah_adams', 'peter_jennings', 'robert_siegel', 'ted_koppel', 'thalia_assuras']


<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [5]:
num_heads_2 = 4 # MHA heads


class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.LayerNorm(36)
        self.reshape =  Lambda(lambda x: x.permute((1, 0, 2))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(36, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.sigmoid = nn.Sigmoid()
    def forward(self, X):
        X = self.bn(X)
        X = self.reshape(X)
        X, _ = self.lstm(X)
        X = self.fc1(X)
        X = self.average(X)
        X = self.tanh(X)
        X = self.fc2(X)
        X = self.sigmoid(X)
        return X
    
    
overnet = Baseline().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/lstm.pth'):
    print('load model')
    checkpoint = torch.load('models/lstm.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

initializing new model


('bestacc:', 0.0)

## Also Do metrics on hitting a single person right

In [6]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

batch_size = 32
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/lstm.pth')

        
        
    corrects = 0
    for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(), target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)
        overnet.eval()
        out = overnet(spec) 
        corrects += compute_corrects(out, target)
    print('val acc:', corrects/len(valset))
    if corrects/len(valset) > bestacc:
        bestacc = corrects/len(valset)
        torch.save({
        'model_state_dict': overnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'bestacc': bestacc
        }, 'models/best-lstm.pth')
    overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[1,   200] loss: 0.369 accuracy: 0.007
[1,   400] loss: 0.325 accuracy: 0.005
[1,   600] loss: 0.324 accuracy: 0.007
[1,   800] loss: 0.314 accuracy: 0.010
[1,  1000] loss: 0.307 accuracy: 0.013
[1,  1200] loss: 0.302 accuracy: 0.021
[1,  1400] loss: 0.287 accuracy: 0.039
[1,  1600] loss: 0.265 accuracy: 0.071
[1,  1800] loss: 0.248 accuracy: 0.102
[1,  2000] loss: 0.240 accuracy: 0.121
[1,  2200] loss: 0.233 accuracy: 0.134
[1,  2400] loss: 0.224 accuracy: 0.159
[1,  2600] loss: 0.217 accuracy: 0.181



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.2013611615245009


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[2,   200] loss: 0.208 accuracy: 0.211
[2,   400] loss: 0.202 accuracy: 0.221
[2,   600] loss: 0.197 accuracy: 0.243
[2,   800] loss: 0.194 accuracy: 0.246
[2,  1000] loss: 0.190 accuracy: 0.271
[2,  1200] loss: 0.187 accuracy: 0.275
[2,  1400] loss: 0.185 accuracy: 0.292
[2,  1600] loss: 0.180 accuracy: 0.300
[2,  1800] loss: 0.177 accuracy: 0.327
[2,  2000] loss: 0.174 accuracy: 0.333
[2,  2200] loss: 0.173 accuracy: 0.329
[2,  2400] loss: 0.169 accuracy: 0.356
[2,  2600] loss: 0.167 accuracy: 0.368



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.37250453720508164


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[3,   200] loss: 0.161 accuracy: 0.376
[3,   400] loss: 0.160 accuracy: 0.382
[3,   600] loss: 0.159 accuracy: 0.390
[3,   800] loss: 0.156 accuracy: 0.399
[3,  1000] loss: 0.154 accuracy: 0.404
[3,  1200] loss: 0.151 accuracy: 0.425
[3,  1400] loss: 0.149 accuracy: 0.432
[3,  1600] loss: 0.149 accuracy: 0.430
[3,  1800] loss: 0.148 accuracy: 0.430
[3,  2000] loss: 0.145 accuracy: 0.453
[3,  2200] loss: 0.144 accuracy: 0.453
[3,  2400] loss: 0.142 accuracy: 0.469
[3,  2600] loss: 0.141 accuracy: 0.470



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.46805807622504536


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[4,   200] loss: 0.138 accuracy: 0.477
[4,   400] loss: 0.136 accuracy: 0.487
[4,   600] loss: 0.134 accuracy: 0.496
[4,   800] loss: 0.135 accuracy: 0.494
[4,  1000] loss: 0.132 accuracy: 0.514
[4,  1200] loss: 0.132 accuracy: 0.506
[4,  1400] loss: 0.130 accuracy: 0.517
[4,  1600] loss: 0.129 accuracy: 0.521
[4,  1800] loss: 0.130 accuracy: 0.512
[4,  2000] loss: 0.129 accuracy: 0.524
[4,  2200] loss: 0.127 accuracy: 0.532
[4,  2400] loss: 0.126 accuracy: 0.537
[4,  2600] loss: 0.123 accuracy: 0.551



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.5484573502722323


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[5,   200] loss: 0.123 accuracy: 0.549
[5,   400] loss: 0.120 accuracy: 0.562
[5,   600] loss: 0.121 accuracy: 0.560
[5,   800] loss: 0.120 accuracy: 0.555
[5,  1000] loss: 0.119 accuracy: 0.568
[5,  1200] loss: 0.118 accuracy: 0.572
[5,  1400] loss: 0.116 accuracy: 0.583
[5,  1600] loss: 0.116 accuracy: 0.576
[5,  1800] loss: 0.117 accuracy: 0.575
[5,  2000] loss: 0.116 accuracy: 0.581
[5,  2200] loss: 0.114 accuracy: 0.592
[5,  2400] loss: 0.113 accuracy: 0.592
[5,  2600] loss: 0.113 accuracy: 0.594



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.5956442831215971


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[6,   200] loss: 0.110 accuracy: 0.604
[6,   400] loss: 0.111 accuracy: 0.600
[6,   600] loss: 0.108 accuracy: 0.615
[6,   800] loss: 0.108 accuracy: 0.615
[6,  1000] loss: 0.110 accuracy: 0.602
[6,  1200] loss: 0.107 accuracy: 0.622
[6,  1400] loss: 0.108 accuracy: 0.606
[6,  1600] loss: 0.107 accuracy: 0.618
[6,  1800] loss: 0.106 accuracy: 0.622
[6,  2000] loss: 0.106 accuracy: 0.618
[6,  2200] loss: 0.108 accuracy: 0.610
[6,  2400] loss: 0.108 accuracy: 0.611
[6,  2600] loss: 0.104 accuracy: 0.624



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6263157894736842


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[7,   200] loss: 0.100 accuracy: 0.648
[7,   400] loss: 0.102 accuracy: 0.639
[7,   600] loss: 0.102 accuracy: 0.639
[7,   800] loss: 0.104 accuracy: 0.632
[7,  1000] loss: 0.102 accuracy: 0.630
[7,  1200] loss: 0.101 accuracy: 0.638
[7,  1400] loss: 0.101 accuracy: 0.636
[7,  1600] loss: 0.104 accuracy: 0.622
[7,  1800] loss: 0.100 accuracy: 0.643
[7,  2000] loss: 0.099 accuracy: 0.653
[7,  2200] loss: 0.101 accuracy: 0.633
[7,  2400] loss: 0.101 accuracy: 0.632
[7,  2600] loss: 0.099 accuracy: 0.656



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6332123411978221


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[8,   200] loss: 0.098 accuracy: 0.656
[8,   400] loss: 0.097 accuracy: 0.662
[8,   600] loss: 0.095 accuracy: 0.652
[8,   800] loss: 0.096 accuracy: 0.654
[8,  1000] loss: 0.096 accuracy: 0.666
[8,  1200] loss: 0.095 accuracy: 0.661
[8,  1400] loss: 0.095 accuracy: 0.667
[8,  1600] loss: 0.098 accuracy: 0.657
[8,  1800] loss: 0.097 accuracy: 0.658
[8,  2000] loss: 0.095 accuracy: 0.668
[8,  2200] loss: 0.097 accuracy: 0.658
[8,  2400] loss: 0.097 accuracy: 0.660
[8,  2600] loss: 0.095 accuracy: 0.664



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6561705989110708


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[9,   200] loss: 0.093 accuracy: 0.670
[9,   400] loss: 0.091 accuracy: 0.675
[9,   600] loss: 0.092 accuracy: 0.676
[9,   800] loss: 0.093 accuracy: 0.677
[9,  1000] loss: 0.092 accuracy: 0.672
[9,  1200] loss: 0.090 accuracy: 0.678
[9,  1400] loss: 0.093 accuracy: 0.673
[9,  1600] loss: 0.092 accuracy: 0.677
[9,  1800] loss: 0.090 accuracy: 0.685
[9,  2000] loss: 0.092 accuracy: 0.675
[9,  2200] loss: 0.093 accuracy: 0.674
[9,  2400] loss: 0.090 accuracy: 0.678
[9,  2600] loss: 0.092 accuracy: 0.674



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6742286751361162


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[10,   200] loss: 0.089 accuracy: 0.683
[10,   400] loss: 0.089 accuracy: 0.683
[10,   600] loss: 0.092 accuracy: 0.685
[10,   800] loss: 0.090 accuracy: 0.680
[10,  1000] loss: 0.089 accuracy: 0.683
[10,  1200] loss: 0.088 accuracy: 0.693
[10,  1400] loss: 0.088 accuracy: 0.696
[10,  1600] loss: 0.088 accuracy: 0.688
[10,  1800] loss: 0.089 accuracy: 0.688
[10,  2000] loss: 0.089 accuracy: 0.688
[10,  2200] loss: 0.087 accuracy: 0.696
[10,  2400] loss: 0.089 accuracy: 0.688
[10,  2600] loss: 0.088 accuracy: 0.690



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6764065335753177


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[11,   200] loss: 0.085 accuracy: 0.707
[11,   400] loss: 0.087 accuracy: 0.700
[11,   600] loss: 0.086 accuracy: 0.699
[11,   800] loss: 0.087 accuracy: 0.702
[11,  1000] loss: 0.086 accuracy: 0.701
[11,  1200] loss: 0.086 accuracy: 0.700
[11,  1400] loss: 0.086 accuracy: 0.704
[11,  1600] loss: 0.086 accuracy: 0.693
[11,  1800] loss: 0.086 accuracy: 0.697
[11,  2000] loss: 0.086 accuracy: 0.700
[11,  2200] loss: 0.086 accuracy: 0.701
[11,  2400] loss: 0.087 accuracy: 0.699
[11,  2600] loss: 0.086 accuracy: 0.701



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.678584392014519


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[12,   200] loss: 0.082 accuracy: 0.723
[12,   400] loss: 0.083 accuracy: 0.714
[12,   600] loss: 0.084 accuracy: 0.713
[12,   800] loss: 0.085 accuracy: 0.705
[12,  1000] loss: 0.084 accuracy: 0.711
[12,  1200] loss: 0.086 accuracy: 0.695
[12,  1400] loss: 0.085 accuracy: 0.704
[12,  1600] loss: 0.084 accuracy: 0.705
[12,  1800] loss: 0.083 accuracy: 0.702
[12,  2000] loss: 0.083 accuracy: 0.711
[12,  2200] loss: 0.084 accuracy: 0.710
[12,  2400] loss: 0.083 accuracy: 0.718
[12,  2600] loss: 0.083 accuracy: 0.715



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6936479128856624


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[13,   200] loss: 0.082 accuracy: 0.722
[13,   400] loss: 0.082 accuracy: 0.717
[13,   600] loss: 0.082 accuracy: 0.716
[13,   800] loss: 0.083 accuracy: 0.715
[13,  1000] loss: 0.082 accuracy: 0.713
[13,  1200] loss: 0.080 accuracy: 0.719
[13,  1400] loss: 0.084 accuracy: 0.704
[13,  1600] loss: 0.080 accuracy: 0.728
[13,  1800] loss: 0.082 accuracy: 0.720
[13,  2000] loss: 0.080 accuracy: 0.723
[13,  2200] loss: 0.082 accuracy: 0.720
[13,  2400] loss: 0.082 accuracy: 0.713
[13,  2600] loss: 0.081 accuracy: 0.720



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6906533575317604


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[14,   200] loss: 0.079 accuracy: 0.732
[14,   400] loss: 0.080 accuracy: 0.725
[14,   600] loss: 0.079 accuracy: 0.731
[14,   800] loss: 0.079 accuracy: 0.731
[14,  1000] loss: 0.080 accuracy: 0.717
[14,  1200] loss: 0.080 accuracy: 0.725
[14,  1400] loss: 0.081 accuracy: 0.719
[14,  1600] loss: 0.081 accuracy: 0.719
[14,  1800] loss: 0.080 accuracy: 0.726
[14,  2000] loss: 0.080 accuracy: 0.720
[14,  2200] loss: 0.080 accuracy: 0.718
[14,  2400] loss: 0.080 accuracy: 0.726
[14,  2600] loss: 0.080 accuracy: 0.722



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7030852994555354


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[15,   200] loss: 0.079 accuracy: 0.733
[15,   400] loss: 0.081 accuracy: 0.729
[15,   600] loss: 0.079 accuracy: 0.731
[15,   800] loss: 0.077 accuracy: 0.738
[15,  1000] loss: 0.079 accuracy: 0.733
[15,  1200] loss: 0.078 accuracy: 0.727
[15,  1400] loss: 0.079 accuracy: 0.726
[15,  1600] loss: 0.080 accuracy: 0.722
[15,  1800] loss: 0.078 accuracy: 0.734
[15,  2000] loss: 0.076 accuracy: 0.741
[15,  2200] loss: 0.077 accuracy: 0.739
[15,  2400] loss: 0.079 accuracy: 0.719
[15,  2600] loss: 0.079 accuracy: 0.730



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.709437386569873


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[16,   200] loss: 0.075 accuracy: 0.740
[16,   400] loss: 0.077 accuracy: 0.733
[16,   600] loss: 0.076 accuracy: 0.735
[16,   800] loss: 0.077 accuracy: 0.730
[16,  1000] loss: 0.077 accuracy: 0.737
[16,  1200] loss: 0.075 accuracy: 0.740
[16,  1400] loss: 0.077 accuracy: 0.739
[16,  1600] loss: 0.078 accuracy: 0.740
[16,  1800] loss: 0.077 accuracy: 0.738
[16,  2000] loss: 0.080 accuracy: 0.720
[16,  2200] loss: 0.077 accuracy: 0.733
[16,  2400] loss: 0.079 accuracy: 0.725
[16,  2600] loss: 0.077 accuracy: 0.734



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7081669691470055


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[17,   200] loss: 0.075 accuracy: 0.742
[17,   400] loss: 0.076 accuracy: 0.741
[17,   600] loss: 0.076 accuracy: 0.744
[17,   800] loss: 0.078 accuracy: 0.727
[17,  1000] loss: 0.074 accuracy: 0.748
[17,  1200] loss: 0.075 accuracy: 0.738
[17,  1400] loss: 0.075 accuracy: 0.743
[17,  1600] loss: 0.075 accuracy: 0.749
[17,  1800] loss: 0.075 accuracy: 0.745
[17,  2000] loss: 0.076 accuracy: 0.748
[17,  2200] loss: 0.076 accuracy: 0.741
[17,  2400] loss: 0.077 accuracy: 0.732
[17,  2600] loss: 0.077 accuracy: 0.738



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7137931034482758


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[18,   200] loss: 0.072 accuracy: 0.759
[18,   400] loss: 0.073 accuracy: 0.750
[18,   600] loss: 0.073 accuracy: 0.758
[18,   800] loss: 0.076 accuracy: 0.741
[18,  1000] loss: 0.073 accuracy: 0.754
[18,  1200] loss: 0.075 accuracy: 0.743
[18,  1400] loss: 0.075 accuracy: 0.744
[18,  1600] loss: 0.076 accuracy: 0.741
[18,  1800] loss: 0.074 accuracy: 0.750
[18,  2000] loss: 0.074 accuracy: 0.748
[18,  2200] loss: 0.074 accuracy: 0.749
[18,  2400] loss: 0.076 accuracy: 0.740
[18,  2600] loss: 0.078 accuracy: 0.728



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7264065335753176


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[19,   200] loss: 0.071 accuracy: 0.757
[19,   400] loss: 0.072 accuracy: 0.752
[19,   600] loss: 0.072 accuracy: 0.753
[19,   800] loss: 0.075 accuracy: 0.746
[19,  1000] loss: 0.073 accuracy: 0.744
[19,  1200] loss: 0.074 accuracy: 0.750
[19,  1400] loss: 0.074 accuracy: 0.753
[19,  1600] loss: 0.074 accuracy: 0.745
[19,  1800] loss: 0.074 accuracy: 0.748
[19,  2000] loss: 0.074 accuracy: 0.751
[19,  2200] loss: 0.075 accuracy: 0.743
[19,  2400] loss: 0.074 accuracy: 0.751
[19,  2600] loss: 0.074 accuracy: 0.750



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.720508166969147


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[20,   200] loss: 0.071 accuracy: 0.765
[20,   400] loss: 0.073 accuracy: 0.750
[20,   600] loss: 0.071 accuracy: 0.764
[20,   800] loss: 0.070 accuracy: 0.760
[20,  1000] loss: 0.073 accuracy: 0.753
[20,  1200] loss: 0.072 accuracy: 0.758
[20,  1400] loss: 0.074 accuracy: 0.748
[20,  1600] loss: 0.074 accuracy: 0.754
[20,  1800] loss: 0.074 accuracy: 0.750
[20,  2000] loss: 0.073 accuracy: 0.755
[20,  2200] loss: 0.073 accuracy: 0.742
[20,  2400] loss: 0.073 accuracy: 0.746
[20,  2600] loss: 0.074 accuracy: 0.746



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7274047186932849


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[21,   200] loss: 0.072 accuracy: 0.754
[21,   400] loss: 0.071 accuracy: 0.765
[21,   600] loss: 0.071 accuracy: 0.757
[21,   800] loss: 0.072 accuracy: 0.750
[21,  1000] loss: 0.071 accuracy: 0.761
[21,  1200] loss: 0.072 accuracy: 0.756
[21,  1400] loss: 0.072 accuracy: 0.757
[21,  1600] loss: 0.074 accuracy: 0.746
[21,  1800] loss: 0.071 accuracy: 0.753
[21,  2000] loss: 0.071 accuracy: 0.757
[21,  2200] loss: 0.072 accuracy: 0.755
[21,  2400] loss: 0.074 accuracy: 0.746
[21,  2600] loss: 0.073 accuracy: 0.748



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7243194192377496


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[22,   200] loss: 0.070 accuracy: 0.767
[22,   400] loss: 0.071 accuracy: 0.758
[22,   600] loss: 0.072 accuracy: 0.759
[22,   800] loss: 0.070 accuracy: 0.761
[22,  1000] loss: 0.070 accuracy: 0.765
[22,  1200] loss: 0.070 accuracy: 0.766
[22,  1400] loss: 0.073 accuracy: 0.748
[22,  1600] loss: 0.070 accuracy: 0.768
[22,  1800] loss: 0.070 accuracy: 0.757
[22,  2000] loss: 0.071 accuracy: 0.757
[22,  2200] loss: 0.070 accuracy: 0.765
[22,  2400] loss: 0.072 accuracy: 0.757
[22,  2600] loss: 0.072 accuracy: 0.759



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.726497277676951


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[23,   200] loss: 0.067 accuracy: 0.776
[23,   400] loss: 0.070 accuracy: 0.763
[23,   600] loss: 0.071 accuracy: 0.763
[23,   800] loss: 0.071 accuracy: 0.760
[23,  1000] loss: 0.070 accuracy: 0.758
[23,  1200] loss: 0.068 accuracy: 0.767
[23,  1400] loss: 0.069 accuracy: 0.771
[23,  1600] loss: 0.071 accuracy: 0.759
[23,  1800] loss: 0.069 accuracy: 0.769
[23,  2000] loss: 0.071 accuracy: 0.759
[23,  2200] loss: 0.071 accuracy: 0.765
[23,  2400] loss: 0.069 accuracy: 0.765
[23,  2600] loss: 0.072 accuracy: 0.756



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7288566243194192


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[24,   200] loss: 0.068 accuracy: 0.769
[24,   400] loss: 0.070 accuracy: 0.765
[24,   600] loss: 0.069 accuracy: 0.767
[24,   800] loss: 0.068 accuracy: 0.770
[24,  1000] loss: 0.069 accuracy: 0.767
[24,  1200] loss: 0.069 accuracy: 0.768
[24,  1400] loss: 0.069 accuracy: 0.767
[24,  1600] loss: 0.070 accuracy: 0.767
[24,  1800] loss: 0.070 accuracy: 0.763
[24,  2000] loss: 0.070 accuracy: 0.767
[24,  2200] loss: 0.069 accuracy: 0.768
[24,  2400] loss: 0.070 accuracy: 0.765
[24,  2600] loss: 0.070 accuracy: 0.761



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7274047186932849


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[25,   200] loss: 0.067 accuracy: 0.772
[25,   400] loss: 0.067 accuracy: 0.773
[25,   600] loss: 0.067 accuracy: 0.771
[25,   800] loss: 0.067 accuracy: 0.774
[25,  1000] loss: 0.068 accuracy: 0.770
[25,  1200] loss: 0.070 accuracy: 0.768
[25,  1400] loss: 0.068 accuracy: 0.776
[25,  1600] loss: 0.070 accuracy: 0.763
[25,  1800] loss: 0.070 accuracy: 0.762
[25,  2000] loss: 0.068 accuracy: 0.767
[25,  2200] loss: 0.069 accuracy: 0.767
[25,  2400] loss: 0.070 accuracy: 0.766
[25,  2600] loss: 0.070 accuracy: 0.767



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7242286751361161


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[26,   200] loss: 0.067 accuracy: 0.776
[26,   400] loss: 0.066 accuracy: 0.776
[26,   600] loss: 0.068 accuracy: 0.768
[26,   800] loss: 0.067 accuracy: 0.778
[26,  1000] loss: 0.067 accuracy: 0.774
[26,  1200] loss: 0.067 accuracy: 0.773
[26,  1400] loss: 0.068 accuracy: 0.775
[26,  1600] loss: 0.071 accuracy: 0.755
[26,  1800] loss: 0.068 accuracy: 0.765
[26,  2000] loss: 0.068 accuracy: 0.766
[26,  2200] loss: 0.070 accuracy: 0.764
[26,  2400] loss: 0.068 accuracy: 0.773
[26,  2600] loss: 0.069 accuracy: 0.764



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7274047186932849


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[27,   200] loss: 0.066 accuracy: 0.779
[27,   400] loss: 0.066 accuracy: 0.783
[27,   600] loss: 0.067 accuracy: 0.773
[27,   800] loss: 0.067 accuracy: 0.773
[27,  1000] loss: 0.068 accuracy: 0.779
[27,  1200] loss: 0.067 accuracy: 0.780
[27,  1400] loss: 0.067 accuracy: 0.772
[27,  1600] loss: 0.067 accuracy: 0.775
[27,  1800] loss: 0.068 accuracy: 0.771
[27,  2000] loss: 0.068 accuracy: 0.772
[27,  2200] loss: 0.068 accuracy: 0.765
[27,  2400] loss: 0.070 accuracy: 0.763
[27,  2600] loss: 0.068 accuracy: 0.773



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.732486388384755


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[28,   200] loss: 0.066 accuracy: 0.775
[28,   400] loss: 0.065 accuracy: 0.783
[28,   600] loss: 0.067 accuracy: 0.780
[28,   800] loss: 0.066 accuracy: 0.777
[28,  1000] loss: 0.066 accuracy: 0.780
[28,  1200] loss: 0.068 accuracy: 0.771
[28,  1400] loss: 0.067 accuracy: 0.771
[28,  1600] loss: 0.066 accuracy: 0.779
[28,  1800] loss: 0.066 accuracy: 0.779
[28,  2000] loss: 0.067 accuracy: 0.778
[28,  2200] loss: 0.070 accuracy: 0.761
[28,  2400] loss: 0.067 accuracy: 0.770
[28,  2600] loss: 0.066 accuracy: 0.776



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7295825771324864


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[29,   200] loss: 0.066 accuracy: 0.784
[29,   400] loss: 0.064 accuracy: 0.787
[29,   600] loss: 0.064 accuracy: 0.786
[29,   800] loss: 0.066 accuracy: 0.775
[29,  1000] loss: 0.067 accuracy: 0.770
[29,  1200] loss: 0.066 accuracy: 0.775
[29,  1400] loss: 0.067 accuracy: 0.775
[29,  1600] loss: 0.066 accuracy: 0.778
[29,  1800] loss: 0.065 accuracy: 0.787
[29,  2000] loss: 0.067 accuracy: 0.773
[29,  2200] loss: 0.068 accuracy: 0.768
[29,  2400] loss: 0.067 accuracy: 0.773
[29,  2600] loss: 0.069 accuracy: 0.764



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7305807622504538


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[30,   200] loss: 0.063 accuracy: 0.789
[30,   400] loss: 0.065 accuracy: 0.782
[30,   600] loss: 0.065 accuracy: 0.787
[30,   800] loss: 0.067 accuracy: 0.770
[30,  1000] loss: 0.067 accuracy: 0.777
[30,  1200] loss: 0.066 accuracy: 0.789
[30,  1400] loss: 0.066 accuracy: 0.775
[30,  1600] loss: 0.065 accuracy: 0.781
[30,  1800] loss: 0.065 accuracy: 0.781
[30,  2000] loss: 0.068 accuracy: 0.764
[30,  2200] loss: 0.065 accuracy: 0.781
[30,  2400] loss: 0.066 accuracy: 0.776
[30,  2600] loss: 0.067 accuracy: 0.771



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7356624319419238


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[31,   200] loss: 0.063 accuracy: 0.785
[31,   400] loss: 0.065 accuracy: 0.782
[31,   600] loss: 0.063 accuracy: 0.793
[31,   800] loss: 0.066 accuracy: 0.776
[31,  1000] loss: 0.065 accuracy: 0.782
[31,  1200] loss: 0.067 accuracy: 0.775
[31,  1400] loss: 0.065 accuracy: 0.786
[31,  1600] loss: 0.068 accuracy: 0.772
[31,  1800] loss: 0.066 accuracy: 0.784
[31,  2000] loss: 0.065 accuracy: 0.778
[31,  2200] loss: 0.065 accuracy: 0.776
[31,  2400] loss: 0.065 accuracy: 0.783
[31,  2600] loss: 0.065 accuracy: 0.786



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7340290381125227


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[32,   200] loss: 0.063 accuracy: 0.790
[32,   400] loss: 0.063 accuracy: 0.790
[32,   600] loss: 0.064 accuracy: 0.794
[32,   800] loss: 0.066 accuracy: 0.777
[32,  1000] loss: 0.065 accuracy: 0.782
[32,  1200] loss: 0.066 accuracy: 0.780
[32,  1400] loss: 0.064 accuracy: 0.788
[32,  1600] loss: 0.067 accuracy: 0.775
[32,  1800] loss: 0.064 accuracy: 0.782
[32,  2000] loss: 0.064 accuracy: 0.787
[32,  2200] loss: 0.065 accuracy: 0.786
[32,  2400] loss: 0.067 accuracy: 0.775
[32,  2600] loss: 0.065 accuracy: 0.775



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7398366606170599


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[33,   200] loss: 0.062 accuracy: 0.798
[33,   400] loss: 0.065 accuracy: 0.771
[33,   600] loss: 0.062 accuracy: 0.801
[33,   800] loss: 0.063 accuracy: 0.791
[33,  1000] loss: 0.063 accuracy: 0.791
[33,  1200] loss: 0.065 accuracy: 0.777
[33,  1400] loss: 0.064 accuracy: 0.789
[33,  1600] loss: 0.066 accuracy: 0.775
[33,  1800] loss: 0.065 accuracy: 0.788
[33,  2000] loss: 0.065 accuracy: 0.776
[33,  2200] loss: 0.064 accuracy: 0.788
[33,  2400] loss: 0.068 accuracy: 0.767
[33,  2600] loss: 0.065 accuracy: 0.779



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7443738656987295


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[34,   200] loss: 0.061 accuracy: 0.800
[34,   400] loss: 0.063 accuracy: 0.788
[34,   600] loss: 0.063 accuracy: 0.790
[34,   800] loss: 0.063 accuracy: 0.782
[34,  1000] loss: 0.063 accuracy: 0.784
[34,  1200] loss: 0.065 accuracy: 0.782
[34,  1400] loss: 0.062 accuracy: 0.797
[34,  1600] loss: 0.065 accuracy: 0.781
[34,  1800] loss: 0.065 accuracy: 0.781
[34,  2000] loss: 0.063 accuracy: 0.786
[34,  2200] loss: 0.065 accuracy: 0.780
[34,  2400] loss: 0.065 accuracy: 0.781
[34,  2600] loss: 0.065 accuracy: 0.775



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7393829401088929


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[35,   200] loss: 0.061 accuracy: 0.794
[35,   400] loss: 0.065 accuracy: 0.786
[35,   600] loss: 0.061 accuracy: 0.796
[35,   800] loss: 0.064 accuracy: 0.785
[35,  1000] loss: 0.062 accuracy: 0.794
[35,  1200] loss: 0.064 accuracy: 0.792
[35,  1400] loss: 0.063 accuracy: 0.785
[35,  1600] loss: 0.062 accuracy: 0.792
[35,  1800] loss: 0.064 accuracy: 0.785
[35,  2000] loss: 0.063 accuracy: 0.797
[35,  2200] loss: 0.064 accuracy: 0.785
[35,  2400] loss: 0.063 accuracy: 0.789
[35,  2600] loss: 0.066 accuracy: 0.776



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7406533575317604


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

KeyboardInterrupt: 

In [8]:
checkpoint = torch.load('models/best-lstm.pth')
overnet.load_state_dict(checkpoint['model_state_dict'])

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
        
corrects = 0
for batch_idx, (spec, target) in enumerate(tqdm(testloader)):
    spec, target = spec.float(), target.float()
    if half:
        spec, target = spec.half(), target.half()
    spec = spec.cuda(device)
    target = target.cuda(device)
    overnet.eval()
    out = overnet(spec) 
    corrects += compute_corrects(out, target)
print('test acc:', corrects/len(testset))
overnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


test acc: 0.7449183303085299
