In [6]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False
root = '../'

In [3]:
device = 2
torch.cuda.set_device(device)


In [4]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(root+overlay['first_file'])/(2**15)
        second_segment = np.load(root+overlay['second_file'])/(2**15)
        third_segment = np.load(root+overlay['third_file'])/(2**15)
        max_len = max(len(first_segment), len(second_segment), len(third_segment))
        #padding to compensate rounding errors
        if max_len>len(first_segment):
            padding = np.zeros(max_len-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        if max_len>len(second_segment):
            padding = np.zeros(max_len-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
            
        if max_len>len(third_segment):
            padding = np.zeros(max_len-len(third_segment))
            third_segment = np.concatenate((third_segment, padding))
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        third_idx = self.spkr2idx[overlay['third_speaker']]

        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        target[third_idx] = 1.0
        
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(third_segment), self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment+third_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.mfcc(segment, sr=16000, n_mfcc=20, dct_type=2, n_fft = 1024, hop_length = 160)[1:14].T
        # 200*13
        S1 = np.diff(S)
        S2 = np.diff(S1)
        S = np.concatenate((S, S1, S2), axis = -1)
        return S

trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
testset = OverlayDataSet('overlay-test.csv', False)

In [7]:
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)
print(spec3.max(), spec3.min(), spec3.shape, target)

265.79355 -200.36958 (200, 36) [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0.]


<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [8]:
class Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = nn.LayerNorm(36)
        self.reshape =  Lambda(lambda x: x.permute((1, 0, 2))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(36, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.sigmoid = nn.Sigmoid()
    def forward(self, X):
        X = self.bn(X)
        X = self.reshape(X)
        X, _ = self.lstm(X)
        X = self.fc1(X)
        X = self.average(X)
        X = self.tanh(X)
        X = self.fc2(X)
        X = self.sigmoid(X)
        return X
    
    
overnet = Baseline().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/lstm.pth'):
    print('load model')
    checkpoint = torch.load('models/lstm.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

initializing new model


('bestacc:', 0.0)

## Also Do metrics on hitting a single person right

In [11]:
def find_max3(tensor):
    array = tensor.cpu().detach().numpy()
    max3 = []
    for row in array:
        max3.append(np.argsort(row)[::-1][:3])
    return np.array(max3)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max3(tensor1), find_max3(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

batch_size = 64
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/lstm.pth')

        
        
    corrects = 0
    for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(), target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)
        overnet.eval()
        out = overnet(spec) 
        corrects += compute_corrects(out, target)
    print('val acc:', corrects/len(valset))
    if corrects/len(valset) > bestacc:
        bestacc = corrects/len(valset)
        torch.save({
        'model_state_dict': overnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'bestacc': bestacc
        }, 'models/best-lstm.pth')
    overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[1,   200] loss: 0.153 accuracy: 0.438
[1,   400] loss: 0.154 accuracy: 0.431
[1,   600] loss: 0.153 accuracy: 0.435
[1,   800] loss: 0.153 accuracy: 0.438
[1,  1000] loss: 0.153 accuracy: 0.436
[1,  1200] loss: 0.154 accuracy: 0.430
[1,  1400] loss: 0.154 accuracy: 0.441
[1,  1600] loss: 0.153 accuracy: 0.432
[1,  1800] loss: 0.154 accuracy: 0.428
[1,  2000] loss: 0.152 accuracy: 0.440
[1,  2200] loss: 0.153 accuracy: 0.439
[1,  2400] loss: 0.154 accuracy: 0.435
[1,  2600] loss: 0.154 accuracy: 0.434
[1,  2800] loss: 0.155 accuracy: 0.432
[1,  3000] loss: 0.153 accuracy: 0.436
[1,  3200] loss: 0.155 accuracy: 0.437
[1,  3400] loss: 0.153 accuracy: 0.440
[1,  3600] loss: 0.155 accuracy: 0.435
[1,  3800] loss: 0.152 accuracy: 0.443
[1,  4000] loss: 0.154 accuracy: 0.436
[1,  4200] loss: 0.154 accuracy: 0.432
[1,  4400] loss: 0.154 accuracy: 0.434
[1,  4600] loss: 0.153 accuracy: 0.442
[1,  4800] loss: 0.154 accuracy: 0.426
[1,  5000] loss: 0.153 accuracy: 0.437
[1,  5200] loss: 0.153 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4497227263561202


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[2,   200] loss: 0.151 accuracy: 0.445
[2,   400] loss: 0.152 accuracy: 0.436
[2,   600] loss: 0.153 accuracy: 0.442
[2,   800] loss: 0.152 accuracy: 0.441
[2,  1000] loss: 0.152 accuracy: 0.441
[2,  1200] loss: 0.153 accuracy: 0.439
[2,  1400] loss: 0.153 accuracy: 0.436
[2,  1600] loss: 0.154 accuracy: 0.437
[2,  1800] loss: 0.153 accuracy: 0.436
[2,  2000] loss: 0.152 accuracy: 0.439
[2,  2200] loss: 0.153 accuracy: 0.435
[2,  2400] loss: 0.153 accuracy: 0.438
[2,  2600] loss: 0.152 accuracy: 0.444
[2,  2800] loss: 0.153 accuracy: 0.436
[2,  3000] loss: 0.152 accuracy: 0.436
[2,  3200] loss: 0.153 accuracy: 0.437
[2,  3400] loss: 0.152 accuracy: 0.440
[2,  3600] loss: 0.153 accuracy: 0.438
[2,  3800] loss: 0.154 accuracy: 0.431
[2,  4000] loss: 0.153 accuracy: 0.433
[2,  4200] loss: 0.154 accuracy: 0.438
[2,  4400] loss: 0.153 accuracy: 0.437
[2,  4600] loss: 0.153 accuracy: 0.432
[2,  4800] loss: 0.153 accuracy: 0.441
[2,  5000] loss: 0.154 accuracy: 0.438
[2,  5200] loss: 0.154 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.45617059891107076


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[3,   200] loss: 0.151 accuracy: 0.440
[3,   400] loss: 0.152 accuracy: 0.439
[3,   600] loss: 0.151 accuracy: 0.442
[3,   800] loss: 0.152 accuracy: 0.440
[3,  1000] loss: 0.151 accuracy: 0.442
[3,  1200] loss: 0.153 accuracy: 0.436
[3,  1400] loss: 0.151 accuracy: 0.443
[3,  1600] loss: 0.152 accuracy: 0.438
[3,  1800] loss: 0.152 accuracy: 0.444
[3,  2000] loss: 0.151 accuracy: 0.447
[3,  2200] loss: 0.153 accuracy: 0.436
[3,  2400] loss: 0.151 accuracy: 0.441
[3,  2600] loss: 0.153 accuracy: 0.440
[3,  2800] loss: 0.151 accuracy: 0.446
[3,  3000] loss: 0.151 accuracy: 0.445
[3,  3200] loss: 0.152 accuracy: 0.440
[3,  3400] loss: 0.151 accuracy: 0.440
[3,  3600] loss: 0.152 accuracy: 0.439
[3,  3800] loss: 0.153 accuracy: 0.439
[3,  4000] loss: 0.151 accuracy: 0.443
[3,  4200] loss: 0.152 accuracy: 0.442
[3,  4400] loss: 0.153 accuracy: 0.435
[3,  4600] loss: 0.152 accuracy: 0.444
[3,  4800] loss: 0.153 accuracy: 0.440
[3,  5000] loss: 0.153 accuracy: 0.441
[3,  5200] loss: 0.151 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4527979431336963


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[4,   200] loss: 0.152 accuracy: 0.436
[4,   400] loss: 0.150 accuracy: 0.448
[4,   600] loss: 0.150 accuracy: 0.442
[4,   800] loss: 0.152 accuracy: 0.441
[4,  1000] loss: 0.151 accuracy: 0.439
[4,  1200] loss: 0.151 accuracy: 0.444
[4,  1400] loss: 0.151 accuracy: 0.444
[4,  1600] loss: 0.151 accuracy: 0.440
[4,  1800] loss: 0.150 accuracy: 0.449
[4,  2000] loss: 0.152 accuracy: 0.443
[4,  2200] loss: 0.151 accuracy: 0.444
[4,  2400] loss: 0.151 accuracy: 0.450
[4,  2600] loss: 0.152 accuracy: 0.437
[4,  2800] loss: 0.152 accuracy: 0.440
[4,  3000] loss: 0.152 accuracy: 0.441
[4,  3200] loss: 0.149 accuracy: 0.446
[4,  3400] loss: 0.149 accuracy: 0.447
[4,  3600] loss: 0.151 accuracy: 0.443
[4,  3800] loss: 0.154 accuracy: 0.432
[4,  4000] loss: 0.152 accuracy: 0.436
[4,  4200] loss: 0.152 accuracy: 0.438
[4,  4400] loss: 0.151 accuracy: 0.445
[4,  4600] loss: 0.151 accuracy: 0.449
[4,  4800] loss: 0.151 accuracy: 0.446
[4,  5000] loss: 0.152 accuracy: 0.439
[4,  5200] loss: 0.150 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.46104053236539627


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[5,   200] loss: 0.150 accuracy: 0.441
[5,   400] loss: 0.150 accuracy: 0.446
[5,   600] loss: 0.149 accuracy: 0.454
[5,   800] loss: 0.149 accuracy: 0.451
[5,  1000] loss: 0.151 accuracy: 0.447
[5,  1200] loss: 0.149 accuracy: 0.449
[5,  1400] loss: 0.150 accuracy: 0.449
[5,  1600] loss: 0.151 accuracy: 0.445
[5,  1800] loss: 0.151 accuracy: 0.443
[5,  2000] loss: 0.150 accuracy: 0.449
[5,  2200] loss: 0.151 accuracy: 0.441
[5,  2400] loss: 0.150 accuracy: 0.446
[5,  2600] loss: 0.150 accuracy: 0.452
[5,  2800] loss: 0.151 accuracy: 0.449
[5,  3000] loss: 0.152 accuracy: 0.437
[5,  3200] loss: 0.151 accuracy: 0.443
[5,  3400] loss: 0.149 accuracy: 0.448
[5,  3600] loss: 0.149 accuracy: 0.447
[5,  3800] loss: 0.150 accuracy: 0.446
[5,  4000] loss: 0.152 accuracy: 0.441
[5,  4200] loss: 0.150 accuracy: 0.448
[5,  4400] loss: 0.150 accuracy: 0.442
[5,  4600] loss: 0.151 accuracy: 0.441
[5,  4800] loss: 0.149 accuracy: 0.453
[5,  5000] loss: 0.150 accuracy: 0.449
[5,  5200] loss: 0.149 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4664750957854406


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[6,   200] loss: 0.150 accuracy: 0.449
[6,   400] loss: 0.150 accuracy: 0.445
[6,   600] loss: 0.150 accuracy: 0.444
[6,   800] loss: 0.150 accuracy: 0.447
[6,  1000] loss: 0.149 accuracy: 0.452
[6,  1200] loss: 0.149 accuracy: 0.449
[6,  1400] loss: 0.150 accuracy: 0.439
[6,  1600] loss: 0.147 accuracy: 0.454
[6,  1800] loss: 0.150 accuracy: 0.443
[6,  2000] loss: 0.151 accuracy: 0.440
[6,  2200] loss: 0.149 accuracy: 0.448
[6,  2400] loss: 0.149 accuracy: 0.453
[6,  2600] loss: 0.148 accuracy: 0.457
[6,  2800] loss: 0.149 accuracy: 0.447
[6,  3000] loss: 0.150 accuracy: 0.442
[6,  3200] loss: 0.150 accuracy: 0.449
[6,  3400] loss: 0.149 accuracy: 0.455
[6,  3600] loss: 0.149 accuracy: 0.450
[6,  3800] loss: 0.152 accuracy: 0.437
[6,  4000] loss: 0.150 accuracy: 0.443
[6,  4200] loss: 0.150 accuracy: 0.445
[6,  4400] loss: 0.150 accuracy: 0.445
[6,  4600] loss: 0.150 accuracy: 0.448
[6,  4800] loss: 0.149 accuracy: 0.446
[6,  5000] loss: 0.149 accuracy: 0.456
[6,  5200] loss: 0.151 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.46277979431336963


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[7,   200] loss: 0.149 accuracy: 0.447
[7,   400] loss: 0.149 accuracy: 0.449
[7,   600] loss: 0.149 accuracy: 0.447
[7,   800] loss: 0.150 accuracy: 0.449
[7,  1000] loss: 0.151 accuracy: 0.445
[7,  1200] loss: 0.149 accuracy: 0.450
[7,  1400] loss: 0.150 accuracy: 0.445
[7,  1600] loss: 0.149 accuracy: 0.451
[7,  1800] loss: 0.148 accuracy: 0.458
[7,  2000] loss: 0.148 accuracy: 0.453
[7,  2200] loss: 0.149 accuracy: 0.449
[7,  2400] loss: 0.149 accuracy: 0.450
[7,  2600] loss: 0.148 accuracy: 0.454
[7,  2800] loss: 0.151 accuracy: 0.445
[7,  3000] loss: 0.149 accuracy: 0.450
[7,  3200] loss: 0.149 accuracy: 0.450
[7,  3400] loss: 0.149 accuracy: 0.450
[7,  3600] loss: 0.150 accuracy: 0.444
[7,  3800] loss: 0.149 accuracy: 0.451
[7,  4000] loss: 0.148 accuracy: 0.452
[7,  4200] loss: 0.150 accuracy: 0.450
[7,  4400] loss: 0.148 accuracy: 0.450
[7,  4600] loss: 0.149 accuracy: 0.448
[7,  4800] loss: 0.149 accuracy: 0.455
[7,  5000] loss: 0.149 accuracy: 0.456
[7,  5200] loss: 0.149 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4648366606170599


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[8,   200] loss: 0.149 accuracy: 0.450
[8,   400] loss: 0.149 accuracy: 0.450
[8,   600] loss: 0.150 accuracy: 0.444
[8,   800] loss: 0.148 accuracy: 0.453
[8,  1000] loss: 0.149 accuracy: 0.456
[8,  1200] loss: 0.147 accuracy: 0.451
[8,  1400] loss: 0.150 accuracy: 0.443
[8,  1600] loss: 0.149 accuracy: 0.447
[8,  1800] loss: 0.149 accuracy: 0.449
[8,  2000] loss: 0.149 accuracy: 0.458
[8,  2200] loss: 0.149 accuracy: 0.448
[8,  2400] loss: 0.148 accuracy: 0.455
[8,  2600] loss: 0.148 accuracy: 0.449
[8,  2800] loss: 0.148 accuracy: 0.454
[8,  3000] loss: 0.149 accuracy: 0.450
[8,  3200] loss: 0.149 accuracy: 0.450
[8,  3400] loss: 0.148 accuracy: 0.453
[8,  3600] loss: 0.149 accuracy: 0.452
[8,  3800] loss: 0.149 accuracy: 0.454
[8,  4000] loss: 0.150 accuracy: 0.446
[8,  4200] loss: 0.148 accuracy: 0.453
[8,  4400] loss: 0.148 accuracy: 0.459
[8,  4600] loss: 0.148 accuracy: 0.454
[8,  4800] loss: 0.146 accuracy: 0.457
[8,  5000] loss: 0.148 accuracy: 0.452
[8,  5200] loss: 0.148 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4641107078039927


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[9,   200] loss: 0.147 accuracy: 0.459
[9,   400] loss: 0.149 accuracy: 0.450
[9,   600] loss: 0.148 accuracy: 0.447
[9,   800] loss: 0.147 accuracy: 0.459
[9,  1000] loss: 0.149 accuracy: 0.454
[9,  1200] loss: 0.149 accuracy: 0.454
[9,  1400] loss: 0.149 accuracy: 0.450
[9,  1600] loss: 0.147 accuracy: 0.458
[9,  1800] loss: 0.147 accuracy: 0.464
[9,  2000] loss: 0.150 accuracy: 0.443
[9,  2200] loss: 0.148 accuracy: 0.453
[9,  2400] loss: 0.151 accuracy: 0.441
[9,  2600] loss: 0.149 accuracy: 0.450
[9,  2800] loss: 0.147 accuracy: 0.456
[9,  3000] loss: 0.148 accuracy: 0.459
[9,  3200] loss: 0.150 accuracy: 0.448
[9,  3400] loss: 0.147 accuracy: 0.461
[9,  3600] loss: 0.149 accuracy: 0.450
[9,  3800] loss: 0.148 accuracy: 0.456
[9,  4000] loss: 0.149 accuracy: 0.446
[9,  4200] loss: 0.149 accuracy: 0.452
[9,  4400] loss: 0.148 accuracy: 0.451
[9,  4600] loss: 0.148 accuracy: 0.447
[9,  4800] loss: 0.148 accuracy: 0.454
[9,  5000] loss: 0.148 accuracy: 0.452
[9,  5200] loss: 0.148 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4654920346844122


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[10,   200] loss: 0.148 accuracy: 0.460
[10,   400] loss: 0.146 accuracy: 0.461
[10,   600] loss: 0.148 accuracy: 0.455
[10,   800] loss: 0.148 accuracy: 0.462
[10,  1000] loss: 0.147 accuracy: 0.454
[10,  1200] loss: 0.147 accuracy: 0.459
[10,  1400] loss: 0.148 accuracy: 0.454
[10,  1600] loss: 0.150 accuracy: 0.446
[10,  1800] loss: 0.147 accuracy: 0.464
[10,  2000] loss: 0.147 accuracy: 0.456
[10,  2200] loss: 0.148 accuracy: 0.456
[10,  2400] loss: 0.148 accuracy: 0.453
[10,  2600] loss: 0.148 accuracy: 0.454
[10,  2800] loss: 0.147 accuracy: 0.461
[10,  3000] loss: 0.147 accuracy: 0.458
[10,  3200] loss: 0.149 accuracy: 0.448
[10,  3400] loss: 0.148 accuracy: 0.450
[10,  3600] loss: 0.148 accuracy: 0.451
[10,  3800] loss: 0.149 accuracy: 0.456
[10,  4000] loss: 0.146 accuracy: 0.459
[10,  4200] loss: 0.148 accuracy: 0.448
[10,  4400] loss: 0.147 accuracy: 0.456
[10,  4600] loss: 0.147 accuracy: 0.455
[10,  4800] loss: 0.147 accuracy: 0.454
[10,  5000] loss: 0.149 accuracy: 0.454


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.45745614035087717


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[11,   200] loss: 0.150 accuracy: 0.450
[11,   400] loss: 0.146 accuracy: 0.461
[11,   600] loss: 0.148 accuracy: 0.450
[11,   800] loss: 0.146 accuracy: 0.456
[11,  1000] loss: 0.145 accuracy: 0.463
[11,  1200] loss: 0.148 accuracy: 0.461
[11,  1400] loss: 0.147 accuracy: 0.461
[11,  1600] loss: 0.148 accuracy: 0.451
[11,  1800] loss: 0.145 accuracy: 0.459
[11,  2000] loss: 0.148 accuracy: 0.458
[11,  2200] loss: 0.148 accuracy: 0.451
[11,  2400] loss: 0.148 accuracy: 0.449
[11,  2600] loss: 0.147 accuracy: 0.453
[11,  2800] loss: 0.147 accuracy: 0.456
[11,  3000] loss: 0.146 accuracy: 0.465
[11,  3200] loss: 0.147 accuracy: 0.453
[11,  3400] loss: 0.148 accuracy: 0.453
[11,  3600] loss: 0.147 accuracy: 0.455
[11,  3800] loss: 0.148 accuracy: 0.456
[11,  4000] loss: 0.147 accuracy: 0.458
[11,  4200] loss: 0.149 accuracy: 0.450
[11,  4400] loss: 0.147 accuracy: 0.456
[11,  4600] loss: 0.148 accuracy: 0.454
[11,  4800] loss: 0.146 accuracy: 0.460
[11,  5000] loss: 0.146 accuracy: 0.455


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.46863783020770317


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[12,   200] loss: 0.147 accuracy: 0.456
[12,   400] loss: 0.148 accuracy: 0.454
[12,   600] loss: 0.146 accuracy: 0.459
[12,   800] loss: 0.147 accuracy: 0.451
[12,  1000] loss: 0.145 accuracy: 0.456
[12,  1200] loss: 0.147 accuracy: 0.454
[12,  1400] loss: 0.148 accuracy: 0.451
[12,  1600] loss: 0.148 accuracy: 0.456
[12,  1800] loss: 0.147 accuracy: 0.456
[12,  2000] loss: 0.146 accuracy: 0.459
[12,  2200] loss: 0.148 accuracy: 0.454
[12,  2400] loss: 0.147 accuracy: 0.460
[12,  2600] loss: 0.147 accuracy: 0.460
[12,  2800] loss: 0.148 accuracy: 0.454
[12,  3000] loss: 0.148 accuracy: 0.452
[12,  3200] loss: 0.147 accuracy: 0.453
[12,  3400] loss: 0.149 accuracy: 0.450
[12,  3600] loss: 0.145 accuracy: 0.466
[12,  3800] loss: 0.148 accuracy: 0.448
[12,  4000] loss: 0.147 accuracy: 0.456
[12,  4200] loss: 0.148 accuracy: 0.451
[12,  4400] loss: 0.147 accuracy: 0.451
[12,  4600] loss: 0.148 accuracy: 0.456
[12,  4800] loss: 0.146 accuracy: 0.462
[12,  5000] loss: 0.148 accuracy: 0.454


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47101734220608993


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[13,   200] loss: 0.146 accuracy: 0.464
[13,   400] loss: 0.147 accuracy: 0.458
[13,   600] loss: 0.145 accuracy: 0.459
[13,   800] loss: 0.146 accuracy: 0.460
[13,  1000] loss: 0.146 accuracy: 0.460
[13,  1200] loss: 0.147 accuracy: 0.456
[13,  1400] loss: 0.146 accuracy: 0.461
[13,  1600] loss: 0.147 accuracy: 0.450
[13,  1800] loss: 0.148 accuracy: 0.455
[13,  2000] loss: 0.147 accuracy: 0.457
[13,  2200] loss: 0.145 accuracy: 0.465
[13,  2400] loss: 0.147 accuracy: 0.462
[13,  2600] loss: 0.147 accuracy: 0.455
[13,  2800] loss: 0.147 accuracy: 0.457
[13,  3000] loss: 0.146 accuracy: 0.461
[13,  3200] loss: 0.145 accuracy: 0.464
[13,  3400] loss: 0.147 accuracy: 0.456
[13,  3600] loss: 0.147 accuracy: 0.449
[13,  3800] loss: 0.146 accuracy: 0.462
[13,  4000] loss: 0.147 accuracy: 0.454
[13,  4200] loss: 0.147 accuracy: 0.464
[13,  4400] loss: 0.148 accuracy: 0.452
[13,  4600] loss: 0.146 accuracy: 0.460
[13,  4800] loss: 0.147 accuracy: 0.453
[13,  5000] loss: 0.146 accuracy: 0.455


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4755646299657189


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[14,   200] loss: 0.146 accuracy: 0.459
[14,   400] loss: 0.147 accuracy: 0.460
[14,   600] loss: 0.146 accuracy: 0.455
[14,   800] loss: 0.146 accuracy: 0.463
[14,  1000] loss: 0.146 accuracy: 0.459
[14,  1200] loss: 0.147 accuracy: 0.461
[14,  1400] loss: 0.146 accuracy: 0.463
[14,  1600] loss: 0.147 accuracy: 0.451
[14,  1800] loss: 0.146 accuracy: 0.456
[14,  2000] loss: 0.146 accuracy: 0.458
[14,  2200] loss: 0.147 accuracy: 0.456
[14,  2400] loss: 0.146 accuracy: 0.458
[14,  2600] loss: 0.146 accuracy: 0.460
[14,  2800] loss: 0.146 accuracy: 0.458
[14,  3000] loss: 0.146 accuracy: 0.460
[14,  3200] loss: 0.147 accuracy: 0.458
[14,  3400] loss: 0.146 accuracy: 0.461
[14,  3600] loss: 0.146 accuracy: 0.464
[14,  3800] loss: 0.146 accuracy: 0.464
[14,  4000] loss: 0.147 accuracy: 0.459
[14,  4200] loss: 0.146 accuracy: 0.458
[14,  4400] loss: 0.147 accuracy: 0.457
[14,  4600] loss: 0.145 accuracy: 0.463
[14,  4800] loss: 0.146 accuracy: 0.455
[14,  5000] loss: 0.149 accuracy: 0.452


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4680127041742287


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[15,   200] loss: 0.146 accuracy: 0.459
[15,   400] loss: 0.146 accuracy: 0.457
[15,   600] loss: 0.147 accuracy: 0.459
[15,   800] loss: 0.146 accuracy: 0.457
[15,  1000] loss: 0.145 accuracy: 0.466
[15,  1200] loss: 0.146 accuracy: 0.456
[15,  1400] loss: 0.144 accuracy: 0.467
[15,  1600] loss: 0.145 accuracy: 0.463
[15,  1800] loss: 0.145 accuracy: 0.466
[15,  2000] loss: 0.148 accuracy: 0.456
[15,  2200] loss: 0.146 accuracy: 0.461
[15,  2400] loss: 0.145 accuracy: 0.467
[15,  2600] loss: 0.145 accuracy: 0.462
[15,  2800] loss: 0.147 accuracy: 0.459
[15,  3000] loss: 0.145 accuracy: 0.466
[15,  3200] loss: 0.146 accuracy: 0.460
[15,  3400] loss: 0.146 accuracy: 0.463
[15,  3600] loss: 0.148 accuracy: 0.452
[15,  3800] loss: 0.145 accuracy: 0.463
[15,  4000] loss: 0.146 accuracy: 0.466
[15,  4200] loss: 0.145 accuracy: 0.466
[15,  4400] loss: 0.146 accuracy: 0.465
[15,  4600] loss: 0.147 accuracy: 0.456
[15,  4800] loss: 0.146 accuracy: 0.460
[15,  5000] loss: 0.147 accuracy: 0.457


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47426396450897357


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[16,   200] loss: 0.146 accuracy: 0.462
[16,   400] loss: 0.146 accuracy: 0.463
[16,   600] loss: 0.145 accuracy: 0.465
[16,   800] loss: 0.146 accuracy: 0.456
[16,  1000] loss: 0.146 accuracy: 0.460
[16,  1200] loss: 0.146 accuracy: 0.463
[16,  1400] loss: 0.145 accuracy: 0.464
[16,  1600] loss: 0.147 accuracy: 0.453
[16,  1800] loss: 0.146 accuracy: 0.458
[16,  2000] loss: 0.145 accuracy: 0.465
[16,  2200] loss: 0.146 accuracy: 0.457
[16,  2400] loss: 0.144 accuracy: 0.466
[16,  2600] loss: 0.145 accuracy: 0.468
[16,  2800] loss: 0.145 accuracy: 0.464
[16,  3000] loss: 0.145 accuracy: 0.460
[16,  3200] loss: 0.146 accuracy: 0.464
[16,  3400] loss: 0.147 accuracy: 0.457
[16,  3600] loss: 0.145 accuracy: 0.461
[16,  3800] loss: 0.145 accuracy: 0.461
[16,  4000] loss: 0.146 accuracy: 0.458
[16,  4200] loss: 0.145 accuracy: 0.465
[16,  4400] loss: 0.146 accuracy: 0.460
[16,  4600] loss: 0.145 accuracy: 0.461
[16,  4800] loss: 0.144 accuracy: 0.465
[16,  5000] loss: 0.146 accuracy: 0.459


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47759124823553134


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[17,   200] loss: 0.146 accuracy: 0.460
[17,   400] loss: 0.145 accuracy: 0.467
[17,   600] loss: 0.145 accuracy: 0.466
[17,   800] loss: 0.146 accuracy: 0.470
[17,  1000] loss: 0.146 accuracy: 0.458
[17,  1200] loss: 0.146 accuracy: 0.459
[17,  1400] loss: 0.145 accuracy: 0.463
[17,  1600] loss: 0.146 accuracy: 0.454
[17,  1800] loss: 0.145 accuracy: 0.468
[17,  2000] loss: 0.145 accuracy: 0.464
[17,  2200] loss: 0.145 accuracy: 0.466
[17,  2400] loss: 0.145 accuracy: 0.465
[17,  2600] loss: 0.144 accuracy: 0.474
[17,  2800] loss: 0.146 accuracy: 0.459
[17,  3000] loss: 0.145 accuracy: 0.458
[17,  3200] loss: 0.146 accuracy: 0.460
[17,  3400] loss: 0.145 accuracy: 0.463
[17,  3600] loss: 0.146 accuracy: 0.454
[17,  3800] loss: 0.145 accuracy: 0.464
[17,  4000] loss: 0.146 accuracy: 0.465
[17,  4200] loss: 0.145 accuracy: 0.468
[17,  4400] loss: 0.143 accuracy: 0.466
[17,  4600] loss: 0.144 accuracy: 0.472
[17,  4800] loss: 0.145 accuracy: 0.461
[17,  5000] loss: 0.146 accuracy: 0.459


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48084291187739464


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[18,   200] loss: 0.144 accuracy: 0.468
[18,   400] loss: 0.144 accuracy: 0.464
[18,   600] loss: 0.144 accuracy: 0.462
[18,   800] loss: 0.145 accuracy: 0.463
[18,  1000] loss: 0.146 accuracy: 0.460
[18,  1200] loss: 0.145 accuracy: 0.463
[18,  1400] loss: 0.146 accuracy: 0.462
[18,  1600] loss: 0.144 accuracy: 0.467
[18,  1800] loss: 0.145 accuracy: 0.466
[18,  2000] loss: 0.145 accuracy: 0.464
[18,  2200] loss: 0.145 accuracy: 0.463
[18,  2400] loss: 0.145 accuracy: 0.466
[18,  2600] loss: 0.144 accuracy: 0.466
[18,  2800] loss: 0.146 accuracy: 0.457
[18,  3000] loss: 0.145 accuracy: 0.462
[18,  3200] loss: 0.145 accuracy: 0.464
[18,  3400] loss: 0.145 accuracy: 0.466
[18,  3600] loss: 0.144 accuracy: 0.459
[18,  3800] loss: 0.144 accuracy: 0.471
[18,  4000] loss: 0.144 accuracy: 0.468
[18,  4200] loss: 0.145 accuracy: 0.462
[18,  4400] loss: 0.145 accuracy: 0.468
[18,  4600] loss: 0.146 accuracy: 0.458
[18,  4800] loss: 0.146 accuracy: 0.456
[18,  5000] loss: 0.145 accuracy: 0.463


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47896753377697115


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[19,   200] loss: 0.145 accuracy: 0.461
[19,   400] loss: 0.145 accuracy: 0.464
[19,   600] loss: 0.145 accuracy: 0.462
[19,   800] loss: 0.146 accuracy: 0.465
[19,  1000] loss: 0.146 accuracy: 0.460
[19,  1200] loss: 0.145 accuracy: 0.464
[19,  1400] loss: 0.143 accuracy: 0.467
[19,  1600] loss: 0.144 accuracy: 0.471
[19,  1800] loss: 0.144 accuracy: 0.465
[19,  2000] loss: 0.143 accuracy: 0.468
[19,  2200] loss: 0.145 accuracy: 0.464
[19,  2400] loss: 0.144 accuracy: 0.466
[19,  2600] loss: 0.145 accuracy: 0.464
[19,  2800] loss: 0.143 accuracy: 0.467
[19,  3000] loss: 0.143 accuracy: 0.470
[19,  3200] loss: 0.144 accuracy: 0.473
[19,  3400] loss: 0.145 accuracy: 0.466
[19,  3600] loss: 0.146 accuracy: 0.463
[19,  3800] loss: 0.145 accuracy: 0.463
[19,  4000] loss: 0.147 accuracy: 0.455
[19,  4200] loss: 0.146 accuracy: 0.461
[19,  4400] loss: 0.143 accuracy: 0.473
[19,  4600] loss: 0.145 accuracy: 0.465
[19,  4800] loss: 0.144 accuracy: 0.472
[19,  5000] loss: 0.144 accuracy: 0.462


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48049505948779997


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[20,   200] loss: 0.143 accuracy: 0.472
[20,   400] loss: 0.143 accuracy: 0.475
[20,   600] loss: 0.144 accuracy: 0.463
[20,   800] loss: 0.145 accuracy: 0.463
[20,  1000] loss: 0.144 accuracy: 0.467
[20,  1200] loss: 0.146 accuracy: 0.455
[20,  1400] loss: 0.145 accuracy: 0.461
[20,  1600] loss: 0.143 accuracy: 0.468
[20,  1800] loss: 0.142 accuracy: 0.469
[20,  2000] loss: 0.143 accuracy: 0.471
[20,  2200] loss: 0.146 accuracy: 0.464
[20,  2400] loss: 0.146 accuracy: 0.467
[20,  2600] loss: 0.144 accuracy: 0.461
[20,  2800] loss: 0.143 accuracy: 0.469
[20,  3000] loss: 0.144 accuracy: 0.471
[20,  3200] loss: 0.144 accuracy: 0.471
[20,  3400] loss: 0.143 accuracy: 0.474
[20,  3600] loss: 0.145 accuracy: 0.459
[20,  3800] loss: 0.145 accuracy: 0.462
[20,  4000] loss: 0.145 accuracy: 0.461
[20,  4200] loss: 0.144 accuracy: 0.471
[20,  4400] loss: 0.146 accuracy: 0.460
[20,  4600] loss: 0.143 accuracy: 0.469
[20,  4800] loss: 0.144 accuracy: 0.471
[20,  5000] loss: 0.144 accuracy: 0.468


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4766081871345029


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[21,   200] loss: 0.143 accuracy: 0.471
[21,   400] loss: 0.143 accuracy: 0.470
[21,   600] loss: 0.144 accuracy: 0.469
[21,   800] loss: 0.144 accuracy: 0.466
[21,  1000] loss: 0.145 accuracy: 0.466
[21,  1200] loss: 0.143 accuracy: 0.468
[21,  1400] loss: 0.144 accuracy: 0.469
[21,  1600] loss: 0.143 accuracy: 0.470
[21,  1800] loss: 0.142 accuracy: 0.475
[21,  2000] loss: 0.144 accuracy: 0.465
[21,  2200] loss: 0.142 accuracy: 0.471
[21,  2400] loss: 0.144 accuracy: 0.467
[21,  2600] loss: 0.143 accuracy: 0.471
[21,  2800] loss: 0.143 accuracy: 0.468
[21,  3000] loss: 0.145 accuracy: 0.456
[21,  3200] loss: 0.146 accuracy: 0.464
[21,  3400] loss: 0.144 accuracy: 0.463
[21,  3600] loss: 0.145 accuracy: 0.460
[21,  3800] loss: 0.143 accuracy: 0.472
[21,  4000] loss: 0.145 accuracy: 0.462
[21,  4200] loss: 0.145 accuracy: 0.467
[21,  4400] loss: 0.146 accuracy: 0.459
[21,  4600] loss: 0.146 accuracy: 0.463
[21,  4800] loss: 0.145 accuracy: 0.461
[21,  5000] loss: 0.143 accuracy: 0.463


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4793708408953418


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[22,   200] loss: 0.144 accuracy: 0.467
[22,   400] loss: 0.144 accuracy: 0.463
[22,   600] loss: 0.145 accuracy: 0.462
[22,   800] loss: 0.143 accuracy: 0.474
[22,  1000] loss: 0.143 accuracy: 0.467
[22,  1200] loss: 0.143 accuracy: 0.468
[22,  1400] loss: 0.145 accuracy: 0.463
[22,  1600] loss: 0.145 accuracy: 0.470
[22,  1800] loss: 0.142 accuracy: 0.476
[22,  2000] loss: 0.144 accuracy: 0.469
[22,  2200] loss: 0.144 accuracy: 0.469
[22,  2400] loss: 0.145 accuracy: 0.465
[22,  2600] loss: 0.143 accuracy: 0.468
[22,  2800] loss: 0.145 accuracy: 0.466
[22,  3000] loss: 0.144 accuracy: 0.466
[22,  3200] loss: 0.144 accuracy: 0.461
[22,  3400] loss: 0.143 accuracy: 0.464
[22,  3600] loss: 0.144 accuracy: 0.467
[22,  3800] loss: 0.145 accuracy: 0.465
[22,  4000] loss: 0.145 accuracy: 0.460
[22,  4200] loss: 0.144 accuracy: 0.471
[22,  4400] loss: 0.145 accuracy: 0.465
[22,  4600] loss: 0.146 accuracy: 0.457
[22,  4800] loss: 0.143 accuracy: 0.469
[22,  5000] loss: 0.145 accuracy: 0.465


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48338374672312967


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[23,   200] loss: 0.143 accuracy: 0.472
[23,   400] loss: 0.144 accuracy: 0.469
[23,   600] loss: 0.142 accuracy: 0.477
[23,   800] loss: 0.143 accuracy: 0.470
[23,  1000] loss: 0.143 accuracy: 0.465
[23,  1200] loss: 0.142 accuracy: 0.469
[23,  1400] loss: 0.144 accuracy: 0.472
[23,  1600] loss: 0.144 accuracy: 0.471
[23,  1800] loss: 0.144 accuracy: 0.464
[23,  2000] loss: 0.144 accuracy: 0.465
[23,  2200] loss: 0.145 accuracy: 0.462
[23,  2400] loss: 0.143 accuracy: 0.469
[23,  2600] loss: 0.144 accuracy: 0.465
[23,  2800] loss: 0.143 accuracy: 0.475
[23,  3000] loss: 0.144 accuracy: 0.471
[23,  3200] loss: 0.144 accuracy: 0.470
[23,  3400] loss: 0.143 accuracy: 0.475
[23,  3600] loss: 0.143 accuracy: 0.471
[23,  3800] loss: 0.145 accuracy: 0.468
[23,  4000] loss: 0.144 accuracy: 0.471
[23,  4200] loss: 0.143 accuracy: 0.473
[23,  4400] loss: 0.144 accuracy: 0.471
[23,  4600] loss: 0.144 accuracy: 0.473
[23,  4800] loss: 0.143 accuracy: 0.474
[23,  5000] loss: 0.144 accuracy: 0.471


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4837114337568058


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[24,   200] loss: 0.143 accuracy: 0.477
[24,   400] loss: 0.142 accuracy: 0.473
[24,   600] loss: 0.142 accuracy: 0.482
[24,   800] loss: 0.142 accuracy: 0.476
[24,  1000] loss: 0.144 accuracy: 0.471
[24,  1200] loss: 0.142 accuracy: 0.475
[24,  1400] loss: 0.142 accuracy: 0.476
[24,  1600] loss: 0.143 accuracy: 0.470
[24,  1800] loss: 0.143 accuracy: 0.474
[24,  2000] loss: 0.143 accuracy: 0.470
[24,  2200] loss: 0.144 accuracy: 0.461
[24,  2400] loss: 0.144 accuracy: 0.463
[24,  2600] loss: 0.142 accuracy: 0.475
[24,  2800] loss: 0.144 accuracy: 0.470
[24,  3000] loss: 0.144 accuracy: 0.471
[24,  3200] loss: 0.144 accuracy: 0.459
[24,  3400] loss: 0.145 accuracy: 0.466
[24,  3600] loss: 0.142 accuracy: 0.471
[24,  3800] loss: 0.144 accuracy: 0.469
[24,  4000] loss: 0.145 accuracy: 0.464
[24,  4200] loss: 0.144 accuracy: 0.463
[24,  4400] loss: 0.143 accuracy: 0.469
[24,  4600] loss: 0.145 accuracy: 0.462
[24,  4800] loss: 0.145 accuracy: 0.463
[24,  5000] loss: 0.143 accuracy: 0.463


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48341399475700747


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[25,   200] loss: 0.142 accuracy: 0.480
[25,   400] loss: 0.142 accuracy: 0.472
[25,   600] loss: 0.143 accuracy: 0.474
[25,   800] loss: 0.145 accuracy: 0.462
[25,  1000] loss: 0.143 accuracy: 0.471
[25,  1200] loss: 0.144 accuracy: 0.474
[25,  1400] loss: 0.143 accuracy: 0.466
[25,  1600] loss: 0.143 accuracy: 0.469
[25,  1800] loss: 0.143 accuracy: 0.471
[25,  2000] loss: 0.143 accuracy: 0.474
[25,  2200] loss: 0.144 accuracy: 0.466
[25,  2400] loss: 0.143 accuracy: 0.466
[25,  2600] loss: 0.144 accuracy: 0.467
[25,  2800] loss: 0.144 accuracy: 0.462
[25,  3000] loss: 0.144 accuracy: 0.468
[25,  3200] loss: 0.144 accuracy: 0.469
[25,  3400] loss: 0.143 accuracy: 0.471
[25,  3600] loss: 0.143 accuracy: 0.468
[25,  3800] loss: 0.144 accuracy: 0.462
[25,  4000] loss: 0.144 accuracy: 0.464
[25,  4200] loss: 0.145 accuracy: 0.461
[25,  4400] loss: 0.142 accuracy: 0.473
[25,  4600] loss: 0.144 accuracy: 0.465
[25,  4800] loss: 0.142 accuracy: 0.471
[25,  5000] loss: 0.144 accuracy: 0.468


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4775357935067554


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[26,   200] loss: 0.142 accuracy: 0.470
[26,   400] loss: 0.143 accuracy: 0.470
[26,   600] loss: 0.143 accuracy: 0.474
[26,   800] loss: 0.141 accuracy: 0.475
[26,  1000] loss: 0.144 accuracy: 0.467
[26,  1200] loss: 0.145 accuracy: 0.466
[26,  1400] loss: 0.143 accuracy: 0.472
[26,  1600] loss: 0.142 accuracy: 0.471
[26,  1800] loss: 0.145 accuracy: 0.473
[26,  2000] loss: 0.142 accuracy: 0.474
[26,  2200] loss: 0.142 accuracy: 0.475
[26,  2400] loss: 0.143 accuracy: 0.466
[26,  2600] loss: 0.143 accuracy: 0.466
[26,  2800] loss: 0.144 accuracy: 0.470
[26,  3000] loss: 0.144 accuracy: 0.463
[26,  3200] loss: 0.142 accuracy: 0.475
[26,  3400] loss: 0.142 accuracy: 0.474
[26,  3600] loss: 0.142 accuracy: 0.473
[26,  3800] loss: 0.143 accuracy: 0.469
[26,  4000] loss: 0.144 accuracy: 0.469
[26,  4200] loss: 0.143 accuracy: 0.467
[26,  4400] loss: 0.144 accuracy: 0.466
[26,  4600] loss: 0.143 accuracy: 0.464
[26,  4800] loss: 0.143 accuracy: 0.465
[26,  5000] loss: 0.143 accuracy: 0.468


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.47982456140350876


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[27,   200] loss: 0.143 accuracy: 0.468
[27,   400] loss: 0.142 accuracy: 0.481
[27,   600] loss: 0.141 accuracy: 0.477
[27,   800] loss: 0.143 accuracy: 0.467
[27,  1000] loss: 0.143 accuracy: 0.464
[27,  1200] loss: 0.144 accuracy: 0.468
[27,  1400] loss: 0.143 accuracy: 0.470
[27,  1600] loss: 0.142 accuracy: 0.471
[27,  1800] loss: 0.144 accuracy: 0.467
[27,  2000] loss: 0.143 accuracy: 0.471
[27,  2200] loss: 0.141 accuracy: 0.476
[27,  2400] loss: 0.145 accuracy: 0.467
[27,  2600] loss: 0.141 accuracy: 0.477
[27,  2800] loss: 0.143 accuracy: 0.467
[27,  3000] loss: 0.142 accuracy: 0.473
[27,  3200] loss: 0.142 accuracy: 0.474
[27,  3400] loss: 0.144 accuracy: 0.468
[27,  3600] loss: 0.142 accuracy: 0.476
[27,  3800] loss: 0.144 accuracy: 0.467
[27,  4000] loss: 0.144 accuracy: 0.462
[27,  4200] loss: 0.143 accuracy: 0.475
[27,  4400] loss: 0.145 accuracy: 0.461
[27,  4600] loss: 0.142 accuracy: 0.471
[27,  4800] loss: 0.143 accuracy: 0.472
[27,  5000] loss: 0.142 accuracy: 0.466


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.479890098810244


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[28,   200] loss: 0.142 accuracy: 0.473
[28,   400] loss: 0.143 accuracy: 0.471
[28,   600] loss: 0.143 accuracy: 0.470
[28,   800] loss: 0.142 accuracy: 0.473
[28,  1000] loss: 0.143 accuracy: 0.474
[28,  1200] loss: 0.141 accuracy: 0.483
[28,  1400] loss: 0.145 accuracy: 0.466
[28,  1600] loss: 0.142 accuracy: 0.477
[28,  1800] loss: 0.141 accuracy: 0.475
[28,  2000] loss: 0.142 accuracy: 0.469
[28,  2200] loss: 0.142 accuracy: 0.471
[28,  2400] loss: 0.141 accuracy: 0.473
[28,  2600] loss: 0.142 accuracy: 0.472
[28,  2800] loss: 0.142 accuracy: 0.470
[28,  3000] loss: 0.141 accuracy: 0.479
[28,  3200] loss: 0.143 accuracy: 0.469
[28,  3400] loss: 0.143 accuracy: 0.468
[28,  3600] loss: 0.143 accuracy: 0.475
[28,  3800] loss: 0.143 accuracy: 0.469
[28,  4000] loss: 0.143 accuracy: 0.470
[28,  4200] loss: 0.143 accuracy: 0.470
[28,  4400] loss: 0.142 accuracy: 0.473
[28,  4600] loss: 0.142 accuracy: 0.470
[28,  4800] loss: 0.144 accuracy: 0.462
[28,  5000] loss: 0.144 accuracy: 0.470


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[28,  6000] loss: 0.142 accuracy: 0.473
[28,  6200] loss: 0.142 accuracy: 0.471
[28,  6400] loss: 0.142 accuracy: 0.479
[28,  6600] loss: 0.144 accuracy: 0.473
[28,  6800] loss: 0.144 accuracy: 0.469
[28,  7000] loss: 0.142 accuracy: 0.474
[28,  7200] loss: 0.142 accuracy: 0.474
[28,  7400] loss: 0.142 accuracy: 0.474
[28,  7600] loss: 0.145 accuracy: 0.469
[28,  7800] loss: 0.143 accuracy: 0.466
[28,  8000] loss: 0.142 accuracy: 0.472
[28,  8200] loss: 0.142 accuracy: 0.476
[28,  8400] loss: 0.142 accuracy: 0.471
[28,  8600] loss: 0.144 accuracy: 0.475
[28,  8800] loss: 0.141 accuracy: 0.474
[28,  9000] loss: 0.143 accuracy: 0.468
[28,  9200] loss: 0.144 accuracy: 0.473
[28,  9400] loss: 0.142 accuracy: 0.470
[28,  9600] loss: 0.141 accuracy: 0.472
[28,  9800] loss: 0.143 accuracy: 0.472
[28, 10000] loss: 0.142 accuracy: 0.476
[28, 10200] loss: 0.143 accuracy: 0.472
[28, 10400] loss: 0.142 accuracy: 0.474
[28, 10600] loss: 0.142 accuracy: 0.469
[28, 10800] loss: 0.142 accuracy: 0.473


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4879864892115346


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[29,   200] loss: 0.143 accuracy: 0.476
[29,   400] loss: 0.143 accuracy: 0.474
[29,   600] loss: 0.143 accuracy: 0.469
[29,   800] loss: 0.141 accuracy: 0.477
[29,  1000] loss: 0.143 accuracy: 0.473
[29,  1200] loss: 0.143 accuracy: 0.471
[29,  1400] loss: 0.141 accuracy: 0.474
[29,  1600] loss: 0.142 accuracy: 0.471
[29,  1800] loss: 0.142 accuracy: 0.471
[29,  2000] loss: 0.143 accuracy: 0.467
[29,  2200] loss: 0.143 accuracy: 0.471
[29,  2400] loss: 0.143 accuracy: 0.467
[29,  2600] loss: 0.140 accuracy: 0.486
[29,  2800] loss: 0.143 accuracy: 0.469
[29,  3000] loss: 0.141 accuracy: 0.472
[29,  3200] loss: 0.141 accuracy: 0.470
[29,  3400] loss: 0.143 accuracy: 0.471
[29,  3600] loss: 0.143 accuracy: 0.470
[29,  3800] loss: 0.142 accuracy: 0.470
[29,  4000] loss: 0.142 accuracy: 0.471
[29,  4200] loss: 0.142 accuracy: 0.476
[29,  4400] loss: 0.141 accuracy: 0.476
[29,  4600] loss: 0.143 accuracy: 0.468
[29,  4800] loss: 0.143 accuracy: 0.472
[29,  5000] loss: 0.142 accuracy: 0.466


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4880822746521476


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[30,   200] loss: 0.141 accuracy: 0.478
[30,   400] loss: 0.140 accuracy: 0.483
[30,   600] loss: 0.143 accuracy: 0.469
[30,   800] loss: 0.142 accuracy: 0.476
[30,  1000] loss: 0.142 accuracy: 0.474
[30,  1200] loss: 0.144 accuracy: 0.468
[30,  1400] loss: 0.142 accuracy: 0.475
[30,  1600] loss: 0.144 accuracy: 0.469
[30,  1800] loss: 0.141 accuracy: 0.480
[30,  2000] loss: 0.141 accuracy: 0.473
[30,  2200] loss: 0.143 accuracy: 0.469
[30,  2400] loss: 0.142 accuracy: 0.468
[30,  2600] loss: 0.144 accuracy: 0.469
[30,  2800] loss: 0.143 accuracy: 0.470
[30,  3000] loss: 0.141 accuracy: 0.477
[30,  3200] loss: 0.141 accuracy: 0.482
[30,  3400] loss: 0.142 accuracy: 0.469
[30,  3600] loss: 0.142 accuracy: 0.472
[30,  3800] loss: 0.143 accuracy: 0.474
[30,  4000] loss: 0.143 accuracy: 0.472
[30,  4200] loss: 0.141 accuracy: 0.479
[30,  4400] loss: 0.142 accuracy: 0.479
[30,  4600] loss: 0.143 accuracy: 0.466
[30,  4800] loss: 0.143 accuracy: 0.470
[30,  5000] loss: 0.142 accuracy: 0.473


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.4832526719096592


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[31,   200] loss: 0.141 accuracy: 0.480
[31,   400] loss: 0.140 accuracy: 0.479
[31,   600] loss: 0.142 accuracy: 0.471
[31,   800] loss: 0.141 accuracy: 0.480
[31,  1000] loss: 0.142 accuracy: 0.474
[31,  1200] loss: 0.143 accuracy: 0.465
[31,  1400] loss: 0.144 accuracy: 0.465
[31,  1600] loss: 0.143 accuracy: 0.476
[31,  1800] loss: 0.141 accuracy: 0.482
[31,  2000] loss: 0.141 accuracy: 0.483
[31,  2200] loss: 0.143 accuracy: 0.471
[31,  2400] loss: 0.142 accuracy: 0.479
[31,  2600] loss: 0.141 accuracy: 0.476
[31,  2800] loss: 0.141 accuracy: 0.480
[31,  3000] loss: 0.143 accuracy: 0.465
[31,  3200] loss: 0.141 accuracy: 0.476
[31,  3400] loss: 0.143 accuracy: 0.470
[31,  3600] loss: 0.144 accuracy: 0.470
[31,  3800] loss: 0.144 accuracy: 0.467
[31,  4000] loss: 0.142 accuracy: 0.477
[31,  4200] loss: 0.142 accuracy: 0.476
[31,  4400] loss: 0.142 accuracy: 0.475
[31,  4600] loss: 0.142 accuracy: 0.474
[31,  4800] loss: 0.142 accuracy: 0.474
[31,  5000] loss: 0.142 accuracy: 0.475


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48649425287356324


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[32,   200] loss: 0.142 accuracy: 0.479
[32,   400] loss: 0.142 accuracy: 0.473
[32,   600] loss: 0.142 accuracy: 0.476
[32,   800] loss: 0.141 accuracy: 0.473
[32,  1000] loss: 0.142 accuracy: 0.474
[32,  1200] loss: 0.143 accuracy: 0.466
[32,  1400] loss: 0.138 accuracy: 0.485
[32,  1600] loss: 0.142 accuracy: 0.471
[32,  1800] loss: 0.142 accuracy: 0.473
[32,  2000] loss: 0.141 accuracy: 0.485
[32,  2200] loss: 0.142 accuracy: 0.472
[32,  2400] loss: 0.141 accuracy: 0.476
[32,  2600] loss: 0.141 accuracy: 0.482
[32,  2800] loss: 0.142 accuracy: 0.474
[32,  3000] loss: 0.143 accuracy: 0.465
[32,  3200] loss: 0.141 accuracy: 0.478
[32,  3400] loss: 0.142 accuracy: 0.473
[32,  3600] loss: 0.141 accuracy: 0.476
[32,  3800] loss: 0.143 accuracy: 0.469
[32,  4000] loss: 0.143 accuracy: 0.470
[32,  4200] loss: 0.144 accuracy: 0.467
[32,  4400] loss: 0.143 accuracy: 0.473
[32,  4600] loss: 0.142 accuracy: 0.475
[32,  4800] loss: 0.142 accuracy: 0.470
[32,  5000] loss: 0.143 accuracy: 0.467


HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.48944847751562814


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[33,   200] loss: 0.141 accuracy: 0.475
[33,   400] loss: 0.142 accuracy: 0.477
[33,   600] loss: 0.140 accuracy: 0.482
[33,   800] loss: 0.142 accuracy: 0.466
[33,  1000] loss: 0.141 accuracy: 0.477
[33,  1200] loss: 0.141 accuracy: 0.475
[33,  1400] loss: 0.143 accuracy: 0.473
[33,  1600] loss: 0.142 accuracy: 0.481
[33,  1800] loss: 0.141 accuracy: 0.478
[33,  2000] loss: 0.142 accuracy: 0.478
[33,  2200] loss: 0.141 accuracy: 0.475
[33,  2400] loss: 0.144 accuracy: 0.464
[33,  2600] loss: 0.141 accuracy: 0.473
[33,  2800] loss: 0.141 accuracy: 0.475
[33,  3000] loss: 0.143 accuracy: 0.472
[33,  3200] loss: 0.141 accuracy: 0.479
[33,  3400] loss: 0.143 accuracy: 0.470
[33,  3600] loss: 0.144 accuracy: 0.468
[33,  3800] loss: 0.142 accuracy: 0.468
[33,  4000] loss: 0.144 accuracy: 0.471
[33,  4200] loss: 0.143 accuracy: 0.469
[33,  4400] loss: 0.142 accuracy: 0.471
[33,  4600] loss: 0.140 accuracy: 0.481
[33,  4800] loss: 0.141 accuracy: 0.477
[33,  5000] loss: 0.142 accuracy: 0.477


KeyboardInterrupt: 

In [12]:
checkpoint = torch.load('models/best-lstm.pth')
overnet.load_state_dict(checkpoint['model_state_dict'])

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
        
corrects = 0
for batch_idx, (spec, target) in enumerate(tqdm(testloader)):
    spec, target = spec.float(), target.float()
    if half:
        spec, target = spec.half(), target.half()
    spec = spec.cuda(device)
    target = target.cuda(device)
    overnet.eval()
    out = overnet(spec) 
    corrects += compute_corrects(out, target)
print('test acc:', corrects/len(testset))
overnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


test acc: 0.49054748941318815
