In [1]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False
root = '../'

In [2]:
device = 2
torch.cuda.set_device(device)


In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(root+overlay['first_file'])/(2**15)
        second_segment = np.load(root+overlay['second_file'])/(2**15)
        third_segment = np.load(root+overlay['third_file'])/(2**15)
        max_len = max(len(first_segment), len(second_segment), len(third_segment))
        #padding to compensate rounding errors
        if max_len>len(first_segment):
            padding = np.zeros(max_len-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        if max_len>len(second_segment):
            padding = np.zeros(max_len-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
            
        if max_len>len(third_segment):
            padding = np.zeros(max_len-len(third_segment))
            third_segment = np.concatenate((third_segment, padding))
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        third_idx = self.spkr2idx[overlay['third_speaker']]

        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        target[third_idx] = 1.0
        
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(third_segment), self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment+third_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.mfcc(segment, sr=16000, n_mfcc=20, dct_type=2, n_fft = 1024, hop_length = 160)[1:14].T
        # 200*13
        S1 = np.diff(S)
        S2 = np.diff(S1)
        S = np.concatenate((S, S1, S2), axis = -1)
        return S

trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
testset = OverlayDataSet('overlay-test.csv', False)

In [4]:
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)
print(spec3.max(), spec3.min(), spec3.shape, target)

265.79355 -200.36958 (200, 36) [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0.]


<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [5]:
num_heads_2 = 4 # MHA heads


class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.LayerNorm(36)
        self.reshape =  Lambda(lambda x: x.permute((1, 0, 2))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(36, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.mha =  torch.nn.MultiheadAttention(64, num_heads = num_heads_2, dropout=dropout, bias=True, kdim=64, vdim=64) # L * N * 64
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.sigmoid = nn.Sigmoid()
    def forward(self, X):
        X = self.bn(X)
        X = self.reshape(X)
        X, _ = self.lstm(X)
        X, _ = self.mha(X, X, X)
        X = self.fc1(X)
        X = self.average(X)
        X = self.tanh(X)
        X = self.fc2(X)
        X = self.sigmoid(X)
        return X
    
    
overnet = Baseline().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/baseline.pth'):
    print('load model')
    checkpoint = torch.load('models/baseline.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

load model


('bestacc:', 0.0)

## Also Do metrics on hitting a single person right

In [8]:
def find_max3(tensor):
    array = tensor.cpu().detach().numpy()
    max3 = []
    for row in array:
        max3.append(np.argsort(row)[::-1][:3])
    return np.array(max3)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max3(tensor1), find_max3(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

batch_size = 64
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/baseline.pth')

        
        
    corrects = 0
    for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(), target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)
        overnet.eval()
        out = overnet(spec) 
        corrects += compute_corrects(out, target)
    print('val acc:', corrects/len(valset))
    if corrects/len(valset) > bestacc:
        bestacc = corrects/len(valset)
        torch.save({
        'model_state_dict': overnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'bestacc': bestacc
        }, 'models/best-baseline.pth')
    overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[1,   200] loss: 0.151 accuracy: 0.446
[1,   400] loss: 0.151 accuracy: 0.446
[1,   600] loss: 0.151 accuracy: 0.448
[1,   800] loss: 0.149 accuracy: 0.455
[1,  1000] loss: 0.150 accuracy: 0.449
[1,  1200] loss: 0.150 accuracy: 0.447
[1,  1400] loss: 0.151 accuracy: 0.450
[1,  1600] loss: 0.151 accuracy: 0.446
[1,  1800] loss: 0.151 accuracy: 0.443
[1,  2000] loss: 0.151 accuracy: 0.443
[1,  2200] loss: 0.152 accuracy: 0.441
[1,  2400] loss: 0.153 accuracy: 0.441
[1,  2600] loss: 0.152 accuracy: 0.444
[1,  2800] loss: 0.152 accuracy: 0.439
[1,  3000] loss: 0.152 accuracy: 0.441
[1,  3200] loss: 0.150 accuracy: 0.449
[1,  3400] loss: 0.152 accuracy: 0.440
[1,  3600] loss: 0.151 accuracy: 0.441
[1,  3800] loss: 0.152 accuracy: 0.439
[1,  4000] loss: 0.153 accuracy: 0.437
[1,  4200] loss: 0.152 accuracy: 0.439
[1,  4400] loss: 0.151 accuracy: 0.446
[1,  4600] loss: 0.150 accuracy: 0.443
[1,  4800] loss: 0.152 accuracy: 0.445
[1,  5000] loss: 0.151 accuracy: 0.447
[1,  5200] loss: 0.151 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.46012300867110306


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[2,   200] loss: 0.152 accuracy: 0.439
[2,   400] loss: 0.149 accuracy: 0.446
[2,   600] loss: 0.151 accuracy: 0.448
[2,   800] loss: 0.152 accuracy: 0.443
[2,  1000] loss: 0.150 accuracy: 0.452
[2,  1200] loss: 0.150 accuracy: 0.451
[2,  1400] loss: 0.150 accuracy: 0.449
[2,  1600] loss: 0.149 accuracy: 0.456
[2,  1800] loss: 0.150 accuracy: 0.450
[2,  2000] loss: 0.150 accuracy: 0.441
[2,  2200] loss: 0.152 accuracy: 0.438
[2,  2400] loss: 0.150 accuracy: 0.442
[2,  2600] loss: 0.151 accuracy: 0.445
[2,  2800] loss: 0.150 accuracy: 0.451
[2,  3000] loss: 0.150 accuracy: 0.446
[2,  3200] loss: 0.151 accuracy: 0.444
[2,  3400] loss: 0.150 accuracy: 0.449
[2,  3600] loss: 0.152 accuracy: 0.439
[2,  3800] loss: 0.149 accuracy: 0.449
[2,  4000] loss: 0.150 accuracy: 0.453
[2,  4200] loss: 0.150 accuracy: 0.447
[2,  4400] loss: 0.149 accuracy: 0.454
[2,  4600] loss: 0.149 accuracy: 0.446
[2,  4800] loss: 0.150 accuracy: 0.451
[2,  5000] loss: 0.151 accuracy: 0.445
[2,  5200] loss: 0.150 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4648417019560395


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[3,   200] loss: 0.148 accuracy: 0.459
[3,   400] loss: 0.150 accuracy: 0.450
[3,   600] loss: 0.149 accuracy: 0.450
[3,   800] loss: 0.149 accuracy: 0.456
[3,  1000] loss: 0.148 accuracy: 0.454
[3,  1200] loss: 0.150 accuracy: 0.446
[3,  1400] loss: 0.149 accuracy: 0.449
[3,  1600] loss: 0.150 accuracy: 0.450
[3,  1800] loss: 0.149 accuracy: 0.451
[3,  2000] loss: 0.148 accuracy: 0.457
[3,  2200] loss: 0.149 accuracy: 0.453
[3,  2400] loss: 0.148 accuracy: 0.458
[3,  2600] loss: 0.149 accuracy: 0.456
[3,  2800] loss: 0.148 accuracy: 0.456
[3,  3000] loss: 0.149 accuracy: 0.450
[3,  3200] loss: 0.150 accuracy: 0.450
[3,  3400] loss: 0.150 accuracy: 0.448
[3,  3600] loss: 0.150 accuracy: 0.442
[3,  3800] loss: 0.150 accuracy: 0.447
[3,  4000] loss: 0.150 accuracy: 0.445
[3,  4200] loss: 0.148 accuracy: 0.458
[3,  4400] loss: 0.148 accuracy: 0.452
[3,  4600] loss: 0.149 accuracy: 0.455
[3,  4800] loss: 0.149 accuracy: 0.448
[3,  5000] loss: 0.147 accuracy: 0.455
[3,  5200] loss: 0.148 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47106775559588626


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[4,   200] loss: 0.148 accuracy: 0.453
[4,   400] loss: 0.149 accuracy: 0.454
[4,   600] loss: 0.149 accuracy: 0.448
[4,   800] loss: 0.148 accuracy: 0.453
[4,  1000] loss: 0.148 accuracy: 0.454
[4,  1200] loss: 0.149 accuracy: 0.454
[4,  1400] loss: 0.148 accuracy: 0.449
[4,  1600] loss: 0.149 accuracy: 0.449
[4,  1800] loss: 0.149 accuracy: 0.452
[4,  2000] loss: 0.147 accuracy: 0.459
[4,  2200] loss: 0.149 accuracy: 0.450
[4,  2400] loss: 0.146 accuracy: 0.462
[4,  2600] loss: 0.148 accuracy: 0.454
[4,  2800] loss: 0.149 accuracy: 0.449
[4,  3000] loss: 0.149 accuracy: 0.446
[4,  3200] loss: 0.148 accuracy: 0.451
[4,  3400] loss: 0.149 accuracy: 0.449
[4,  3600] loss: 0.149 accuracy: 0.452
[4,  3800] loss: 0.147 accuracy: 0.457
[4,  4000] loss: 0.150 accuracy: 0.442
[4,  4200] loss: 0.148 accuracy: 0.454
[4,  4400] loss: 0.148 accuracy: 0.455
[4,  4600] loss: 0.150 accuracy: 0.445
[4,  4800] loss: 0.148 accuracy: 0.451
[4,  5000] loss: 0.149 accuracy: 0.451
[4,  5200] loss: 0.148 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4706493244605767


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[5,   200] loss: 0.149 accuracy: 0.450
[5,   400] loss: 0.145 accuracy: 0.460
[5,   600] loss: 0.148 accuracy: 0.457
[5,   800] loss: 0.146 accuracy: 0.463
[5,  1000] loss: 0.146 accuracy: 0.459
[5,  1200] loss: 0.146 accuracy: 0.461
[5,  1400] loss: 0.148 accuracy: 0.458
[5,  1600] loss: 0.149 accuracy: 0.445
[5,  1800] loss: 0.148 accuracy: 0.451
[5,  2000] loss: 0.147 accuracy: 0.454
[5,  2200] loss: 0.147 accuracy: 0.465
[5,  2400] loss: 0.148 accuracy: 0.453
[5,  2600] loss: 0.148 accuracy: 0.456
[5,  2800] loss: 0.147 accuracy: 0.454
[5,  3000] loss: 0.147 accuracy: 0.464
[5,  3200] loss: 0.146 accuracy: 0.461
[5,  3400] loss: 0.147 accuracy: 0.453
[5,  3600] loss: 0.147 accuracy: 0.457
[5,  3800] loss: 0.149 accuracy: 0.441
[5,  4000] loss: 0.147 accuracy: 0.459
[5,  4200] loss: 0.146 accuracy: 0.460
[5,  4400] loss: 0.146 accuracy: 0.456
[5,  4600] loss: 0.147 accuracy: 0.461
[5,  4800] loss: 0.146 accuracy: 0.466
[5,  5000] loss: 0.147 accuracy: 0.458
[5,  5200] loss: 0.148 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47650231901593065


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[6,   200] loss: 0.147 accuracy: 0.460
[6,   400] loss: 0.147 accuracy: 0.460
[6,   600] loss: 0.147 accuracy: 0.457
[6,   800] loss: 0.146 accuracy: 0.464
[6,  1000] loss: 0.145 accuracy: 0.466
[6,  1200] loss: 0.147 accuracy: 0.458
[6,  1400] loss: 0.145 accuracy: 0.457
[6,  1600] loss: 0.146 accuracy: 0.456
[6,  1800] loss: 0.146 accuracy: 0.463
[6,  2000] loss: 0.147 accuracy: 0.452
[6,  2200] loss: 0.146 accuracy: 0.457
[6,  2400] loss: 0.146 accuracy: 0.462
[6,  2600] loss: 0.147 accuracy: 0.460
[6,  2800] loss: 0.148 accuracy: 0.451
[6,  3000] loss: 0.146 accuracy: 0.463
[6,  3200] loss: 0.146 accuracy: 0.456
[6,  3400] loss: 0.147 accuracy: 0.458
[6,  3600] loss: 0.148 accuracy: 0.448
[6,  3800] loss: 0.147 accuracy: 0.463
[6,  4000] loss: 0.147 accuracy: 0.455
[6,  4200] loss: 0.147 accuracy: 0.459
[6,  4400] loss: 0.145 accuracy: 0.470
[6,  4600] loss: 0.146 accuracy: 0.463
[6,  4800] loss: 0.146 accuracy: 0.458
[6,  5000] loss: 0.146 accuracy: 0.463
[6,  5200] loss: 0.147 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48540028231498283


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[7,   200] loss: 0.145 accuracy: 0.468
[7,   400] loss: 0.146 accuracy: 0.453
[7,   600] loss: 0.146 accuracy: 0.463
[7,   800] loss: 0.145 accuracy: 0.466
[7,  1000] loss: 0.144 accuracy: 0.469
[7,  1200] loss: 0.145 accuracy: 0.462
[7,  1400] loss: 0.146 accuracy: 0.461
[7,  1600] loss: 0.145 accuracy: 0.467
[7,  1800] loss: 0.146 accuracy: 0.457
[7,  2000] loss: 0.148 accuracy: 0.450
[7,  2200] loss: 0.146 accuracy: 0.463
[7,  2400] loss: 0.147 accuracy: 0.458
[7,  2600] loss: 0.147 accuracy: 0.460
[7,  2800] loss: 0.145 accuracy: 0.464
[7,  3000] loss: 0.147 accuracy: 0.454
[7,  3200] loss: 0.146 accuracy: 0.467
[7,  3400] loss: 0.146 accuracy: 0.460
[7,  3600] loss: 0.144 accuracy: 0.468
[7,  3800] loss: 0.146 accuracy: 0.464
[7,  4000] loss: 0.145 accuracy: 0.469
[7,  4200] loss: 0.144 accuracy: 0.467
[7,  4400] loss: 0.145 accuracy: 0.468
[7,  4600] loss: 0.146 accuracy: 0.459
[7,  4800] loss: 0.146 accuracy: 0.462
[7,  5000] loss: 0.147 accuracy: 0.458
[7,  5200] loss: 0.144 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4875579753982658


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[8,   200] loss: 0.144 accuracy: 0.469
[8,   400] loss: 0.144 accuracy: 0.474
[8,   600] loss: 0.145 accuracy: 0.467
[8,   800] loss: 0.145 accuracy: 0.462
[8,  1000] loss: 0.144 accuracy: 0.471
[8,  1200] loss: 0.145 accuracy: 0.466
[8,  1400] loss: 0.145 accuracy: 0.463
[8,  1600] loss: 0.144 accuracy: 0.466
[8,  1800] loss: 0.146 accuracy: 0.455
[8,  2000] loss: 0.145 accuracy: 0.466
[8,  2200] loss: 0.145 accuracy: 0.462
[8,  2400] loss: 0.143 accuracy: 0.472
[8,  2600] loss: 0.146 accuracy: 0.468
[8,  2800] loss: 0.145 accuracy: 0.471
[8,  3000] loss: 0.146 accuracy: 0.462
[8,  3200] loss: 0.146 accuracy: 0.459
[8,  3400] loss: 0.145 accuracy: 0.464
[8,  3600] loss: 0.145 accuracy: 0.456
[8,  3800] loss: 0.145 accuracy: 0.468
[8,  4000] loss: 0.143 accuracy: 0.475
[8,  4200] loss: 0.147 accuracy: 0.455
[8,  4400] loss: 0.145 accuracy: 0.464
[8,  4600] loss: 0.147 accuracy: 0.454
[8,  4800] loss: 0.145 accuracy: 0.460
[8,  5000] loss: 0.144 accuracy: 0.469
[8,  5200] loss: 0.146 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48787053841500305


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[9,   200] loss: 0.145 accuracy: 0.463
[9,   400] loss: 0.141 accuracy: 0.482
[9,   600] loss: 0.145 accuracy: 0.462
[9,   800] loss: 0.143 accuracy: 0.471
[9,  1000] loss: 0.145 accuracy: 0.468
[9,  1200] loss: 0.144 accuracy: 0.468
[9,  1400] loss: 0.145 accuracy: 0.467
[9,  1600] loss: 0.144 accuracy: 0.466
[9,  1800] loss: 0.145 accuracy: 0.461
[9,  2000] loss: 0.145 accuracy: 0.467
[9,  2200] loss: 0.145 accuracy: 0.465
[9,  2400] loss: 0.144 accuracy: 0.466
[9,  2600] loss: 0.143 accuracy: 0.475
[9,  2800] loss: 0.145 accuracy: 0.467
[9,  3000] loss: 0.145 accuracy: 0.466
[9,  3200] loss: 0.143 accuracy: 0.470
[9,  3400] loss: 0.144 accuracy: 0.474
[9,  3600] loss: 0.144 accuracy: 0.466
[9,  3800] loss: 0.145 accuracy: 0.466
[9,  4000] loss: 0.147 accuracy: 0.460
[9,  4200] loss: 0.145 accuracy: 0.459
[9,  4400] loss: 0.145 accuracy: 0.462
[9,  4600] loss: 0.145 accuracy: 0.461
[9,  4800] loss: 0.146 accuracy: 0.465
[9,  5000] loss: 0.144 accuracy: 0.470
[9,  5200] loss: 0.144 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.489120790481952


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[10,   200] loss: 0.144 accuracy: 0.465
[10,   400] loss: 0.143 accuracy: 0.470
[10,   600] loss: 0.145 accuracy: 0.468
[10,   800] loss: 0.145 accuracy: 0.460
[10,  1000] loss: 0.145 accuracy: 0.467
[10,  1200] loss: 0.145 accuracy: 0.462
[10,  1400] loss: 0.145 accuracy: 0.465
[10,  1600] loss: 0.145 accuracy: 0.468
[10,  1800] loss: 0.145 accuracy: 0.467
[10,  2000] loss: 0.144 accuracy: 0.469
[10,  2200] loss: 0.143 accuracy: 0.471
[10,  2400] loss: 0.144 accuracy: 0.466
[10,  2600] loss: 0.144 accuracy: 0.466
[10,  2800] loss: 0.143 accuracy: 0.474
[10,  3000] loss: 0.143 accuracy: 0.467
[10,  3200] loss: 0.143 accuracy: 0.472
[10,  3400] loss: 0.143 accuracy: 0.477
[10,  3600] loss: 0.145 accuracy: 0.467
[10,  3800] loss: 0.145 accuracy: 0.463
[10,  4000] loss: 0.143 accuracy: 0.468
[10,  4200] loss: 0.143 accuracy: 0.473
[10,  4400] loss: 0.144 accuracy: 0.470
[10,  4600] loss: 0.144 accuracy: 0.464
[10,  4800] loss: 0.144 accuracy: 0.462
[10,  5000] loss: 0.145 accuracy: 0.462


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4743093365597903


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[11,   200] loss: 0.143 accuracy: 0.469
[11,   400] loss: 0.144 accuracy: 0.470
[11,   600] loss: 0.145 accuracy: 0.460
[11,   800] loss: 0.143 accuracy: 0.473
[11,  1000] loss: 0.143 accuracy: 0.468
[11,  1200] loss: 0.142 accuracy: 0.470
[11,  1400] loss: 0.143 accuracy: 0.471
[11,  1600] loss: 0.144 accuracy: 0.470
[11,  1800] loss: 0.143 accuracy: 0.469
[11,  2000] loss: 0.144 accuracy: 0.472
[11,  2200] loss: 0.143 accuracy: 0.470
[11,  2400] loss: 0.143 accuracy: 0.471
[11,  2600] loss: 0.144 accuracy: 0.465
[11,  2800] loss: 0.144 accuracy: 0.463
[11,  3000] loss: 0.144 accuracy: 0.469
[11,  3200] loss: 0.143 accuracy: 0.465
[11,  3400] loss: 0.143 accuracy: 0.471
[11,  3600] loss: 0.142 accuracy: 0.473
[11,  3800] loss: 0.142 accuracy: 0.471
[11,  4000] loss: 0.142 accuracy: 0.471
[11,  4200] loss: 0.143 accuracy: 0.477
[11,  4400] loss: 0.145 accuracy: 0.459
[11,  4600] loss: 0.144 accuracy: 0.472
[11,  4800] loss: 0.143 accuracy: 0.476
[11,  5000] loss: 0.144 accuracy: 0.470


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47404214559386976


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[12,   200] loss: 0.143 accuracy: 0.471
[12,   400] loss: 0.142 accuracy: 0.479
[12,   600] loss: 0.141 accuracy: 0.476
[12,   800] loss: 0.142 accuracy: 0.471
[12,  1000] loss: 0.144 accuracy: 0.471
[12,  1200] loss: 0.144 accuracy: 0.464
[12,  1400] loss: 0.144 accuracy: 0.469
[12,  1600] loss: 0.142 accuracy: 0.476
[12,  1800] loss: 0.144 accuracy: 0.465
[12,  2000] loss: 0.143 accuracy: 0.477
[12,  2200] loss: 0.143 accuracy: 0.475
[12,  2400] loss: 0.143 accuracy: 0.476
[12,  2600] loss: 0.142 accuracy: 0.478
[12,  2800] loss: 0.144 accuracy: 0.473
[12,  3000] loss: 0.142 accuracy: 0.470
[12,  3200] loss: 0.143 accuracy: 0.474
[12,  3400] loss: 0.144 accuracy: 0.465
[12,  3600] loss: 0.143 accuracy: 0.468
[12,  3800] loss: 0.143 accuracy: 0.475
[12,  4000] loss: 0.145 accuracy: 0.464
[12,  4200] loss: 0.143 accuracy: 0.475
[12,  4400] loss: 0.144 accuracy: 0.468
[12,  4600] loss: 0.143 accuracy: 0.472
[12,  4800] loss: 0.143 accuracy: 0.474
[12,  5000] loss: 0.143 accuracy: 0.468


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.49958156886469046


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[13,   200] loss: 0.142 accuracy: 0.476
[13,   400] loss: 0.142 accuracy: 0.483
[13,   600] loss: 0.143 accuracy: 0.475
[13,   800] loss: 0.142 accuracy: 0.474
[13,  1000] loss: 0.144 accuracy: 0.462
[13,  1200] loss: 0.142 accuracy: 0.475
[13,  1400] loss: 0.140 accuracy: 0.483
[13,  1600] loss: 0.141 accuracy: 0.478
[13,  1800] loss: 0.142 accuracy: 0.474
[13,  2000] loss: 0.143 accuracy: 0.469
[13,  2200] loss: 0.143 accuracy: 0.471
[13,  2400] loss: 0.143 accuracy: 0.473
[13,  2600] loss: 0.144 accuracy: 0.467
[13,  2800] loss: 0.143 accuracy: 0.478
[13,  3000] loss: 0.142 accuracy: 0.476
[13,  3200] loss: 0.143 accuracy: 0.468
[13,  3400] loss: 0.143 accuracy: 0.467
[13,  3600] loss: 0.143 accuracy: 0.468
[13,  3800] loss: 0.143 accuracy: 0.473
[13,  4000] loss: 0.144 accuracy: 0.474
[13,  4200] loss: 0.142 accuracy: 0.475
[13,  4400] loss: 0.143 accuracy: 0.471
[13,  4600] loss: 0.141 accuracy: 0.476
[13,  4800] loss: 0.143 accuracy: 0.470
[13,  5000] loss: 0.145 accuracy: 0.469


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48057067957249444


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[14,   200] loss: 0.141 accuracy: 0.479
[14,   400] loss: 0.141 accuracy: 0.479
[14,   600] loss: 0.142 accuracy: 0.475
[14,   800] loss: 0.142 accuracy: 0.471
[14,  1000] loss: 0.141 accuracy: 0.478
[14,  1200] loss: 0.141 accuracy: 0.480
[14,  1400] loss: 0.142 accuracy: 0.471
[14,  1600] loss: 0.142 accuracy: 0.475
[14,  1800] loss: 0.141 accuracy: 0.478
[14,  2000] loss: 0.143 accuracy: 0.472
[14,  2200] loss: 0.143 accuracy: 0.465
[14,  2400] loss: 0.142 accuracy: 0.471
[14,  2600] loss: 0.143 accuracy: 0.472
[14,  2800] loss: 0.141 accuracy: 0.479
[14,  3000] loss: 0.142 accuracy: 0.472
[14,  3200] loss: 0.142 accuracy: 0.480
[14,  3400] loss: 0.141 accuracy: 0.481
[14,  3600] loss: 0.141 accuracy: 0.479
[14,  3800] loss: 0.143 accuracy: 0.468
[14,  4000] loss: 0.142 accuracy: 0.473
[14,  4200] loss: 0.143 accuracy: 0.474
[14,  4400] loss: 0.142 accuracy: 0.477
[14,  4600] loss: 0.143 accuracy: 0.473
[14,  4800] loss: 0.142 accuracy: 0.475
[14,  5000] loss: 0.144 accuracy: 0.466


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.49820024198427104


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[15,   200] loss: 0.140 accuracy: 0.487
[15,   400] loss: 0.141 accuracy: 0.474
[15,   600] loss: 0.143 accuracy: 0.466
[15,   800] loss: 0.143 accuracy: 0.471
[15,  1000] loss: 0.142 accuracy: 0.473
[15,  1200] loss: 0.141 accuracy: 0.478
[15,  1400] loss: 0.141 accuracy: 0.479
[15,  1600] loss: 0.142 accuracy: 0.471
[15,  1800] loss: 0.143 accuracy: 0.468
[15,  2000] loss: 0.141 accuracy: 0.477
[15,  2200] loss: 0.141 accuracy: 0.482
[15,  2400] loss: 0.142 accuracy: 0.472
[15,  2600] loss: 0.141 accuracy: 0.478
[15,  2800] loss: 0.142 accuracy: 0.476
[15,  3000] loss: 0.141 accuracy: 0.481
[15,  3200] loss: 0.141 accuracy: 0.480
[15,  3400] loss: 0.143 accuracy: 0.471
[15,  3600] loss: 0.141 accuracy: 0.478
[15,  3800] loss: 0.143 accuracy: 0.469
[15,  4000] loss: 0.143 accuracy: 0.468
[15,  4200] loss: 0.142 accuracy: 0.479
[15,  4400] loss: 0.141 accuracy: 0.474
[15,  4600] loss: 0.142 accuracy: 0.476
[15,  4800] loss: 0.143 accuracy: 0.480
[15,  5000] loss: 0.142 accuracy: 0.476


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4959971768501714


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[16,   200] loss: 0.142 accuracy: 0.478
[16,   400] loss: 0.140 accuracy: 0.480
[16,   600] loss: 0.141 accuracy: 0.480
[16,   800] loss: 0.140 accuracy: 0.482
[16,  1000] loss: 0.140 accuracy: 0.485
[16,  1200] loss: 0.141 accuracy: 0.485
[16,  1400] loss: 0.141 accuracy: 0.480
[16,  1600] loss: 0.141 accuracy: 0.472
[16,  1800] loss: 0.140 accuracy: 0.477
[16,  2000] loss: 0.142 accuracy: 0.472
[16,  2200] loss: 0.141 accuracy: 0.476
[16,  2400] loss: 0.141 accuracy: 0.479
[16,  2600] loss: 0.139 accuracy: 0.489
[16,  2800] loss: 0.142 accuracy: 0.476
[16,  3000] loss: 0.142 accuracy: 0.476
[16,  3200] loss: 0.141 accuracy: 0.479
[16,  3400] loss: 0.141 accuracy: 0.480
[16,  3600] loss: 0.142 accuracy: 0.479
[16,  3800] loss: 0.142 accuracy: 0.474
[16,  4000] loss: 0.141 accuracy: 0.482
[16,  4200] loss: 0.141 accuracy: 0.476
[16,  4400] loss: 0.139 accuracy: 0.484
[16,  4600] loss: 0.142 accuracy: 0.480
[16,  4800] loss: 0.142 accuracy: 0.481
[16,  5000] loss: 0.141 accuracy: 0.478


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4986085904416213


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[17,   200] loss: 0.141 accuracy: 0.481
[17,   400] loss: 0.140 accuracy: 0.479
[17,   600] loss: 0.141 accuracy: 0.480
[17,   800] loss: 0.141 accuracy: 0.481
[17,  1000] loss: 0.141 accuracy: 0.480
[17,  1200] loss: 0.139 accuracy: 0.487
[17,  1400] loss: 0.142 accuracy: 0.474
[17,  1600] loss: 0.141 accuracy: 0.472
[17,  1800] loss: 0.141 accuracy: 0.478
[17,  2000] loss: 0.141 accuracy: 0.475
[17,  2200] loss: 0.141 accuracy: 0.479
[17,  2400] loss: 0.141 accuracy: 0.486
[17,  2600] loss: 0.141 accuracy: 0.474
[17,  2800] loss: 0.141 accuracy: 0.472
[17,  3000] loss: 0.141 accuracy: 0.481
[17,  3200] loss: 0.140 accuracy: 0.480
[17,  3400] loss: 0.141 accuracy: 0.478
[17,  3600] loss: 0.140 accuracy: 0.483
[17,  3800] loss: 0.140 accuracy: 0.484
[17,  4000] loss: 0.141 accuracy: 0.480
[17,  4200] loss: 0.142 accuracy: 0.480
[17,  4400] loss: 0.142 accuracy: 0.476
[17,  4600] loss: 0.141 accuracy: 0.481
[17,  4800] loss: 0.141 accuracy: 0.482
[17,  5000] loss: 0.140 accuracy: 0.488


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.49739362774752977


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[18,   200] loss: 0.141 accuracy: 0.479
[18,   400] loss: 0.140 accuracy: 0.487
[18,   600] loss: 0.140 accuracy: 0.486
[18,   800] loss: 0.141 accuracy: 0.478
[18,  1000] loss: 0.141 accuracy: 0.483
[18,  1200] loss: 0.139 accuracy: 0.485
[18,  1400] loss: 0.141 accuracy: 0.477
[18,  1600] loss: 0.142 accuracy: 0.475
[18,  1800] loss: 0.140 accuracy: 0.480
[18,  2000] loss: 0.140 accuracy: 0.481
[18,  2200] loss: 0.140 accuracy: 0.478
[18,  2400] loss: 0.140 accuracy: 0.477
[18,  2600] loss: 0.140 accuracy: 0.488
[18,  2800] loss: 0.142 accuracy: 0.472
[18,  3000] loss: 0.140 accuracy: 0.483
[18,  3200] loss: 0.140 accuracy: 0.477
[18,  3400] loss: 0.140 accuracy: 0.479
[18,  3600] loss: 0.139 accuracy: 0.486
[18,  3800] loss: 0.140 accuracy: 0.487
[18,  4000] loss: 0.140 accuracy: 0.478
[18,  4200] loss: 0.140 accuracy: 0.486
[18,  4400] loss: 0.139 accuracy: 0.485
[18,  4600] loss: 0.140 accuracy: 0.480
[18,  4800] loss: 0.139 accuracy: 0.486
[18,  5000] loss: 0.141 accuracy: 0.477


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.49239261947973384


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[19,   200] loss: 0.140 accuracy: 0.481
[19,   400] loss: 0.140 accuracy: 0.484
[19,   600] loss: 0.139 accuracy: 0.489
[19,   800] loss: 0.141 accuracy: 0.480
[19,  1000] loss: 0.142 accuracy: 0.478
[19,  1200] loss: 0.140 accuracy: 0.480
[19,  1400] loss: 0.140 accuracy: 0.484
[19,  1600] loss: 0.141 accuracy: 0.473
[19,  1800] loss: 0.142 accuracy: 0.474
[19,  2000] loss: 0.140 accuracy: 0.475
[19,  2200] loss: 0.141 accuracy: 0.483
[19,  2400] loss: 0.140 accuracy: 0.477
[19,  2600] loss: 0.140 accuracy: 0.477
[19,  2800] loss: 0.140 accuracy: 0.484
[19,  3000] loss: 0.140 accuracy: 0.484
[19,  3200] loss: 0.138 accuracy: 0.489
[19,  3400] loss: 0.140 accuracy: 0.484
[19,  3600] loss: 0.139 accuracy: 0.489
[19,  3800] loss: 0.138 accuracy: 0.488
[19,  4000] loss: 0.137 accuracy: 0.488
[19,  4200] loss: 0.141 accuracy: 0.479
[19,  4400] loss: 0.140 accuracy: 0.482
[19,  4600] loss: 0.141 accuracy: 0.479
[19,  4800] loss: 0.139 accuracy: 0.491
[19,  5000] loss: 0.140 accuracy: 0.488


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5000907441016333


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[20,   200] loss: 0.140 accuracy: 0.488
[20,   400] loss: 0.139 accuracy: 0.484
[20,   600] loss: 0.141 accuracy: 0.479
[20,   800] loss: 0.141 accuracy: 0.483
[20,  1000] loss: 0.140 accuracy: 0.485
[20,  1200] loss: 0.140 accuracy: 0.479
[20,  1400] loss: 0.141 accuracy: 0.476
[20,  1600] loss: 0.139 accuracy: 0.483
[20,  1800] loss: 0.139 accuracy: 0.485
[20,  2000] loss: 0.138 accuracy: 0.491
[20,  2200] loss: 0.141 accuracy: 0.484
[20,  2400] loss: 0.141 accuracy: 0.480
[20,  2600] loss: 0.141 accuracy: 0.483
[20,  2800] loss: 0.139 accuracy: 0.484
[20,  3000] loss: 0.139 accuracy: 0.489
[20,  3200] loss: 0.140 accuracy: 0.483
[20,  3400] loss: 0.139 accuracy: 0.485
[20,  3600] loss: 0.141 accuracy: 0.480
[20,  3800] loss: 0.140 accuracy: 0.484
[20,  4000] loss: 0.139 accuracy: 0.479
[20,  4200] loss: 0.140 accuracy: 0.479
[20,  4400] loss: 0.140 accuracy: 0.485
[20,  4600] loss: 0.141 accuracy: 0.479
[20,  4800] loss: 0.139 accuracy: 0.489
[20,  5000] loss: 0.140 accuracy: 0.484


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5068058076225045


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[21,   200] loss: 0.141 accuracy: 0.476
[21,   400] loss: 0.138 accuracy: 0.490
[21,   600] loss: 0.139 accuracy: 0.487
[21,   800] loss: 0.139 accuracy: 0.486
[21,  1000] loss: 0.139 accuracy: 0.485
[21,  1200] loss: 0.138 accuracy: 0.485
[21,  1400] loss: 0.139 accuracy: 0.481
[21,  1600] loss: 0.139 accuracy: 0.483
[21,  1800] loss: 0.140 accuracy: 0.481
[21,  2000] loss: 0.140 accuracy: 0.481
[21,  2200] loss: 0.138 accuracy: 0.488
[21,  2400] loss: 0.139 accuracy: 0.485
[21,  2600] loss: 0.140 accuracy: 0.480
[21,  2800] loss: 0.140 accuracy: 0.488
[21,  3000] loss: 0.141 accuracy: 0.474
[21,  3200] loss: 0.140 accuracy: 0.479
[21,  3400] loss: 0.141 accuracy: 0.481
[21,  3600] loss: 0.140 accuracy: 0.482
[21,  3800] loss: 0.140 accuracy: 0.480
[21,  4000] loss: 0.138 accuracy: 0.493
[21,  4200] loss: 0.139 accuracy: 0.481
[21,  4400] loss: 0.140 accuracy: 0.478
[21,  4600] loss: 0.140 accuracy: 0.483
[21,  4800] loss: 0.141 accuracy: 0.478
[21,  5000] loss: 0.138 accuracy: 0.489


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5069267997580157


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[22,   200] loss: 0.139 accuracy: 0.484
[22,   400] loss: 0.141 accuracy: 0.483
[22,   600] loss: 0.140 accuracy: 0.479
[22,   800] loss: 0.137 accuracy: 0.484
[22,  1000] loss: 0.138 accuracy: 0.486
[22,  1200] loss: 0.139 accuracy: 0.487
[22,  1400] loss: 0.140 accuracy: 0.482
[22,  1600] loss: 0.140 accuracy: 0.485
[22,  1800] loss: 0.139 accuracy: 0.485
[22,  2000] loss: 0.139 accuracy: 0.483
[22,  2200] loss: 0.140 accuracy: 0.486
[22,  2400] loss: 0.140 accuracy: 0.485
[22,  2600] loss: 0.139 accuracy: 0.489
[22,  2800] loss: 0.139 accuracy: 0.489
[22,  3000] loss: 0.138 accuracy: 0.489
[22,  3200] loss: 0.140 accuracy: 0.479
[22,  3400] loss: 0.139 accuracy: 0.487
[22,  3600] loss: 0.139 accuracy: 0.484
[22,  3800] loss: 0.141 accuracy: 0.476
[22,  4000] loss: 0.139 accuracy: 0.487
[22,  4200] loss: 0.139 accuracy: 0.483
[22,  4400] loss: 0.136 accuracy: 0.496
[22,  4600] loss: 0.138 accuracy: 0.491
[22,  4800] loss: 0.139 accuracy: 0.487
[22,  5000] loss: 0.139 accuracy: 0.489


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.510274248840492


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[23,   200] loss: 0.138 accuracy: 0.491
[23,   400] loss: 0.139 accuracy: 0.482
[23,   600] loss: 0.138 accuracy: 0.483
[23,   800] loss: 0.139 accuracy: 0.484
[23,  1000] loss: 0.140 accuracy: 0.481
[23,  1200] loss: 0.139 accuracy: 0.494
[23,  1400] loss: 0.137 accuracy: 0.487
[23,  1600] loss: 0.138 accuracy: 0.488
[23,  1800] loss: 0.137 accuracy: 0.489
[23,  2000] loss: 0.139 accuracy: 0.484
[23,  2200] loss: 0.140 accuracy: 0.485
[23,  2400] loss: 0.140 accuracy: 0.486
[23,  2600] loss: 0.140 accuracy: 0.484
[23,  2800] loss: 0.139 accuracy: 0.483
[23,  3000] loss: 0.137 accuracy: 0.490
[23,  3200] loss: 0.140 accuracy: 0.493
[23,  3400] loss: 0.139 accuracy: 0.481
[23,  3600] loss: 0.139 accuracy: 0.482
[23,  3800] loss: 0.139 accuracy: 0.491
[23,  4000] loss: 0.141 accuracy: 0.481
[23,  4200] loss: 0.138 accuracy: 0.488
[23,  4400] loss: 0.138 accuracy: 0.492
[23,  4600] loss: 0.139 accuracy: 0.487
[23,  4800] loss: 0.138 accuracy: 0.496
[23,  5000] loss: 0.139 accuracy: 0.485


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5114640048396855


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[24,   200] loss: 0.139 accuracy: 0.488
[24,   400] loss: 0.137 accuracy: 0.499
[24,   600] loss: 0.137 accuracy: 0.492
[24,   800] loss: 0.138 accuracy: 0.488
[24,  1000] loss: 0.139 accuracy: 0.488
[24,  1200] loss: 0.138 accuracy: 0.487
[24,  1400] loss: 0.140 accuracy: 0.479
[24,  1600] loss: 0.138 accuracy: 0.490
[24,  1800] loss: 0.138 accuracy: 0.491
[24,  2000] loss: 0.139 accuracy: 0.489
[24,  2200] loss: 0.139 accuracy: 0.490
[24,  2400] loss: 0.137 accuracy: 0.497
[24,  2600] loss: 0.138 accuracy: 0.492
[24,  2800] loss: 0.139 accuracy: 0.487
[24,  3000] loss: 0.138 accuracy: 0.491
[24,  3200] loss: 0.138 accuracy: 0.488
[24,  3400] loss: 0.138 accuracy: 0.487
[24,  3600] loss: 0.139 accuracy: 0.485
[24,  3800] loss: 0.139 accuracy: 0.486
[24,  4000] loss: 0.139 accuracy: 0.489
[24,  4200] loss: 0.139 accuracy: 0.492
[24,  4400] loss: 0.140 accuracy: 0.485
[24,  4600] loss: 0.139 accuracy: 0.485
[24,  4800] loss: 0.137 accuracy: 0.493
[24,  5000] loss: 0.137 accuracy: 0.488


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5047791893526921


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[25,   200] loss: 0.139 accuracy: 0.491
[25,   400] loss: 0.138 accuracy: 0.487
[25,   600] loss: 0.140 accuracy: 0.485
[25,   800] loss: 0.137 accuracy: 0.490
[25,  1000] loss: 0.138 accuracy: 0.490
[25,  1200] loss: 0.138 accuracy: 0.491
[25,  1400] loss: 0.138 accuracy: 0.487
[25,  1600] loss: 0.139 accuracy: 0.483
[25,  1800] loss: 0.138 accuracy: 0.487
[25,  2000] loss: 0.140 accuracy: 0.479
[25,  2200] loss: 0.140 accuracy: 0.475
[25,  2400] loss: 0.139 accuracy: 0.485
[25,  2600] loss: 0.138 accuracy: 0.487
[25,  2800] loss: 0.138 accuracy: 0.486
[25,  3000] loss: 0.138 accuracy: 0.490
[25,  3200] loss: 0.137 accuracy: 0.492
[25,  3400] loss: 0.139 accuracy: 0.493
[25,  3600] loss: 0.138 accuracy: 0.489
[25,  3800] loss: 0.140 accuracy: 0.488
[25,  4000] loss: 0.140 accuracy: 0.481
[25,  4200] loss: 0.139 accuracy: 0.487
[25,  4400] loss: 0.139 accuracy: 0.484
[25,  4600] loss: 0.139 accuracy: 0.488
[25,  4800] loss: 0.139 accuracy: 0.481
[25,  5000] loss: 0.139 accuracy: 0.488


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.514120790481952


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[26,   200] loss: 0.139 accuracy: 0.482
[26,   400] loss: 0.138 accuracy: 0.485
[26,   600] loss: 0.139 accuracy: 0.487
[26,   800] loss: 0.138 accuracy: 0.487
[26,  1000] loss: 0.137 accuracy: 0.495
[26,  1200] loss: 0.139 accuracy: 0.483
[26,  1400] loss: 0.138 accuracy: 0.489
[26,  1600] loss: 0.136 accuracy: 0.498
[26,  1800] loss: 0.137 accuracy: 0.492
[26,  2000] loss: 0.137 accuracy: 0.493
[26,  2200] loss: 0.139 accuracy: 0.482
[26,  2400] loss: 0.138 accuracy: 0.492
[26,  2600] loss: 0.137 accuracy: 0.491
[26,  2800] loss: 0.138 accuracy: 0.486
[26,  3000] loss: 0.138 accuracy: 0.489
[26,  3200] loss: 0.139 accuracy: 0.484
[26,  3400] loss: 0.138 accuracy: 0.482
[26,  3600] loss: 0.136 accuracy: 0.494
[26,  3800] loss: 0.139 accuracy: 0.485
[26,  4000] loss: 0.138 accuracy: 0.490
[26,  4200] loss: 0.138 accuracy: 0.489
[26,  4400] loss: 0.137 accuracy: 0.492
[26,  4600] loss: 0.138 accuracy: 0.491
[26,  4800] loss: 0.139 accuracy: 0.487
[26,  5000] loss: 0.139 accuracy: 0.482


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[26,  8800] loss: 0.139 accuracy: 0.489
[26,  9000] loss: 0.138 accuracy: 0.489
[26,  9200] loss: 0.139 accuracy: 0.491
[26,  9400] loss: 0.139 accuracy: 0.485
[26,  9600] loss: 0.139 accuracy: 0.482
[26,  9800] loss: 0.137 accuracy: 0.495
[26, 10000] loss: 0.138 accuracy: 0.488
[26, 10200] loss: 0.137 accuracy: 0.490
[26, 10400] loss: 0.138 accuracy: 0.491
[26, 10600] loss: 0.140 accuracy: 0.484
[26, 10800] loss: 0.138 accuracy: 0.490
[26, 11000] loss: 0.139 accuracy: 0.489
[26, 11200] loss: 0.139 accuracy: 0.490
[26, 11400] loss: 0.139 accuracy: 0.484
[26, 11600] loss: 0.139 accuracy: 0.487
[26, 11800] loss: 0.138 accuracy: 0.491
[26, 12000] loss: 0.140 accuracy: 0.484
[26, 12200] loss: 0.139 accuracy: 0.489
[26, 12400] loss: 0.139 accuracy: 0.492
[26, 12600] loss: 0.138 accuracy: 0.490
[26, 12800] loss: 0.137 accuracy: 0.494
[26, 13000] loss: 0.137 accuracy: 0.492
[26, 13200] loss: 0.138 accuracy: 0.487
[26, 13400] loss: 0.139 accuracy: 0.490
[26, 13600] loss: 0.138 accuracy: 0.488


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5034381931841097


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[27,   200] loss: 0.139 accuracy: 0.484
[27,   400] loss: 0.138 accuracy: 0.486
[27,   600] loss: 0.137 accuracy: 0.493
[27,   800] loss: 0.138 accuracy: 0.490
[27,  1000] loss: 0.137 accuracy: 0.495
[27,  1200] loss: 0.138 accuracy: 0.483
[27,  1400] loss: 0.139 accuracy: 0.483
[27,  1600] loss: 0.138 accuracy: 0.488
[27,  1800] loss: 0.137 accuracy: 0.495
[27,  2000] loss: 0.137 accuracy: 0.488
[27,  2200] loss: 0.137 accuracy: 0.495
[27,  2400] loss: 0.137 accuracy: 0.493
[27,  2600] loss: 0.138 accuracy: 0.489
[27,  2800] loss: 0.139 accuracy: 0.484
[27,  3000] loss: 0.138 accuracy: 0.488
[27,  3200] loss: 0.138 accuracy: 0.486
[27,  3400] loss: 0.139 accuracy: 0.487
[27,  3600] loss: 0.139 accuracy: 0.486
[27,  3800] loss: 0.138 accuracy: 0.484
[27,  4000] loss: 0.137 accuracy: 0.491
[27,  4200] loss: 0.137 accuracy: 0.490
[27,  4400] loss: 0.138 accuracy: 0.494
[27,  4600] loss: 0.138 accuracy: 0.492
[27,  4800] loss: 0.137 accuracy: 0.492
[27,  5000] loss: 0.138 accuracy: 0.490


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5125176446864287


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[28,   200] loss: 0.136 accuracy: 0.498
[28,   400] loss: 0.137 accuracy: 0.490
[28,   600] loss: 0.138 accuracy: 0.491
[28,   800] loss: 0.137 accuracy: 0.489
[28,  1000] loss: 0.138 accuracy: 0.489
[28,  1200] loss: 0.136 accuracy: 0.500
[28,  1400] loss: 0.139 accuracy: 0.484
[28,  1600] loss: 0.137 accuracy: 0.496
[28,  1800] loss: 0.136 accuracy: 0.497
[28,  2000] loss: 0.138 accuracy: 0.487
[28,  2200] loss: 0.138 accuracy: 0.493
[28,  2400] loss: 0.137 accuracy: 0.493
[28,  2600] loss: 0.139 accuracy: 0.486
[28,  2800] loss: 0.137 accuracy: 0.492
[28,  3000] loss: 0.138 accuracy: 0.492
[28,  3200] loss: 0.137 accuracy: 0.489
[28,  3400] loss: 0.138 accuracy: 0.493
[28,  3600] loss: 0.137 accuracy: 0.490
[28,  3800] loss: 0.137 accuracy: 0.487
[28,  4000] loss: 0.138 accuracy: 0.488
[28,  4200] loss: 0.139 accuracy: 0.485
[28,  4400] loss: 0.139 accuracy: 0.490
[28,  4600] loss: 0.139 accuracy: 0.489
[28,  4800] loss: 0.138 accuracy: 0.489
[28,  5000] loss: 0.137 accuracy: 0.493


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5137426900584795


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[29,   200] loss: 0.138 accuracy: 0.493
[29,   400] loss: 0.137 accuracy: 0.496
[29,   600] loss: 0.138 accuracy: 0.489
[29,   800] loss: 0.138 accuracy: 0.490
[29,  1000] loss: 0.136 accuracy: 0.496
[29,  1200] loss: 0.137 accuracy: 0.491
[29,  1400] loss: 0.138 accuracy: 0.490
[29,  1600] loss: 0.138 accuracy: 0.488
[29,  1800] loss: 0.137 accuracy: 0.495
[29,  2000] loss: 0.138 accuracy: 0.492
[29,  2200] loss: 0.138 accuracy: 0.494
[29,  2400] loss: 0.137 accuracy: 0.490
[29,  2600] loss: 0.137 accuracy: 0.502
[29,  2800] loss: 0.138 accuracy: 0.491
[29,  3000] loss: 0.138 accuracy: 0.486
[29,  3200] loss: 0.137 accuracy: 0.490
[29,  3400] loss: 0.137 accuracy: 0.497
[29,  3600] loss: 0.138 accuracy: 0.491
[29,  3800] loss: 0.137 accuracy: 0.493
[29,  4000] loss: 0.137 accuracy: 0.488
[29,  4200] loss: 0.138 accuracy: 0.490
[29,  4400] loss: 0.138 accuracy: 0.490
[29,  4600] loss: 0.136 accuracy: 0.499
[29,  4800] loss: 0.139 accuracy: 0.487
[29,  5000] loss: 0.137 accuracy: 0.494


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5147358338374672


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[30,   200] loss: 0.137 accuracy: 0.496
[30,   400] loss: 0.136 accuracy: 0.497
[30,   600] loss: 0.135 accuracy: 0.499
[30,   800] loss: 0.138 accuracy: 0.486
[30,  1000] loss: 0.139 accuracy: 0.486
[30,  1200] loss: 0.137 accuracy: 0.491
[30,  1400] loss: 0.137 accuracy: 0.496
[30,  1600] loss: 0.136 accuracy: 0.499
[30,  1800] loss: 0.139 accuracy: 0.481
[30,  2000] loss: 0.137 accuracy: 0.494
[30,  2200] loss: 0.137 accuracy: 0.497
[30,  2400] loss: 0.138 accuracy: 0.487
[30,  2600] loss: 0.138 accuracy: 0.491
[30,  2800] loss: 0.137 accuracy: 0.493
[30,  3000] loss: 0.137 accuracy: 0.490
[30,  3200] loss: 0.136 accuracy: 0.492
[30,  3400] loss: 0.137 accuracy: 0.496
[30,  3600] loss: 0.137 accuracy: 0.496
[30,  3800] loss: 0.136 accuracy: 0.497
[30,  4000] loss: 0.138 accuracy: 0.487
[30,  4200] loss: 0.137 accuracy: 0.497
[30,  4400] loss: 0.140 accuracy: 0.485
[30,  4600] loss: 0.137 accuracy: 0.496
[30,  4800] loss: 0.138 accuracy: 0.494
[30,  5000] loss: 0.138 accuracy: 0.482


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.5105313571284533


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[31,   200] loss: 0.136 accuracy: 0.497
[31,   400] loss: 0.137 accuracy: 0.494
[31,   600] loss: 0.137 accuracy: 0.488
[31,   800] loss: 0.135 accuracy: 0.493
[31,  1000] loss: 0.136 accuracy: 0.493
[31,  1200] loss: 0.137 accuracy: 0.490
[31,  1400] loss: 0.136 accuracy: 0.494
[31,  1600] loss: 0.138 accuracy: 0.489
[31,  1800] loss: 0.139 accuracy: 0.485
[31,  2000] loss: 0.136 accuracy: 0.499
[31,  2200] loss: 0.138 accuracy: 0.492
[31,  2400] loss: 0.135 accuracy: 0.493
[31,  2600] loss: 0.138 accuracy: 0.492
[31,  2800] loss: 0.137 accuracy: 0.491
[31,  3000] loss: 0.137 accuracy: 0.492
[31,  3200] loss: 0.139 accuracy: 0.487
[31,  3400] loss: 0.138 accuracy: 0.487
[31,  3600] loss: 0.136 accuracy: 0.496
[31,  3800] loss: 0.137 accuracy: 0.486
[31,  4000] loss: 0.139 accuracy: 0.490
[31,  4200] loss: 0.135 accuracy: 0.497
[31,  4400] loss: 0.137 accuracy: 0.497
[31,  4600] loss: 0.139 accuracy: 0.489
[31,  4800] loss: 0.139 accuracy: 0.488
[31,  5000] loss: 0.138 accuracy: 0.494


KeyboardInterrupt: 

In [9]:
checkpoint = torch.load('models/best-baseline.pth')
overnet.load_state_dict(checkpoint['model_state_dict'])

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
        
corrects = 0
for batch_idx, (spec, target) in enumerate(tqdm(testloader)):
    spec, target = spec.float(), target.float()
    if half:
        spec, target = spec.half(), target.half()
    spec = spec.cuda(device)
    target = target.cuda(device)
    overnet.eval()
    out = overnet(spec) 
    corrects += compute_corrects(out, target)
print('test acc:', corrects/len(testset))
overnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


test acc: 0.5110859044162129
