In [81]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False

In [82]:
device = 2
torch.cuda.set_device(device)


In [83]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']).union(set(self.overlays['second_speaker'])))
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(overlay['first_file'])/(2**15)
        second_segment = np.load(overlay['second_file'])/(2**15)
        #padding to compensate rounding errors
        if len(first_segment)>len(second_segment):
            padding = np.zeros(len(first_segment)-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
        
        if len(first_segment)<len(second_segment):
            padding = np.zeros(len(second_segment)-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.melspectrogram(segment, n_mels = 256, n_fft = 1024, hop_length = 160) # 32 ms window, 10 ms hop
        S_dB = librosa.power_to_db(S, ref=np.max).T[None, :, :] # add channel dimension
        S_dB = (S_dB+40)/40
        return(S_dB)
trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)

<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [85]:
# import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        self.bn1 = nn.BatchNorm2d(input_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(input_channels, output_channels, 1, 1, bias = False)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(output_channels, output_channels, 3, stride, padding = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(output_channels, output_channels, 1, 1, bias = False)
        self.conv4 = nn.Conv2d(input_channels, output_channels , 1, stride, bias = False)
        
    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out1 = self.relu(out)
        out = self.conv1(out1)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        if (self.input_channels != self.output_channels) or (self.stride !=1 ):
            residual = self.conv4(out1)
        out += residual
        return out

class AttentionModule_stage1(nn.Module):
    # input size is 56*56
    def __init__(self, in_channels, out_channels, size1=(200, 128), size2=(100, 64), size3=(50, 32)):
        super(AttentionModule_stage1, self).__init__()
        self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

        self.trunk_branches = nn.Sequential(
            ResidualBlock(in_channels, out_channels),
            ResidualBlock(in_channels, out_channels)
         )

        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax1_blocks = ResidualBlock(in_channels, out_channels)

        self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

        self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax2_blocks = ResidualBlock(in_channels, out_channels)

        self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)

        self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax3_blocks = nn.Sequential(
            ResidualBlock(in_channels, out_channels),
            ResidualBlock(in_channels, out_channels)
        )

        self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)

        self.softmax4_blocks = ResidualBlock(in_channels, out_channels)

        self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)

        self.softmax5_blocks = ResidualBlock(in_channels, out_channels)

        self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

        self.softmax6_blocks = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            nn.Sigmoid()
        )

        self.last_blocks = ResidualBlock(in_channels, out_channels)

    def forward(self, x):
        #batch_size, nheads, length, n_mels = x.shape
        x = self.first_residual_blocks(x)
        out_trunk = self.trunk_branches(x)
        out_mpool1 = self.mpool1(x) # 100x64
        out_softmax1 = self.softmax1_blocks(out_mpool1)
        out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
        out_mpool2 = self.mpool2(out_softmax1) # 50x32
        out_softmax2 = self.softmax2_blocks(out_mpool2)
        out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
        out_mpool3 = self.mpool3(out_softmax2) # 25x16
        out_softmax3 = self.softmax3_blocks(out_mpool3) 
        out_interp3 = self.interpolation3(out_softmax3) + out_softmax2
        out = out_interp3 + out_skip2_connection
        out_softmax4 = self.softmax4_blocks(out)
        out_interp2 = self.interpolation2(out_softmax4) + out_softmax1
        out = out_interp2 + out_skip1_connection
        out_softmax5 = self.softmax5_blocks(out)
        out_interp1 = self.interpolation1(out_softmax5) + out_trunk
        out_softmax6 = self.softmax6_blocks(out_interp1)
        out = (1 + out_softmax6) * out_trunk
        out_last = self.last_blocks(out)
        return out_last

num_heads = 2 # Residual Attention Channels
num_heads_2 = 4 # MHA heads


class OverlayNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.downsample1 = nn.Sequential(nn.Conv2d(1, num_heads, kernel_size=3, stride=(1,2), padding=1), ##downsampling
                        nn.BatchNorm2d(num_heads),
                        nn.ReLU(inplace=True))
        self.res_att = AttentionModule_stage1(num_heads, num_heads)  # batch_size * num_heads * L *128
        self.downsample2 = nn.Sequential(nn.Conv2d(num_heads, num_heads, kernel_size=3, stride=1, padding=1),
                        nn.BatchNorm2d(num_heads),
                        nn.ReLU(inplace=True),
                        nn.MaxPool2d(kernel_size=3, stride=(1, 2), padding=1))
        self.reshape =  Lambda(lambda x: x.permute((1, 2, 0, 3))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(64, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.mha =  torch.nn.MultiheadAttention(64, num_heads = num_heads_2, dropout=dropout, bias=True, kdim=64, vdim=64) # L * N * 64
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.softmax = nn.Softmax(1)
    def forward(self, X):
        X = self.downsample1(X)
        #print('first downsample ', X.shape)
        X = self.res_att(X)
        #print('residual attention ', X.shape)
        X = self.downsample2(X)
        #print('second downsample ', X.shape)
        X = self.reshape(X)
        X1, X2 = X[0], X[1]
        X1,_ = self.lstm(X1)
        X2,_ = self.lstm(X1)
        #print('lstm ', X.shape)
        X1,_ = self.mha(X1, X1, X1)
        X2,_ = self.mha(X2, X2, X2)
        #print('mha ', X.shape)
        X1 = self.fc1(X1)
        X2 = self.fc1(X2)
        #print('dense ', X.shape)
        X1 = self.average(X1)
        X2 = self.average(X2)
        #print('mean ', X.shape)
        X1 = self.tanh(X1)
        X2 = self.tanh(X2)
        X1 = self.fc2(X1)
        X2 = self.fc2(X2)
        X1 = self.softmax(X1)
        X2 = self.softmax(X2)
        X = torch.stack([X1,X2], dim=0)
        X,_ = torch.max(X, dim=0)
        return X
    
    
overnet = OverlayNet().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/overnet.pth'):
    print('load model')
    checkpoint = torch.load('models/overnet.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

('bestacc:', 0.0)

## Also Do metrics on hitting a single person right

In [94]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            corrects+=1
    return corrects


trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/64
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/overnet.pth')
        #measure time
        #print('batch time: ', str(time.time()-lasttime)[:4])
        lasttime = time.time()
        
        
    corrects = 0
    for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(), target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)
        overnet.eval()
        out = overnet(spec) 
        corrects += compute_corrects(out, target)
    print('val acc:', corrects/len(valset))
    if corrects/len(valset) > bestacc:
        bestacc = corrects/len(valset)
        torch.save({
        'model_state_dict': overnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'bestacc': bestacc
        }, 'models/best-overnet.pth')
    overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[1,   200] loss: 0.012 accuracy: 0.959
[1,   400] loss: 0.011 accuracy: 0.961
[1,   600] loss: 0.012 accuracy: 0.958
[1,   800] loss: 0.011 accuracy: 0.961
[1,  1000] loss: 0.012 accuracy: 0.958
[1,  1200] loss: 0.011 accuracy: 0.961



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9118874773139746


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[2,   200] loss: 0.011 accuracy: 0.964
[2,   400] loss: 0.012 accuracy: 0.959
[2,   600] loss: 0.012 accuracy: 0.958
[2,   800] loss: 0.012 accuracy: 0.958
[2,  1000] loss: 0.012 accuracy: 0.961
[2,  1200] loss: 0.012 accuracy: 0.958



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9186025408348457


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[3,   200] loss: 0.012 accuracy: 0.959
[3,   400] loss: 0.011 accuracy: 0.963
[3,   600] loss: 0.011 accuracy: 0.961
[3,   800] loss: 0.012 accuracy: 0.961
[3,  1000] loss: 0.012 accuracy: 0.961
[3,  1200] loss: 0.012 accuracy: 0.960



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9158802177858439


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[4,   200] loss: 0.010 accuracy: 0.963
[4,   400] loss: 0.012 accuracy: 0.960
[4,   600] loss: 0.011 accuracy: 0.962
[4,   800] loss: 0.011 accuracy: 0.961
[4,  1000] loss: 0.013 accuracy: 0.953
[4,  1200] loss: 0.012 accuracy: 0.959



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9140653357531761


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[5,   200] loss: 0.012 accuracy: 0.959
[5,   400] loss: 0.011 accuracy: 0.962
[5,   600] loss: 0.012 accuracy: 0.960
[5,   800] loss: 0.011 accuracy: 0.961
[5,  1000] loss: 0.012 accuracy: 0.958
[5,  1200] loss: 0.012 accuracy: 0.960



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9098003629764065


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[6,   200] loss: 0.011 accuracy: 0.963
[6,   400] loss: 0.010 accuracy: 0.963
[6,   600] loss: 0.010 accuracy: 0.965
[6,   800] loss: 0.013 accuracy: 0.956
[6,  1000] loss: 0.012 accuracy: 0.958
[6,  1200] loss: 0.011 accuracy: 0.962



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9133393829401089


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[7,   200] loss: 0.010 accuracy: 0.964
[7,   400] loss: 0.011 accuracy: 0.961
[7,   600] loss: 0.011 accuracy: 0.961
[7,   800] loss: 0.012 accuracy: 0.956
[7,  1000] loss: 0.012 accuracy: 0.960
[7,  1200] loss: 0.012 accuracy: 0.958



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9117059891107078


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[8,   200] loss: 0.011 accuracy: 0.962
[8,   400] loss: 0.011 accuracy: 0.963
[8,   600] loss: 0.012 accuracy: 0.961
[8,   800] loss: 0.011 accuracy: 0.962
[8,  1000] loss: 0.012 accuracy: 0.957
[8,  1200] loss: 0.011 accuracy: 0.960



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9166969147005445


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[9,   200] loss: 0.011 accuracy: 0.964
[9,   400] loss: 0.011 accuracy: 0.963
[9,   600] loss: 0.011 accuracy: 0.964
[9,   800] loss: 0.011 accuracy: 0.961
[9,  1000] loss: 0.012 accuracy: 0.956
[9,  1200] loss: 0.011 accuracy: 0.964



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9127041742286751


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[10,   200] loss: 0.011 accuracy: 0.964
[10,   400] loss: 0.011 accuracy: 0.963
[10,   600] loss: 0.011 accuracy: 0.962
[10,   800] loss: 0.011 accuracy: 0.961
[10,  1000] loss: 0.012 accuracy: 0.961
[10,  1200] loss: 0.011 accuracy: 0.962



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9156079854809437


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[11,   200] loss: 0.010 accuracy: 0.967
[11,   400] loss: 0.011 accuracy: 0.964
[11,   600] loss: 0.012 accuracy: 0.958
[11,   800] loss: 0.011 accuracy: 0.960
[11,  1000] loss: 0.012 accuracy: 0.957
[11,  1200] loss: 0.011 accuracy: 0.960



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9098003629764065


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[12,   200] loss: 0.011 accuracy: 0.963
[12,   400] loss: 0.011 accuracy: 0.965
[12,   600] loss: 0.011 accuracy: 0.961
[12,   800] loss: 0.011 accuracy: 0.962
[12,  1000] loss: 0.011 accuracy: 0.961
[12,  1200] loss: 0.011 accuracy: 0.963



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9149727767695099


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[13,   200] loss: 0.010 accuracy: 0.965
[13,   400] loss: 0.010 accuracy: 0.965
[13,   600] loss: 0.011 accuracy: 0.960
[13,   800] loss: 0.010 accuracy: 0.965
[13,  1000] loss: 0.012 accuracy: 0.958
[13,  1200] loss: 0.011 accuracy: 0.960



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9164246823956442


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[14,   200] loss: 0.010 accuracy: 0.965
[14,   400] loss: 0.010 accuracy: 0.964
[14,   600] loss: 0.011 accuracy: 0.962
[14,   800] loss: 0.011 accuracy: 0.966
[14,  1000] loss: 0.012 accuracy: 0.960
[14,  1200] loss: 0.012 accuracy: 0.961



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9186932849364792


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[15,   200] loss: 0.010 accuracy: 0.965
[15,   400] loss: 0.009 accuracy: 0.968
[15,   600] loss: 0.012 accuracy: 0.960
[15,   800] loss: 0.013 accuracy: 0.957
[15,  1000] loss: 0.012 accuracy: 0.959
[15,  1200] loss: 0.010 accuracy: 0.964



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9164246823956442


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[16,   200] loss: 0.010 accuracy: 0.968
[16,   400] loss: 0.010 accuracy: 0.966
[16,   600] loss: 0.011 accuracy: 0.961
[16,   800] loss: 0.011 accuracy: 0.962
[16,  1000] loss: 0.012 accuracy: 0.959
[16,  1200] loss: 0.011 accuracy: 0.963



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9159709618874773


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[17,   200] loss: 0.010 accuracy: 0.965
[17,   400] loss: 0.011 accuracy: 0.964
[17,   600] loss: 0.011 accuracy: 0.962
[17,   800] loss: 0.011 accuracy: 0.962
[17,  1000] loss: 0.011 accuracy: 0.962
[17,  1200] loss: 0.012 accuracy: 0.959



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.920961887477314


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[18,   200] loss: 0.010 accuracy: 0.965
[18,   400] loss: 0.011 accuracy: 0.964
[18,   600] loss: 0.010 accuracy: 0.967
[18,   800] loss: 0.011 accuracy: 0.961
[18,  1000] loss: 0.011 accuracy: 0.961
[18,  1200] loss: 0.012 accuracy: 0.959



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9157894736842105


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[19,   200] loss: 0.010 accuracy: 0.966
[19,   400] loss: 0.011 accuracy: 0.964
[19,   600] loss: 0.011 accuracy: 0.961
[19,   800] loss: 0.011 accuracy: 0.964
[19,  1000] loss: 0.011 accuracy: 0.963
[19,  1200] loss: 0.011 accuracy: 0.961



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9150635208711434


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[20,   200] loss: 0.011 accuracy: 0.963
[20,   400] loss: 0.010 accuracy: 0.964
[20,   600] loss: 0.010 accuracy: 0.966
[20,   800] loss: 0.011 accuracy: 0.963
[20,  1000] loss: 0.011 accuracy: 0.960
[20,  1200] loss: 0.012 accuracy: 0.957



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9101633393829401


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[21,   200] loss: 0.010 accuracy: 0.967
[21,   400] loss: 0.010 accuracy: 0.964
[21,   600] loss: 0.011 accuracy: 0.965
[21,   800] loss: 0.010 accuracy: 0.966
[21,  1000] loss: 0.011 accuracy: 0.960
[21,  1200] loss: 0.011 accuracy: 0.962



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9114337568058076


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[22,   200] loss: 0.010 accuracy: 0.968
[22,   400] loss: 0.010 accuracy: 0.968
[22,   600] loss: 0.010 accuracy: 0.964
[22,   800] loss: 0.010 accuracy: 0.963
[22,  1000] loss: 0.011 accuracy: 0.964
[22,  1200] loss: 0.011 accuracy: 0.961



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9147005444646098


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[23,   200] loss: 0.010 accuracy: 0.969
[23,   400] loss: 0.011 accuracy: 0.962
[23,   600] loss: 0.010 accuracy: 0.967
[23,   800] loss: 0.011 accuracy: 0.960
[23,  1000] loss: 0.011 accuracy: 0.961
[23,  1200] loss: 0.011 accuracy: 0.964



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9149727767695099


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[24,   200] loss: 0.010 accuracy: 0.964
[24,   400] loss: 0.011 accuracy: 0.962
[24,   600] loss: 0.011 accuracy: 0.962
[24,   800] loss: 0.011 accuracy: 0.962
[24,  1000] loss: 0.011 accuracy: 0.962
[24,  1200] loss: 0.011 accuracy: 0.966



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9170598911070781


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[25,   200] loss: 0.010 accuracy: 0.966
[25,   400] loss: 0.009 accuracy: 0.969
[25,   600] loss: 0.011 accuracy: 0.963
[25,   800] loss: 0.010 accuracy: 0.966
[25,  1000] loss: 0.010 accuracy: 0.966
[25,  1200] loss: 0.011 accuracy: 0.965



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9196007259528131


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[26,   200] loss: 0.010 accuracy: 0.968
[26,   400] loss: 0.010 accuracy: 0.966
[26,   600] loss: 0.010 accuracy: 0.966
[26,   800] loss: 0.010 accuracy: 0.965
[26,  1000] loss: 0.011 accuracy: 0.960
[26,  1200] loss: 0.011 accuracy: 0.961



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9152450090744102


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[27,   200] loss: 0.010 accuracy: 0.965
[27,   400] loss: 0.010 accuracy: 0.965
[27,   600] loss: 0.011 accuracy: 0.963
[27,   800] loss: 0.011 accuracy: 0.962
[27,  1000] loss: 0.010 accuracy: 0.964
[27,  1200] loss: 0.010 accuracy: 0.966



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9147912885662431


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[28,   200] loss: 0.009 accuracy: 0.970
[28,   400] loss: 0.010 accuracy: 0.965
[28,   600] loss: 0.010 accuracy: 0.962
[28,   800] loss: 0.010 accuracy: 0.965
[28,  1000] loss: 0.011 accuracy: 0.964
[28,  1200] loss: 0.011 accuracy: 0.962



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9167876588021778


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[29,   200] loss: 0.010 accuracy: 0.964
[29,   400] loss: 0.010 accuracy: 0.967
[29,   600] loss: 0.010 accuracy: 0.967
[29,   800] loss: 0.011 accuracy: 0.965
[29,  1000] loss: 0.011 accuracy: 0.965
[29,  1200] loss: 0.010 accuracy: 0.963



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9141560798548094


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[30,   200] loss: 0.009 accuracy: 0.970
[30,   400] loss: 0.010 accuracy: 0.966
[30,   600] loss: 0.010 accuracy: 0.963
[30,   800] loss: 0.010 accuracy: 0.965
[30,  1000] loss: 0.010 accuracy: 0.965
[30,  1200] loss: 0.010 accuracy: 0.967



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9174228675136116


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[31,   200] loss: 0.010 accuracy: 0.965
[31,   400] loss: 0.010 accuracy: 0.965
[31,   600] loss: 0.010 accuracy: 0.965
[31,   800] loss: 0.010 accuracy: 0.965
[31,  1000] loss: 0.010 accuracy: 0.965
[31,  1200] loss: 0.010 accuracy: 0.964



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9136116152450091


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[32,   200] loss: 0.009 accuracy: 0.968
[32,   400] loss: 0.010 accuracy: 0.966
[32,   600] loss: 0.010 accuracy: 0.967
[32,   800] loss: 0.011 accuracy: 0.962
[32,  1000] loss: 0.011 accuracy: 0.964
[32,  1200] loss: 0.010 accuracy: 0.965



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9140653357531761


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[33,   200] loss: 0.010 accuracy: 0.967
[33,   400] loss: 0.010 accuracy: 0.966
[33,   600] loss: 0.010 accuracy: 0.965
[33,   800] loss: 0.010 accuracy: 0.965
[33,  1000] loss: 0.010 accuracy: 0.967
[33,  1200] loss: 0.011 accuracy: 0.965



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9098003629764065


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[34,   200] loss: 0.010 accuracy: 0.968
[34,   400] loss: 0.009 accuracy: 0.969
[34,   600] loss: 0.010 accuracy: 0.966
[34,   800] loss: 0.011 accuracy: 0.962
[34,  1000] loss: 0.010 accuracy: 0.964
[34,  1200] loss: 0.010 accuracy: 0.966



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.918874773139746


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[35,   200] loss: 0.009 accuracy: 0.969
[35,   400] loss: 0.009 accuracy: 0.970
[35,   600] loss: 0.011 accuracy: 0.965
[35,   800] loss: 0.011 accuracy: 0.961
[35,  1000] loss: 0.010 accuracy: 0.965
[35,  1200] loss: 0.010 accuracy: 0.964



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9107078039927404


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[36,   200] loss: 0.009 accuracy: 0.968
[36,   400] loss: 0.009 accuracy: 0.968
[36,   600] loss: 0.010 accuracy: 0.967
[36,   800] loss: 0.010 accuracy: 0.964
[36,  1000] loss: 0.010 accuracy: 0.964
[36,  1200] loss: 0.009 accuracy: 0.969



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9107078039927404


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[37,   200] loss: 0.009 accuracy: 0.968
[37,   400] loss: 0.010 accuracy: 0.968
[37,   600] loss: 0.010 accuracy: 0.966
[37,   800] loss: 0.010 accuracy: 0.967
[37,  1000] loss: 0.011 accuracy: 0.964
[37,  1200] loss: 0.010 accuracy: 0.965



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9156987295825771


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[38,   200] loss: 0.009 accuracy: 0.970
[38,   400] loss: 0.010 accuracy: 0.968
[38,   600] loss: 0.010 accuracy: 0.966
[38,   800] loss: 0.010 accuracy: 0.970
[38,  1000] loss: 0.010 accuracy: 0.966
[38,  1200] loss: 0.010 accuracy: 0.965



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9160617059891107


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[39,   200] loss: 0.010 accuracy: 0.968
[39,   400] loss: 0.010 accuracy: 0.965
[39,   600] loss: 0.011 accuracy: 0.963
[39,   800] loss: 0.010 accuracy: 0.965
[39,  1000] loss: 0.010 accuracy: 0.965
[39,  1200] loss: 0.010 accuracy: 0.965
[40,   400] loss: 0.009 accuracy: 0.969
[40,   600] loss: 0.009 accuracy: 0.966
[40,   800] loss: 0.010 accuracy: 0.965
[40,  1000] loss: 0.010 accuracy: 0.964
[40,  1200] loss: 0.009 accuracy: 0.967



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9174228675136116


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[43,   400] loss: 0.010 accuracy: 0.968
[43,   600] loss: 0.010 accuracy: 0.966
[43,   800] loss: 0.009 accuracy: 0.967
[43,  1000] loss: 0.011 accuracy: 0.963
[43,  1200] loss: 0.009 accuracy: 0.968



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9167876588021778


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[44,   200] loss: 0.008 accuracy: 0.971


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[46,   400] loss: 0.009 accuracy: 0.968
[46,   600] loss: 0.010 accuracy: 0.965
[46,   800] loss: 0.010 accuracy: 0.964
[46,  1000] loss: 0.010 accuracy: 0.967
[46,  1200] loss: 0.010 accuracy: 0.964



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9107078039927404


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)






HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9141560798548094


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[49,   200] loss: 0.009 accuracy: 0.969
[49,   400] loss: 0.009 accuracy: 0.970
[49,   600] loss: 0.010 accuracy: 0.966
[49,   800] loss: 0.010 accuracy: 0.966
[49,  1000] loss: 0.010 accuracy: 0.964


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[51,  1200] loss: 0.010 accuracy: 0.967



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9170598911070781


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[52,   200] loss: 0.009 accuracy: 0.970
[52,   400] loss: 0.009 accuracy: 0.970
[52,   600] loss: 0.010 accuracy: 0.966
[52,   800] loss: 0.010 accuracy: 0.966
[52,  1000] loss: 0.010 accuracy: 0.969


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[54,  1000] loss: 0.010 accuracy: 0.966
[54,  1200] loss: 0.009 accuracy: 0.968



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.9149727767695099


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[55,   200] loss: 0.009 accuracy: 0.971
[55,   400] loss: 0.009 accuracy: 0.970
[55,   600] loss: 0.010 accuracy: 0.967
[55,   800] loss: 0.010 accuracy: 0.968


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[57,  1200] loss: 0.009 accuracy: 0.970



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.915426497277677


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[58,   200] loss: 0.009 accuracy: 0.968
[58,   400] loss: 0.009 accuracy: 0.970
[58,   600] loss: 0.009 accuracy: 0.967
[58,   800] loss: 0.009 accuracy: 0.966


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[60,  1200] loss: 0.009 accuracy: 0.970



HBox(children=(FloatProgress(value=0.0, max=173.0), HTML(value='')))


val acc: 0.915426497277677


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[61,   200] loss: 0.009 accuracy: 0.968
[61,   400] loss: 0.009 accuracy: 0.969
[61,   600] loss: 0.009 accuracy: 0.971
[61,   800] loss: 0.010 accuracy: 0.966
[61,  1000] loss: 0.010 accuracy: 0.967


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[62,  1000] loss: 0.009 accuracy: 0.970
[62,  1200] loss: 0.009 accuracy: 0.971


KeyboardInterrupt: 