In [1]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False

In [2]:
device = 3
torch.cuda.set_device(device)

In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']).union(set(self.overlays['second_speaker'])))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(overlay['first_file'])/(2**15)
        second_segment = np.load(overlay['second_file'])/(2**15)
        #padding to compensate rounding errors
        if len(first_segment)>len(second_segment):
            padding = np.zeros(len(first_segment)-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
        
        if len(first_segment)<len(second_segment):
            padding = np.zeros(len(second_segment)-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.melspectrogram(segment, n_mels = 256, n_fft = 1024, hop_length = 160) # 32 ms window, 10 ms hop
        S_dB = librosa.power_to_db(S).T[None, :, :] # add channel dimension
        return S_dB

trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
testset = OverlayDataSet('overlay-test.csv', False)
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)
print(spec3.max(), spec3.min(), spec3.shape)
print(trainset.speakers)

27.971525 -52.028473 (1, 200, 256)
['andrea_arsenault', 'brian_lamb', 'csp_waj_susan', 'david_brancaccio', 'eddie_mair', 'joie_chen', 'kathleen_kennedy', 'leon_harris', 'linda_wertheimer', 'linden_soles', 'lisa_mullins', 'lou_waters', 'lynn_vaughan', 'mark_mullen', 'natalie_allen', 'noah_adams', 'peter_jennings', 'robert_siegel', 'ted_koppel', 'thalia_assuras']


<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [4]:
# import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        self.bn1 = nn.BatchNorm2d(input_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(input_channels, output_channels, 1, 1, bias = False)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(output_channels, output_channels, 3, stride, padding = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(output_channels, output_channels, 1, 1, bias = False)
        self.conv4 = nn.Conv2d(input_channels, output_channels , 1, stride, bias = False)
        
    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out1 = self.relu(out)
        out = self.conv1(out1)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        if (self.input_channels != self.output_channels) or (self.stride !=1 ):
            residual = self.conv4(out1)
        out += residual
        return out

class AttentionModule_stage1(nn.Module):
    # input size is 56*56
    def __init__(self, in_channels, out_channels, size1=(200, 128), size2=(100, 64), size3=(50, 32)):
        super(AttentionModule_stage1, self).__init__()
        self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

        self.trunk_branches = nn.Sequential(
            ResidualBlock(in_channels, out_channels),
            ResidualBlock(in_channels, out_channels)
         )

        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax1_blocks = ResidualBlock(in_channels, out_channels)

        self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

        self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax2_blocks = ResidualBlock(in_channels, out_channels)

        self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)

        self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax3_blocks = nn.Sequential(
            ResidualBlock(in_channels, out_channels),
            ResidualBlock(in_channels, out_channels)
        )

        self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)

        self.softmax4_blocks = ResidualBlock(in_channels, out_channels)

        self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)

        self.softmax5_blocks = ResidualBlock(in_channels, out_channels)

        self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

        self.softmax6_blocks = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            nn.Sigmoid()
        )

        self.last_blocks = ResidualBlock(in_channels, out_channels)

    def forward(self, x):
        #batch_size, nheads, length, n_mels = x.shape
        x = self.first_residual_blocks(x)
        out_trunk = self.trunk_branches(x)
        out_mpool1 = self.mpool1(x) # 100x64
        out_softmax1 = self.softmax1_blocks(out_mpool1)
        out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
        out_mpool2 = self.mpool2(out_softmax1) # 50x32
        out_softmax2 = self.softmax2_blocks(out_mpool2)
        out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
        out_mpool3 = self.mpool3(out_softmax2) # 25x16
        out_softmax3 = self.softmax3_blocks(out_mpool3) 
        out_interp3 = self.interpolation3(out_softmax3) + out_softmax2
        out = out_interp3 + out_skip2_connection
        out_softmax4 = self.softmax4_blocks(out)
        out_interp2 = self.interpolation2(out_softmax4) + out_softmax1
        out = out_interp2 + out_skip1_connection
        out_softmax5 = self.softmax5_blocks(out)
        out_interp1 = self.interpolation1(out_softmax5) + out_trunk
        out_softmax6 = self.softmax6_blocks(out_interp1)
        out = (1 + out_softmax6) * out_trunk
        out_last = self.last_blocks(out)
        return out_last

num_heads = 8 # Residual Attention Channels
num_heads_2 = 4 # MHA heads


class OverlayNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = nn.LayerNorm(256)
        self.downsample1 = ResidualBlock(1, num_heads, (1, 2))
        self.res_att = AttentionModule_stage1(num_heads, num_heads)  # batch_size * num_heads * L *128
        self.downsample2 = nn.Sequential(ResidualBlock(num_heads, num_heads//2),
                                        ResidualBlock(num_heads//2, 2),
                                        ResidualBlock(2, 2, (1, 2)))
        self.reshape =  Lambda(lambda x: x.permute((1, 2, 0, 3))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(64, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.softmax = nn.Softmax(1)
    def forward(self, X):
        X = self.ln(X)
        X = self.downsample1(X)
        #print('first downsample ', X.shape)
        X = self.res_att(X)
        #print('residual attention ', X.shape)
        X = self.downsample2(X)
        #print('second downsample ', X.shape)
        X = self.reshape(X)
        X1, X2 = X[0], X[1]
        X1,_ = self.lstm(X1)
        X2,_ = self.lstm(X1)
        #print('lstm ', X.shape)
        X1 = self.fc1(X1)
        X2 = self.fc1(X2)
        #print('dense ', X.shape)
        X1 = self.average(X1)
        X2 = self.average(X2)
        #print('mean ', X.shape)
        X1 = self.tanh(X1)
        X2 = self.tanh(X2)
        X1 = self.fc2(X1)
        X2 = self.fc2(X2)
        X1 = self.softmax(X1)
        X2 = self.softmax(X2)
        X = torch.stack([X1,X2], dim=0)
        X,_ = torch.max(X, dim=0)
        return X
    
    
overnet = OverlayNet().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/resatt.pth'):
    print('load model')
    checkpoint = torch.load('models/resatt.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

initializing new model


('bestacc:', 0.0)

## Also Do metrics on hitting a single person right
## Theoretical justification that when reducing number of channels, neural network could learn to pair similar spectrum representations with each other: only pitch lines that are greater than a certain threshold will pass relu and become postivie. Batchnorm makes this effect stronger. Therefore, when two channels represent spectrums of different people, when they are added the output will be nothing if it doesn't pass the relu threshold

In [5]:
batch_size = 32
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects


trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/resatt.pth')

        
        
    corrects = 0
    for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(), target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)
        overnet.eval()
        out = overnet(spec) 
        corrects += compute_corrects(out, target)
    print('val acc:', corrects/len(valset))
    if corrects/len(valset) > bestacc:
        bestacc = corrects/len(valset)
        torch.save({
        'model_state_dict': overnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'bestacc': bestacc
        }, 'models/best-resatt.pth')
    overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[1,   200] loss: 0.318 accuracy: 0.015
[1,   400] loss: 0.279 accuracy: 0.037
[1,   600] loss: 0.250 accuracy: 0.086
[1,   800] loss: 0.232 accuracy: 0.119
[1,  1000] loss: 0.219 accuracy: 0.154
[1,  1200] loss: 0.208 accuracy: 0.183
[1,  1400] loss: 0.199 accuracy: 0.204
[1,  1600] loss: 0.191 accuracy: 0.231
[1,  1800] loss: 0.186 accuracy: 0.253
[1,  2000] loss: 0.183 accuracy: 0.261
[1,  2200] loss: 0.176 accuracy: 0.291
[1,  2400] loss: 0.173 accuracy: 0.302
[1,  2600] loss: 0.167 accuracy: 0.329



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.3547186932849365


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[2,   200] loss: 0.159 accuracy: 0.358
[2,   400] loss: 0.153 accuracy: 0.386
[2,   600] loss: 0.148 accuracy: 0.407
[2,   800] loss: 0.146 accuracy: 0.415
[2,  1000] loss: 0.140 accuracy: 0.454
[2,  1200] loss: 0.136 accuracy: 0.467
[2,  1400] loss: 0.133 accuracy: 0.473
[2,  1600] loss: 0.129 accuracy: 0.495
[2,  1800] loss: 0.128 accuracy: 0.498
[2,  2000] loss: 0.123 accuracy: 0.518
[2,  2200] loss: 0.120 accuracy: 0.535
[2,  2400] loss: 0.117 accuracy: 0.548
[2,  2600] loss: 0.114 accuracy: 0.551



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.5586206896551724


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[3,   200] loss: 0.109 accuracy: 0.577
[3,   400] loss: 0.106 accuracy: 0.602
[3,   600] loss: 0.105 accuracy: 0.599
[3,   800] loss: 0.106 accuracy: 0.602
[3,  1000] loss: 0.102 accuracy: 0.612
[3,  1200] loss: 0.100 accuracy: 0.624
[3,  1400] loss: 0.099 accuracy: 0.625
[3,  1600] loss: 0.096 accuracy: 0.638
[3,  1800] loss: 0.097 accuracy: 0.638
[3,  2000] loss: 0.094 accuracy: 0.653
[3,  2200] loss: 0.090 accuracy: 0.661
[3,  2400] loss: 0.093 accuracy: 0.652
[3,  2600] loss: 0.092 accuracy: 0.650



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.6295825771324864


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[4,   200] loss: 0.085 accuracy: 0.683
[4,   400] loss: 0.084 accuracy: 0.688
[4,   600] loss: 0.083 accuracy: 0.690
[4,   800] loss: 0.083 accuracy: 0.689
[4,  1000] loss: 0.081 accuracy: 0.699
[4,  1200] loss: 0.082 accuracy: 0.692
[4,  1400] loss: 0.082 accuracy: 0.693
[4,  1600] loss: 0.079 accuracy: 0.700
[4,  1800] loss: 0.077 accuracy: 0.721
[4,  2000] loss: 0.078 accuracy: 0.709
[4,  2200] loss: 0.077 accuracy: 0.718
[4,  2400] loss: 0.076 accuracy: 0.725
[4,  2600] loss: 0.077 accuracy: 0.719



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7349364791288566


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[5,   200] loss: 0.070 accuracy: 0.746
[5,   400] loss: 0.069 accuracy: 0.747
[5,   600] loss: 0.071 accuracy: 0.738
[5,   800] loss: 0.070 accuracy: 0.741
[5,  1000] loss: 0.070 accuracy: 0.743
[5,  1200] loss: 0.068 accuracy: 0.755
[5,  1400] loss: 0.067 accuracy: 0.761
[5,  1600] loss: 0.069 accuracy: 0.750
[5,  1800] loss: 0.067 accuracy: 0.763
[5,  2000] loss: 0.070 accuracy: 0.740
[5,  2200] loss: 0.067 accuracy: 0.757
[5,  2400] loss: 0.068 accuracy: 0.752
[5,  2600] loss: 0.066 accuracy: 0.757



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7561705989110707


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[6,   200] loss: 0.061 accuracy: 0.782
[6,   400] loss: 0.061 accuracy: 0.781
[6,   600] loss: 0.061 accuracy: 0.777
[6,   800] loss: 0.062 accuracy: 0.773
[6,  1000] loss: 0.061 accuracy: 0.778
[6,  1200] loss: 0.060 accuracy: 0.789
[6,  1400] loss: 0.060 accuracy: 0.789
[6,  1600] loss: 0.058 accuracy: 0.784
[6,  1800] loss: 0.060 accuracy: 0.781
[6,  2000] loss: 0.058 accuracy: 0.787
[6,  2200] loss: 0.057 accuracy: 0.793
[6,  2400] loss: 0.058 accuracy: 0.788
[6,  2600] loss: 0.059 accuracy: 0.789



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7927404718693285


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[7,   200] loss: 0.055 accuracy: 0.796
[7,   400] loss: 0.051 accuracy: 0.817
[7,   600] loss: 0.053 accuracy: 0.807
[7,   800] loss: 0.055 accuracy: 0.803
[7,  1000] loss: 0.055 accuracy: 0.793
[7,  1200] loss: 0.052 accuracy: 0.811
[7,  1400] loss: 0.053 accuracy: 0.807
[7,  1600] loss: 0.055 accuracy: 0.802
[7,  1800] loss: 0.053 accuracy: 0.806
[7,  2000] loss: 0.053 accuracy: 0.812
[7,  2200] loss: 0.053 accuracy: 0.810
[7,  2400] loss: 0.050 accuracy: 0.818
[7,  2600] loss: 0.053 accuracy: 0.813



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.7843920145190563


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[8,   200] loss: 0.049 accuracy: 0.821
[8,   400] loss: 0.047 accuracy: 0.830
[8,   600] loss: 0.048 accuracy: 0.823
[8,   800] loss: 0.047 accuracy: 0.835
[8,  1000] loss: 0.049 accuracy: 0.823
[8,  1200] loss: 0.049 accuracy: 0.826
[8,  1400] loss: 0.047 accuracy: 0.830
[8,  1600] loss: 0.047 accuracy: 0.828
[8,  1800] loss: 0.048 accuracy: 0.825
[8,  2000] loss: 0.048 accuracy: 0.834
[8,  2200] loss: 0.048 accuracy: 0.829
[8,  2400] loss: 0.047 accuracy: 0.831
[8,  2600] loss: 0.049 accuracy: 0.822



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8225045372050817


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[9,   200] loss: 0.044 accuracy: 0.845
[9,   400] loss: 0.044 accuracy: 0.843
[9,   600] loss: 0.046 accuracy: 0.835
[9,   800] loss: 0.046 accuracy: 0.837
[9,  1000] loss: 0.042 accuracy: 0.850
[9,  1200] loss: 0.044 accuracy: 0.846
[9,  1400] loss: 0.045 accuracy: 0.838
[9,  1600] loss: 0.044 accuracy: 0.840
[9,  1800] loss: 0.044 accuracy: 0.842
[9,  2000] loss: 0.043 accuracy: 0.848
[9,  2200] loss: 0.046 accuracy: 0.841
[9,  2400] loss: 0.043 accuracy: 0.849
[9,  2600] loss: 0.042 accuracy: 0.853



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8078947368421052


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[10,   200] loss: 0.040 accuracy: 0.862
[10,   400] loss: 0.040 accuracy: 0.858
[10,   600] loss: 0.041 accuracy: 0.857
[10,   800] loss: 0.041 accuracy: 0.851
[10,  1000] loss: 0.041 accuracy: 0.851
[10,  1200] loss: 0.041 accuracy: 0.851
[10,  1400] loss: 0.040 accuracy: 0.855
[10,  1600] loss: 0.041 accuracy: 0.852
[10,  1800] loss: 0.042 accuracy: 0.848
[10,  2000] loss: 0.040 accuracy: 0.855
[10,  2200] loss: 0.040 accuracy: 0.860
[10,  2400] loss: 0.042 accuracy: 0.849
[10,  2600] loss: 0.041 accuracy: 0.851



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8359346642468239


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[11,   200] loss: 0.036 accuracy: 0.873
[11,   400] loss: 0.037 accuracy: 0.870
[11,   600] loss: 0.037 accuracy: 0.866
[11,   800] loss: 0.038 accuracy: 0.864
[11,  1000] loss: 0.036 accuracy: 0.869
[11,  1200] loss: 0.039 accuracy: 0.860
[11,  1400] loss: 0.038 accuracy: 0.863
[11,  1600] loss: 0.037 accuracy: 0.870
[11,  1800] loss: 0.038 accuracy: 0.868
[11,  2000] loss: 0.038 accuracy: 0.864
[11,  2200] loss: 0.037 accuracy: 0.864
[11,  2400] loss: 0.037 accuracy: 0.871
[11,  2600] loss: 0.038 accuracy: 0.866



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8459165154264973


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[12,   200] loss: 0.033 accuracy: 0.887
[12,   400] loss: 0.034 accuracy: 0.877
[12,   600] loss: 0.033 accuracy: 0.883
[12,   800] loss: 0.035 accuracy: 0.874
[12,  1000] loss: 0.036 accuracy: 0.873
[12,  1200] loss: 0.037 accuracy: 0.870
[12,  1400] loss: 0.035 accuracy: 0.877
[12,  1600] loss: 0.033 accuracy: 0.881
[12,  1800] loss: 0.036 accuracy: 0.868
[12,  2000] loss: 0.036 accuracy: 0.875
[12,  2200] loss: 0.035 accuracy: 0.878
[12,  2400] loss: 0.036 accuracy: 0.873
[12,  2600] loss: 0.035 accuracy: 0.873



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8575317604355717


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[13,   200] loss: 0.032 accuracy: 0.889
[13,   400] loss: 0.032 accuracy: 0.890
[13,   600] loss: 0.033 accuracy: 0.884
[13,   800] loss: 0.032 accuracy: 0.889
[13,  1000] loss: 0.034 accuracy: 0.879
[13,  1200] loss: 0.034 accuracy: 0.880
[13,  1400] loss: 0.033 accuracy: 0.881
[13,  1600] loss: 0.032 accuracy: 0.883
[13,  1800] loss: 0.034 accuracy: 0.880
[13,  2000] loss: 0.033 accuracy: 0.882
[13,  2200] loss: 0.032 accuracy: 0.884
[13,  2400] loss: 0.033 accuracy: 0.880
[13,  2600] loss: 0.032 accuracy: 0.885



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8596188747731397


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[14,   200] loss: 0.030 accuracy: 0.897
[14,   400] loss: 0.029 accuracy: 0.899
[14,   600] loss: 0.031 accuracy: 0.893
[14,   800] loss: 0.029 accuracy: 0.900
[14,  1000] loss: 0.032 accuracy: 0.885
[14,  1200] loss: 0.030 accuracy: 0.891
[14,  1400] loss: 0.031 accuracy: 0.886
[14,  1600] loss: 0.030 accuracy: 0.890
[14,  1800] loss: 0.033 accuracy: 0.883
[14,  2000] loss: 0.031 accuracy: 0.888
[14,  2200] loss: 0.030 accuracy: 0.895
[14,  2400] loss: 0.030 accuracy: 0.891
[14,  2600] loss: 0.033 accuracy: 0.882



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8642468239564428


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[15,   200] loss: 0.028 accuracy: 0.901
[15,   400] loss: 0.028 accuracy: 0.901
[15,   600] loss: 0.029 accuracy: 0.897
[15,   800] loss: 0.029 accuracy: 0.897
[15,  1000] loss: 0.029 accuracy: 0.900
[15,  1200] loss: 0.029 accuracy: 0.892
[15,  1400] loss: 0.030 accuracy: 0.897
[15,  1600] loss: 0.029 accuracy: 0.898
[15,  1800] loss: 0.030 accuracy: 0.897
[15,  2000] loss: 0.030 accuracy: 0.892
[15,  2200] loss: 0.030 accuracy: 0.890
[15,  2400] loss: 0.030 accuracy: 0.898
[15,  2600] loss: 0.030 accuracy: 0.893



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.867513611615245


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[16,   200] loss: 0.026 accuracy: 0.906
[16,   400] loss: 0.027 accuracy: 0.903
[16,   600] loss: 0.027 accuracy: 0.906
[16,   800] loss: 0.026 accuracy: 0.907
[16,  1000] loss: 0.026 accuracy: 0.910
[16,  1200] loss: 0.028 accuracy: 0.898
[16,  1400] loss: 0.027 accuracy: 0.903
[16,  1600] loss: 0.031 accuracy: 0.890
[16,  1800] loss: 0.028 accuracy: 0.903
[16,  2000] loss: 0.028 accuracy: 0.903
[16,  2200] loss: 0.027 accuracy: 0.902
[16,  2400] loss: 0.030 accuracy: 0.895
[16,  2600] loss: 0.027 accuracy: 0.902



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8717785843920145


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[17,   200] loss: 0.026 accuracy: 0.905
[17,   400] loss: 0.027 accuracy: 0.906
[17,   600] loss: 0.026 accuracy: 0.910
[17,   800] loss: 0.026 accuracy: 0.904
[17,  1000] loss: 0.026 accuracy: 0.913
[17,  1200] loss: 0.026 accuracy: 0.905
[17,  1400] loss: 0.027 accuracy: 0.906
[17,  1600] loss: 0.027 accuracy: 0.905
[17,  1800] loss: 0.027 accuracy: 0.902
[17,  2000] loss: 0.027 accuracy: 0.910
[17,  2200] loss: 0.026 accuracy: 0.909
[17,  2400] loss: 0.026 accuracy: 0.909
[17,  2600] loss: 0.026 accuracy: 0.910



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8696914700544465


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[18,   200] loss: 0.024 accuracy: 0.916
[18,   400] loss: 0.023 accuracy: 0.918
[18,   600] loss: 0.024 accuracy: 0.918
[18,   800] loss: 0.026 accuracy: 0.906
[18,  1000] loss: 0.026 accuracy: 0.906
[18,  1200] loss: 0.026 accuracy: 0.906
[18,  1400] loss: 0.024 accuracy: 0.910
[18,  1600] loss: 0.025 accuracy: 0.907
[18,  1800] loss: 0.025 accuracy: 0.910
[18,  2000] loss: 0.026 accuracy: 0.910
[18,  2200] loss: 0.027 accuracy: 0.906
[18,  2400] loss: 0.025 accuracy: 0.909
[18,  2600] loss: 0.026 accuracy: 0.905



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.880399274047187


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[19,   200] loss: 0.023 accuracy: 0.921
[19,   400] loss: 0.024 accuracy: 0.916
[19,   600] loss: 0.024 accuracy: 0.914
[19,   800] loss: 0.023 accuracy: 0.922
[19,  1000] loss: 0.024 accuracy: 0.914
[19,  1200] loss: 0.024 accuracy: 0.917
[19,  1400] loss: 0.026 accuracy: 0.907
[19,  1600] loss: 0.024 accuracy: 0.913
[19,  1800] loss: 0.025 accuracy: 0.910
[19,  2000] loss: 0.026 accuracy: 0.908
[19,  2200] loss: 0.024 accuracy: 0.915
[19,  2400] loss: 0.025 accuracy: 0.914
[19,  2600] loss: 0.026 accuracy: 0.908



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8737749546279492


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[20,   200] loss: 0.022 accuracy: 0.924
[20,   400] loss: 0.021 accuracy: 0.924
[20,   600] loss: 0.023 accuracy: 0.916
[20,   800] loss: 0.023 accuracy: 0.920
[20,  1000] loss: 0.023 accuracy: 0.915
[20,  1200] loss: 0.023 accuracy: 0.920
[20,  1400] loss: 0.023 accuracy: 0.913
[20,  1600] loss: 0.025 accuracy: 0.906
[20,  1800] loss: 0.024 accuracy: 0.915
[20,  2000] loss: 0.023 accuracy: 0.916
[20,  2200] loss: 0.025 accuracy: 0.914
[20,  2400] loss: 0.023 accuracy: 0.920
[20,  2600] loss: 0.024 accuracy: 0.916



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8733212341197822


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[21,   200] loss: 0.021 accuracy: 0.930
[21,   400] loss: 0.020 accuracy: 0.927
[21,   600] loss: 0.024 accuracy: 0.915
[21,   800] loss: 0.021 accuracy: 0.928
[21,  1000] loss: 0.022 accuracy: 0.919
[21,  1200] loss: 0.023 accuracy: 0.917
[21,  1400] loss: 0.022 accuracy: 0.920
[21,  1600] loss: 0.022 accuracy: 0.922
[21,  1800] loss: 0.021 accuracy: 0.926
[21,  2000] loss: 0.025 accuracy: 0.914
[21,  2200] loss: 0.023 accuracy: 0.918
[21,  2400] loss: 0.023 accuracy: 0.915
[21,  2600] loss: 0.023 accuracy: 0.922



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8828493647912886


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[22,   200] loss: 0.021 accuracy: 0.927
[22,   400] loss: 0.020 accuracy: 0.930
[22,   600] loss: 0.020 accuracy: 0.929
[22,   800] loss: 0.022 accuracy: 0.925
[22,  1000] loss: 0.022 accuracy: 0.921
[22,  1200] loss: 0.023 accuracy: 0.918
[22,  1400] loss: 0.022 accuracy: 0.924
[22,  1600] loss: 0.021 accuracy: 0.928
[22,  1800] loss: 0.021 accuracy: 0.923
[22,  2000] loss: 0.022 accuracy: 0.918
[22,  2200] loss: 0.022 accuracy: 0.921
[22,  2400] loss: 0.022 accuracy: 0.920
[22,  2600] loss: 0.022 accuracy: 0.922



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.888021778584392


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[23,   200] loss: 0.019 accuracy: 0.938
[23,   400] loss: 0.020 accuracy: 0.930
[23,   600] loss: 0.021 accuracy: 0.926
[23,   800] loss: 0.020 accuracy: 0.925
[23,  1000] loss: 0.021 accuracy: 0.925
[23,  1200] loss: 0.021 accuracy: 0.926
[23,  1400] loss: 0.021 accuracy: 0.923
[23,  1600] loss: 0.021 accuracy: 0.924
[23,  1800] loss: 0.022 accuracy: 0.921
[23,  2000] loss: 0.022 accuracy: 0.920
[23,  2200] loss: 0.022 accuracy: 0.923
[23,  2400] loss: 0.021 accuracy: 0.927
[23,  2600] loss: 0.024 accuracy: 0.918



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8876588021778584


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[24,   200] loss: 0.018 accuracy: 0.937
[24,   400] loss: 0.019 accuracy: 0.932
[24,   600] loss: 0.019 accuracy: 0.934
[24,   800] loss: 0.019 accuracy: 0.933
[24,  1000] loss: 0.022 accuracy: 0.925
[24,  1200] loss: 0.021 accuracy: 0.927
[24,  1400] loss: 0.021 accuracy: 0.926
[24,  1600] loss: 0.021 accuracy: 0.926
[24,  1800] loss: 0.020 accuracy: 0.928
[24,  2000] loss: 0.021 accuracy: 0.926
[24,  2200] loss: 0.020 accuracy: 0.926
[24,  2400] loss: 0.022 accuracy: 0.926
[24,  2600] loss: 0.021 accuracy: 0.927



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8882032667876588


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[25,   200] loss: 0.017 accuracy: 0.942
[25,   400] loss: 0.018 accuracy: 0.940
[25,   600] loss: 0.020 accuracy: 0.928
[25,   800] loss: 0.021 accuracy: 0.925
[25,  1000] loss: 0.019 accuracy: 0.934
[25,  1200] loss: 0.020 accuracy: 0.931
[25,  1400] loss: 0.019 accuracy: 0.931
[25,  1600] loss: 0.019 accuracy: 0.935
[25,  1800] loss: 0.021 accuracy: 0.926
[25,  2000] loss: 0.020 accuracy: 0.927
[25,  2200] loss: 0.020 accuracy: 0.932
[25,  2400] loss: 0.020 accuracy: 0.928
[25,  2600] loss: 0.020 accuracy: 0.929



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8874773139745916


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[26,   200] loss: 0.017 accuracy: 0.943
[26,   400] loss: 0.017 accuracy: 0.940
[26,   600] loss: 0.018 accuracy: 0.936
[26,   800] loss: 0.020 accuracy: 0.927
[26,  1000] loss: 0.019 accuracy: 0.930
[26,  1200] loss: 0.019 accuracy: 0.931
[26,  1400] loss: 0.019 accuracy: 0.934
[26,  1600] loss: 0.020 accuracy: 0.930
[26,  1800] loss: 0.019 accuracy: 0.936
[26,  2000] loss: 0.020 accuracy: 0.927
[26,  2200] loss: 0.020 accuracy: 0.931
[26,  2400] loss: 0.019 accuracy: 0.934
[26,  2600] loss: 0.021 accuracy: 0.924



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8873865698729583


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[27,   200] loss: 0.018 accuracy: 0.941
[27,   400] loss: 0.018 accuracy: 0.937
[27,   600] loss: 0.018 accuracy: 0.939
[27,   800] loss: 0.017 accuracy: 0.943
[27,  1000] loss: 0.017 accuracy: 0.940
[27,  1200] loss: 0.020 accuracy: 0.935
[27,  1400] loss: 0.019 accuracy: 0.933
[27,  1600] loss: 0.018 accuracy: 0.935
[27,  1800] loss: 0.019 accuracy: 0.934
[27,  2000] loss: 0.021 accuracy: 0.926
[27,  2200] loss: 0.020 accuracy: 0.931
[27,  2400] loss: 0.018 accuracy: 0.936
[27,  2600] loss: 0.019 accuracy: 0.931



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8753176043557169


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[28,   200] loss: 0.016 accuracy: 0.943
[28,   400] loss: 0.017 accuracy: 0.940
[28,   600] loss: 0.018 accuracy: 0.936
[28,   800] loss: 0.017 accuracy: 0.937
[28,  1000] loss: 0.017 accuracy: 0.940
[28,  1200] loss: 0.018 accuracy: 0.931
[28,  1400] loss: 0.019 accuracy: 0.933
[28,  1600] loss: 0.020 accuracy: 0.929
[28,  1800] loss: 0.019 accuracy: 0.934
[28,  2000] loss: 0.018 accuracy: 0.935
[28,  2200] loss: 0.018 accuracy: 0.937
[28,  2400] loss: 0.018 accuracy: 0.936
[28,  2600] loss: 0.018 accuracy: 0.936



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8903811252268603


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[29,   200] loss: 0.016 accuracy: 0.946
[29,   400] loss: 0.016 accuracy: 0.947
[29,   600] loss: 0.017 accuracy: 0.938
[29,   800] loss: 0.017 accuracy: 0.938
[29,  1000] loss: 0.017 accuracy: 0.940
[29,  1200] loss: 0.019 accuracy: 0.933
[29,  1400] loss: 0.017 accuracy: 0.939
[29,  1600] loss: 0.016 accuracy: 0.941
[29,  1800] loss: 0.018 accuracy: 0.935
[29,  2000] loss: 0.019 accuracy: 0.935
[29,  2200] loss: 0.019 accuracy: 0.933
[29,  2400] loss: 0.018 accuracy: 0.935
[29,  2600] loss: 0.018 accuracy: 0.939



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8836660617059892


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[30,   200] loss: 0.015 accuracy: 0.950
[30,   400] loss: 0.015 accuracy: 0.951
[30,   600] loss: 0.016 accuracy: 0.945
[30,   800] loss: 0.017 accuracy: 0.943
[30,  1000] loss: 0.017 accuracy: 0.938
[30,  1200] loss: 0.018 accuracy: 0.935
[30,  1400] loss: 0.018 accuracy: 0.938
[30,  1600] loss: 0.018 accuracy: 0.936
[30,  1800] loss: 0.017 accuracy: 0.936
[30,  2000] loss: 0.018 accuracy: 0.935
[30,  2200] loss: 0.018 accuracy: 0.942
[30,  2400] loss: 0.018 accuracy: 0.936
[30,  2600] loss: 0.019 accuracy: 0.937



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8929219600725953


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[31,   200] loss: 0.014 accuracy: 0.952
[31,   400] loss: 0.016 accuracy: 0.943
[31,   600] loss: 0.017 accuracy: 0.943
[31,   800] loss: 0.017 accuracy: 0.940
[31,  1000] loss: 0.015 accuracy: 0.947
[31,  1200] loss: 0.016 accuracy: 0.946
[31,  1400] loss: 0.018 accuracy: 0.934
[31,  1600] loss: 0.018 accuracy: 0.937
[31,  1800] loss: 0.018 accuracy: 0.936
[31,  2000] loss: 0.017 accuracy: 0.938
[31,  2200] loss: 0.019 accuracy: 0.933
[31,  2400] loss: 0.017 accuracy: 0.940
[31,  2600] loss: 0.018 accuracy: 0.936



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8996370235934664


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[32,   200] loss: 0.015 accuracy: 0.949
[32,   400] loss: 0.016 accuracy: 0.946
[32,   600] loss: 0.016 accuracy: 0.943
[32,   800] loss: 0.015 accuracy: 0.944
[32,  1000] loss: 0.015 accuracy: 0.947
[32,  1200] loss: 0.016 accuracy: 0.940
[32,  1400] loss: 0.016 accuracy: 0.944
[32,  1600] loss: 0.017 accuracy: 0.940
[32,  1800] loss: 0.018 accuracy: 0.934
[32,  2000] loss: 0.017 accuracy: 0.942
[32,  2200] loss: 0.016 accuracy: 0.944
[32,  2400] loss: 0.017 accuracy: 0.938
[32,  2600] loss: 0.017 accuracy: 0.940



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8985480943738657


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[33,   200] loss: 0.015 accuracy: 0.950
[33,   400] loss: 0.016 accuracy: 0.945
[33,   600] loss: 0.015 accuracy: 0.946
[33,   800] loss: 0.015 accuracy: 0.946
[33,  1000] loss: 0.015 accuracy: 0.947
[33,  1200] loss: 0.016 accuracy: 0.941
[33,  1400] loss: 0.016 accuracy: 0.943
[33,  1600] loss: 0.016 accuracy: 0.944
[33,  1800] loss: 0.017 accuracy: 0.943
[33,  2000] loss: 0.017 accuracy: 0.939
[33,  2200] loss: 0.017 accuracy: 0.941
[33,  2400] loss: 0.018 accuracy: 0.935
[33,  2600] loss: 0.016 accuracy: 0.941



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8891107078039927


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[34,   200] loss: 0.014 accuracy: 0.953
[34,   400] loss: 0.014 accuracy: 0.950
[34,   600] loss: 0.015 accuracy: 0.945
[34,   800] loss: 0.016 accuracy: 0.939
[34,  1000] loss: 0.014 accuracy: 0.949
[34,  1200] loss: 0.014 accuracy: 0.950
[34,  1400] loss: 0.016 accuracy: 0.945
[34,  1600] loss: 0.016 accuracy: 0.941
[34,  1800] loss: 0.016 accuracy: 0.943
[34,  2000] loss: 0.017 accuracy: 0.938
[34,  2200] loss: 0.016 accuracy: 0.944
[34,  2400] loss: 0.015 accuracy: 0.948
[34,  2600] loss: 0.017 accuracy: 0.940



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8950998185117968


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[35,   200] loss: 0.014 accuracy: 0.954
[35,   400] loss: 0.016 accuracy: 0.944
[35,   600] loss: 0.014 accuracy: 0.953
[35,   800] loss: 0.015 accuracy: 0.945
[35,  1000] loss: 0.015 accuracy: 0.949
[35,  1200] loss: 0.015 accuracy: 0.948
[35,  1400] loss: 0.016 accuracy: 0.946
[35,  1600] loss: 0.014 accuracy: 0.951
[35,  1800] loss: 0.016 accuracy: 0.947
[35,  2000] loss: 0.015 accuracy: 0.948
[35,  2200] loss: 0.015 accuracy: 0.949
[35,  2400] loss: 0.015 accuracy: 0.945
[35,  2600] loss: 0.016 accuracy: 0.945



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9023593466424682


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[36,   200] loss: 0.012 accuracy: 0.960
[36,   400] loss: 0.013 accuracy: 0.955
[36,   600] loss: 0.015 accuracy: 0.949
[36,   800] loss: 0.013 accuracy: 0.956
[36,  1000] loss: 0.015 accuracy: 0.949
[36,  1200] loss: 0.014 accuracy: 0.952
[36,  1400] loss: 0.014 accuracy: 0.954
[36,  1600] loss: 0.016 accuracy: 0.944
[36,  1800] loss: 0.015 accuracy: 0.951
[36,  2000] loss: 0.016 accuracy: 0.943
[36,  2200] loss: 0.015 accuracy: 0.951
[36,  2400] loss: 0.016 accuracy: 0.945
[36,  2600] loss: 0.014 accuracy: 0.949



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8960980036297641


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[37,   200] loss: 0.012 accuracy: 0.958
[37,   400] loss: 0.013 accuracy: 0.956
[37,   600] loss: 0.013 accuracy: 0.956
[37,   800] loss: 0.016 accuracy: 0.949
[37,  1000] loss: 0.015 accuracy: 0.947
[37,  1200] loss: 0.014 accuracy: 0.952
[37,  1400] loss: 0.014 accuracy: 0.951
[37,  1600] loss: 0.015 accuracy: 0.947
[37,  1800] loss: 0.014 accuracy: 0.951
[37,  2000] loss: 0.015 accuracy: 0.949
[37,  2200] loss: 0.016 accuracy: 0.947
[37,  2400] loss: 0.015 accuracy: 0.948
[37,  2600] loss: 0.015 accuracy: 0.946



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8997277676950998


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[38,   200] loss: 0.012 accuracy: 0.959
[38,   400] loss: 0.013 accuracy: 0.956
[38,   600] loss: 0.014 accuracy: 0.953
[38,   800] loss: 0.013 accuracy: 0.956
[38,  1000] loss: 0.013 accuracy: 0.955
[38,  1200] loss: 0.013 accuracy: 0.956
[38,  1400] loss: 0.013 accuracy: 0.952
[38,  1600] loss: 0.014 accuracy: 0.955
[38,  1800] loss: 0.014 accuracy: 0.949
[38,  2000] loss: 0.015 accuracy: 0.949
[38,  2200] loss: 0.015 accuracy: 0.945
[38,  2400] loss: 0.015 accuracy: 0.950
[38,  2600] loss: 0.014 accuracy: 0.948



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9019056261343013


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[39,   200] loss: 0.012 accuracy: 0.961
[39,   400] loss: 0.012 accuracy: 0.953
[39,   600] loss: 0.012 accuracy: 0.960
[39,   800] loss: 0.013 accuracy: 0.952
[39,  1000] loss: 0.012 accuracy: 0.958
[39,  1200] loss: 0.013 accuracy: 0.952
[39,  1400] loss: 0.014 accuracy: 0.950
[39,  1600] loss: 0.014 accuracy: 0.950
[39,  1800] loss: 0.013 accuracy: 0.956
[39,  2000] loss: 0.014 accuracy: 0.949
[39,  2200] loss: 0.014 accuracy: 0.949
[39,  2400] loss: 0.014 accuracy: 0.950
[39,  2600] loss: 0.013 accuracy: 0.955



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.8941016333938294


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[40,   200] loss: 0.011 accuracy: 0.964
[40,   400] loss: 0.013 accuracy: 0.957
[40,   600] loss: 0.013 accuracy: 0.957
[40,   800] loss: 0.013 accuracy: 0.954


KeyboardInterrupt: 

In [6]:
checkpoint = torch.load('models/best-resatt.pth')
overnet.load_state_dict(checkpoint['model_state_dict'])

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
        
corrects = 0
for batch_idx, (spec, target) in enumerate(tqdm(testloader)):
    spec, target = spec.float(), target.float()
    if half:
        spec, target = spec.half(), target.half()
    spec = spec.cuda(device)
    target = target.cuda(device)
    overnet.eval()
    out = overnet(spec) 
    corrects += compute_corrects(out, target)
print('test acc:', corrects/len(testset))
overnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


test acc: 0.9022686025408349
