In [1]:
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import time
import glob
from lxml.html import parse
from sphfile import SPHFile
import pydub
import audiosegment
import pandas as pd
from collections import Counter
from bs4 import BeautifulSoup
import sys
import os
from tqdm.notebook import tqdm
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
sr = 16000
dropout = 0.3
half = False
root = '../'

In [2]:
device = 2
torch.cuda.set_device(device)

In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, compute_original = False):
        super().__init__()
        self.overlays = pd.read_csv(csv)
        self.speakers = list(set(self.overlays['first_speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.compute_original = compute_original
    def __len__(self):
        return len(self.overlays)
    def __getitem__(self, idx):
        overlay = self.overlays.iloc[idx]
        first_segment = np.load(root+overlay['first_file'])/(2**15)
        second_segment = np.load(root+overlay['second_file'])/(2**15)
        third_segment = np.load(root+overlay['third_file'])/(2**15)
        max_len = max(len(first_segment), len(second_segment), len(third_segment))
        #padding to compensate rounding errors
        if max_len>len(first_segment):
            padding = np.zeros(max_len-len(first_segment))
            first_segment = np.concatenate((first_segment, padding))
        
        if max_len>len(second_segment):
            padding = np.zeros(max_len-len(second_segment))
            second_segment = np.concatenate((second_segment, padding))
            
        if max_len>len(third_segment):
            padding = np.zeros(max_len-len(third_segment))
            third_segment = np.concatenate((third_segment, padding))
        
        first_idx  = self.spkr2idx[overlay['first_speaker']]
        second_idx = self.spkr2idx[overlay['second_speaker']]
        third_idx = self.spkr2idx[overlay['third_speaker']]

        target = np.zeros(len(self.speakers))
        target[first_idx] = 1.0
        target[second_idx] = 1.0
        target[third_idx] = 1.0
        
        if self.compute_original:
            return self.make_spectrogram(first_segment), self.make_spectrogram(second_segment),\
                self.make_spectrogram(third_segment), self.make_spectrogram(first_segment+second_segment), target
        else:
            return self.make_spectrogram(first_segment+second_segment+third_segment), target
    def make_spectrogram(self, segment):
        segment = segment[50:-50] # make size 200
        S = librosa.feature.melspectrogram(segment, n_mels = 256, n_fft = 1024, hop_length = 160) # 32 ms window, 10 ms hop
        S_dB = librosa.power_to_db(S).T[None, :, :] # add channel dimension
        return S_dB

trainset = OverlayDataSet('overlay-train.csv', False)
valset = OverlayDataSet('overlay-val.csv', False)
testset = OverlayDataSet('overlay-test.csv', False)

In [4]:
spec3, target = trainset[0]
plt.figure(figsize = (20, 6))
if trainset.compute_original:
    plt.subplot(131)
    plt.imshow(spec1[0].T)
    plt.subplot(132)
    plt.imshow(spec2[0].T)
    plt.subplot(133)
    plt.imshow(spec3[0].T)
print(spec3.max(), spec3.min(), spec3.shape, target)

29.884724 -50.115276 (1, 200, 256) [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0.]


<Figure size 1440x432 with 0 Axes>

## Maybe try drastically increasing channel number in residual attention stage to see if it overfits

In [5]:
# import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.stride = stride
        self.bn1 = nn.BatchNorm2d(input_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(input_channels, output_channels, 1, 1, bias = False)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(output_channels, output_channels, 3, stride, padding = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(output_channels, output_channels, 1, 1, bias = False)
        self.conv4 = nn.Conv2d(input_channels, output_channels , 1, stride, bias = False)
        
    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out1 = self.relu(out)
        out = self.conv1(out1)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        if (self.input_channels != self.output_channels) or (self.stride !=1 ):
            residual = self.conv4(out1)
        out += residual
        return out

class AttentionModule_stage1(nn.Module):
    # input size is 56*56
    def __init__(self, in_channels, out_channels, size1=(200, 128), size2=(100, 64), size3=(50, 32)):
        super(AttentionModule_stage1, self).__init__()
        self.first_residual_blocks = ResidualBlock(in_channels, out_channels)

        self.trunk_branches = nn.Sequential(
            ResidualBlock(in_channels, out_channels),
            ResidualBlock(in_channels, out_channels)
         )

        self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax1_blocks = ResidualBlock(in_channels, out_channels)

        self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)

        self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax2_blocks = ResidualBlock(in_channels, out_channels)

        self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)

        self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.softmax3_blocks = nn.Sequential(
            ResidualBlock(in_channels, out_channels),
            ResidualBlock(in_channels, out_channels)
        )

        self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)

        self.softmax4_blocks = ResidualBlock(in_channels, out_channels)

        self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)

        self.softmax5_blocks = ResidualBlock(in_channels, out_channels)

        self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)

        self.softmax6_blocks = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels , kernel_size = 1, stride = 1, bias = False),
            nn.Sigmoid()
        )

        self.last_blocks = ResidualBlock(in_channels, out_channels)

    def forward(self, x):
        #batch_size, nheads, length, n_mels = x.shape
        x = self.first_residual_blocks(x)
        out_trunk = self.trunk_branches(x)
        out_mpool1 = self.mpool1(x) # 100x64
        out_softmax1 = self.softmax1_blocks(out_mpool1)
        out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
        out_mpool2 = self.mpool2(out_softmax1) # 50x32
        out_softmax2 = self.softmax2_blocks(out_mpool2)
        out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
        out_mpool3 = self.mpool3(out_softmax2) # 25x16
        out_softmax3 = self.softmax3_blocks(out_mpool3) 
        out_interp3 = self.interpolation3(out_softmax3) + out_softmax2
        out = out_interp3 + out_skip2_connection
        out_softmax4 = self.softmax4_blocks(out)
        out_interp2 = self.interpolation2(out_softmax4) + out_softmax1
        out = out_interp2 + out_skip1_connection
        out_softmax5 = self.softmax5_blocks(out)
        out_interp1 = self.interpolation1(out_softmax5) + out_trunk
        out_softmax6 = self.softmax6_blocks(out_interp1)
        out = (1 + out_softmax6) * out_trunk
        out_last = self.last_blocks(out)
        return out_last

num_heads = 12 # Residual Attention Channels
num_heads_2 = 4 # MHA heads


class OverlayNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = nn.LayerNorm(256)
        self.downsample1 = ResidualBlock(1, num_heads, (1, 2))
        self.res_att = AttentionModule_stage1(num_heads, num_heads)  # batch_size * num_heads * L *128
        self.downsample2 = nn.Sequential(ResidualBlock(num_heads, num_heads//2),
                                        ResidualBlock(num_heads//2, num_heads//4),
                                        ResidualBlock(num_heads//4, num_heads//4, (1, 2)))
        self.reshape =  Lambda(lambda x: x.permute((1, 2, 0, 3))) # L * batch_size * (num_heads*128)
        self.lstm = nn.LSTM(64, 32, 2, batch_first = False, bidirectional = True, dropout = dropout) # L * batch_size * 200 * n_hidden
        self.mha =  torch.nn.MultiheadAttention(64, num_heads = num_heads_2, dropout=dropout, bias=True, kdim=64, vdim=64) # L * N * 64
        self.fc1 = nn.Linear(64, 32)
        self.average = Lambda(lambda x: x.mean(dim = 0)) # batch * n_hidden
        self.tanh = nn.Tanh()
        #self.norm = Lambda(lambda x: torch.nn.functional.normalize(x, p = 2, dim = 1)) # L2 normalize across n_hidden
        self.fc2 = nn.Linear(32, 20)
        self.softmax = nn.Softmax(1)
    def forward(self, X):
        X = self.ln(X)
        X = self.downsample1(X)
        #print('first downsample ', X.shape)
        X = self.res_att(X)
        #print('residual attention ', X.shape)
        X = self.downsample2(X)
        #print('second downsample ', X.shape)
        X = self.reshape(X)
        X1, X2, X3 = X[0], X[1], X[2]
        X1,_ = self.lstm(X1)
        X2,_ = self.lstm(X1)
        X3,_ = self.lstm(X3)
        #print('lstm ', X.shape)
        X1,_ = self.mha(X1, X1, X1)
        X2,_ = self.mha(X2, X2, X2)
        X3,_ = self.mha(X3, X3, X3)
        #print('mha ', X.shape)
        X1 = self.fc1(X1)
        X2 = self.fc1(X2)
        X3 = self.fc1(X3)
        #print('dense ', X.shape)
        X1 = self.average(X1)
        X2 = self.average(X2)
        X3 = self.average(X3)
        #print('mean ', X.shape)
        X1 = self.tanh(X1)
        X2 = self.tanh(X2)
        X3 = self.tanh(X3)
        X1 = self.fc2(X1)
        X2 = self.fc2(X2)
        X3 = self.fc2(X3)
        X1 = self.softmax(X1)
        X2 = self.softmax(X2)
        X3 = self.softmax(X3)
        X = torch.stack([X1,X2,X3], dim=0)
        X,_ = torch.max(X, dim=0)
        return X
    
    
overnet = OverlayNet().cuda(device)
    
# tune hidden layers smaller if overfit
optimizer = torch.optim.Adam(overnet.parameters(), 0.001)

if os.path.exists('models/overnet8.pth'):
    print('load model')
    checkpoint = torch.load('models/overnet8.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    try:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    except:
        print('cannot load optimizer')
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0
    
if half:
    overnet.half()  # convert to half precision
    for layer in overnet.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.float()
            
overnet.train()
'bestacc:', bestacc

load model


('bestacc:', 0.0)

## Also Do metrics on hitting a single person right
## Theoretical justification that when reducing number of channels, neural network could learn to pair similar spectrum representations with each other: only pitch lines that are greater than a certain threshold will pass relu and become postivie. Batchnorm makes this effect stronger. Therefore, when two channels represent spectrums of different people, when they are added the output will be nothing if it doesn't pass the relu threshold

In [7]:
batch_size = 64
def find_max3(tensor):
    array = tensor.cpu().detach().numpy()
    max3 = []
    for row in array:
        max3.append(np.argsort(row)[::-1][:3])
    return np.array(max3)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max3(tensor1), find_max3(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()

for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (spec, target) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        spec, target = spec.float(), target.float()
        if half:
            spec, target = spec.half(),target.half()
        spec = spec.cuda(device)
        target = target.cuda(device)

        out = overnet(spec)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(overnet.parameters(), 0.5)
        optimizer.step()
                
        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/overnet8.pth')

        
    with torch.no_grad():    
        corrects = 0
        for batch_idx, (spec, target) in enumerate(tqdm(valloader)):
            spec, target = spec.float(), target.float()
            if half:
                spec, target = spec.half(), target.half()
            spec = spec.cuda(device)
            target = target.cuda(device)
            overnet.eval()
            out = overnet(spec) 
            corrects += compute_corrects(out, target)
        print('val acc:', corrects/len(valset))
        if corrects/len(valset) > bestacc:
            bestacc = corrects/len(valset)
            torch.save({
            'model_state_dict': overnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'bestacc': bestacc
            }, 'models/best-overnet8.pth')
        overnet.train()
    pass

HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[1,   200] loss: 0.073 accuracy: 0.749
[1,   400] loss: 0.072 accuracy: 0.751
[1,   600] loss: 0.070 accuracy: 0.766
[1,   800] loss: 0.072 accuracy: 0.755
[1,  1000] loss: 0.071 accuracy: 0.750
[1,  1200] loss: 0.070 accuracy: 0.762
[1,  1400] loss: 0.072 accuracy: 0.757
[1,  1600] loss: 0.072 accuracy: 0.747
[1,  1800] loss: 0.071 accuracy: 0.754
[1,  2000] loss: 0.071 accuracy: 0.755
[1,  2200] loss: 0.072 accuracy: 0.754
[1,  2400] loss: 0.072 accuracy: 0.749
[1,  2600] loss: 0.071 accuracy: 0.753
[1,  2800] loss: 0.071 accuracy: 0.751
[1,  3000] loss: 0.072 accuracy: 0.751
[1,  3200] loss: 0.071 accuracy: 0.760
[1,  3400] loss: 0.072 accuracy: 0.747
[1,  3600] loss: 0.070 accuracy: 0.765
[1,  3800] loss: 0.072 accuracy: 0.749
[1,  4000] loss: 0.071 accuracy: 0.755
[1,  4200] loss: 0.073 accuracy: 0.748
[1,  4400] loss: 0.072 accuracy: 0.758
[1,  4600] loss: 0.073 accuracy: 0.749
[1,  4800] loss: 0.069 accuracy: 0.766
[1,  5000] loss: 0.071 accuracy: 0.756
[1,  5200] loss: 0.070 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.7644837668884856


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[2,   200] loss: 0.070 accuracy: 0.761
[2,   400] loss: 0.069 accuracy: 0.764
[2,   600] loss: 0.070 accuracy: 0.757
[2,   800] loss: 0.069 accuracy: 0.760
[2,  1000] loss: 0.069 accuracy: 0.765
[2,  1200] loss: 0.070 accuracy: 0.762
[2,  1400] loss: 0.070 accuracy: 0.761
[2,  1600] loss: 0.070 accuracy: 0.762
[2,  1800] loss: 0.069 accuracy: 0.765
[2,  2000] loss: 0.070 accuracy: 0.761
[2,  2200] loss: 0.069 accuracy: 0.763
[2,  2400] loss: 0.071 accuracy: 0.754
[2,  2600] loss: 0.069 accuracy: 0.765
[2,  2800] loss: 0.070 accuracy: 0.763
[2,  3000] loss: 0.070 accuracy: 0.765
[2,  3200] loss: 0.070 accuracy: 0.763
[2,  3400] loss: 0.070 accuracy: 0.764
[2,  3600] loss: 0.071 accuracy: 0.759
[2,  3800] loss: 0.070 accuracy: 0.760
[2,  4000] loss: 0.069 accuracy: 0.760
[2,  4200] loss: 0.069 accuracy: 0.766
[2,  4400] loss: 0.069 accuracy: 0.766
[2,  4600] loss: 0.071 accuracy: 0.757
[2,  4800] loss: 0.071 accuracy: 0.755
[2,  5000] loss: 0.069 accuracy: 0.762
[2,  5200] loss: 0.069 ac

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[2, 14200] loss: 0.069 accuracy: 0.763
[2, 14400] loss: 0.071 accuracy: 0.762
[2, 14600] loss: 0.070 accuracy: 0.761
[2, 14800] loss: 0.071 accuracy: 0.756
[2, 15000] loss: 0.072 accuracy: 0.750
[2, 15200] loss: 0.070 accuracy: 0.763
[2, 15400] loss: 0.070 accuracy: 0.761
[2, 15600] loss: 0.070 accuracy: 0.760


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[2, 19000] loss: 0.069 accuracy: 0.767
[2, 19200] loss: 0.069 accuracy: 0.760
[2, 19400] loss: 0.070 accuracy: 0.761
[2, 19600] loss: 0.069 accuracy: 0.759
[2, 19800] loss: 0.068 accuracy: 0.766
[2, 20000] loss: 0.069 accuracy: 0.758
[2, 20200] loss: 0.070 accuracy: 0.763
[2, 20400] loss: 0.069 accuracy: 0.765


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[2, 24000] loss: 0.069 accuracy: 0.759
[2, 24200] loss: 0.069 accuracy: 0.762
[2, 24400] loss: 0.070 accuracy: 0.760
[2, 24600] loss: 0.068 accuracy: 0.770



HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.7624773139745916


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[3,   200] loss: 0.070 accuracy: 0.763
[3,   400] loss: 0.069 accuracy: 0.763
[3,   600] loss: 0.069 accuracy: 0.766
[3,   800] loss: 0.069 accuracy: 0.767
[3,  1000] loss: 0.069 accuracy: 0.759
[3,  1200] loss: 0.067 accuracy: 0.764
[3,  1400] loss: 0.069 accuracy: 0.765
[3,  1600] loss: 0.070 accuracy: 0.759
[3,  1800] loss: 0.069 accuracy: 0.759
[3,  2000] loss: 0.068 accuracy: 0.768
[3,  2200] loss: 0.067 accuracy: 0.774
[3,  2400] loss: 0.069 accuracy: 0.764
[3,  2600] loss: 0.070 accuracy: 0.760
[3,  2800] loss: 0.069 accuracy: 0.764
[3,  3000] loss: 0.069 accuracy: 0.766
[3,  3200] loss: 0.069 accuracy: 0.765
[3,  3400] loss: 0.069 accuracy: 0.762
[3,  3600] loss: 0.069 accuracy: 0.766
[3,  3800] loss: 0.069 accuracy: 0.764
[3,  4000] loss: 0.068 accuracy: 0.770
[3,  4200] loss: 0.069 accuracy: 0.763
[3,  4400] loss: 0.068 accuracy: 0.771
[3,  4600] loss: 0.069 accuracy: 0.763
[3,  4800] loss: 0.069 accuracy: 0.765
[3,  5000] loss: 0.069 accuracy: 0.767
[3,  5200] loss: 0.068 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.7710980036297641


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[4,   200] loss: 0.068 accuracy: 0.764
[4,   400] loss: 0.068 accuracy: 0.772
[4,   600] loss: 0.068 accuracy: 0.770
[4,   800] loss: 0.068 accuracy: 0.764
[4,  1000] loss: 0.067 accuracy: 0.775
[4,  1200] loss: 0.066 accuracy: 0.776
[4,  1400] loss: 0.067 accuracy: 0.771
[4,  1600] loss: 0.067 accuracy: 0.770
[4,  1800] loss: 0.066 accuracy: 0.768
[4,  2000] loss: 0.067 accuracy: 0.774
[4,  2200] loss: 0.069 accuracy: 0.767
[4,  2400] loss: 0.068 accuracy: 0.765
[4,  2600] loss: 0.067 accuracy: 0.771
[4,  2800] loss: 0.066 accuracy: 0.776
[4,  3000] loss: 0.067 accuracy: 0.775
[4,  3200] loss: 0.067 accuracy: 0.772
[4,  3400] loss: 0.067 accuracy: 0.769
[4,  3600] loss: 0.068 accuracy: 0.762
[4,  3800] loss: 0.067 accuracy: 0.769
[4,  4000] loss: 0.066 accuracy: 0.777
[4,  4200] loss: 0.066 accuracy: 0.772
[4,  4400] loss: 0.069 accuracy: 0.765
[4,  4600] loss: 0.068 accuracy: 0.763
[4,  4800] loss: 0.068 accuracy: 0.770
[4,  5000] loss: 0.067 accuracy: 0.772
[4,  5200] loss: 0.067 ac

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[4, 17000] loss: 0.067 accuracy: 0.768
[4, 17200] loss: 0.069 accuracy: 0.764
[4, 17400] loss: 0.067 accuracy: 0.770
[4, 17600] loss: 0.067 accuracy: 0.775
[4, 17800] loss: 0.068 accuracy: 0.768
[4, 18000] loss: 0.067 accuracy: 0.771
[4, 18200] loss: 0.066 accuracy: 0.776
[4, 18400] loss: 0.069 accuracy: 0.764
[4, 18600] loss: 0.068 accuracy: 0.767
[4, 18800] loss: 0.069 accuracy: 0.761
[4, 19000] loss: 0.067 accuracy: 0.773
[4, 19200] loss: 0.068 accuracy: 0.765
[4, 19400] loss: 0.068 accuracy: 0.768
[4, 19600] loss: 0.067 accuracy: 0.770
[4, 19800] loss: 0.066 accuracy: 0.768
[4, 20000] loss: 0.069 accuracy: 0.765
[4, 20200] loss: 0.068 accuracy: 0.770
[4, 20400] loss: 0.069 accuracy: 0.761
[4, 20600] loss: 0.068 accuracy: 0.766
[4, 20800] loss: 0.066 accuracy: 0.777
[4, 21000] loss: 0.067 accuracy: 0.769
[4, 21200] loss: 0.067 accuracy: 0.770
[4, 21400] loss: 0.068 accuracy: 0.773
[4, 21600] loss: 0.065 accuracy: 0.779
[4, 21800] loss: 0.067 accuracy: 0.768
[4, 22000] loss: 0.068 ac

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.7797035692679976


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[5,   200] loss: 0.066 accuracy: 0.776
[5,   400] loss: 0.065 accuracy: 0.777
[5,   600] loss: 0.065 accuracy: 0.778
[5,   800] loss: 0.066 accuracy: 0.776
[5,  1000] loss: 0.066 accuracy: 0.776
[5,  1200] loss: 0.066 accuracy: 0.777
[5,  1400] loss: 0.067 accuracy: 0.769
[5,  1600] loss: 0.066 accuracy: 0.775
[5,  1800] loss: 0.067 accuracy: 0.774
[5,  2000] loss: 0.066 accuracy: 0.773
[5,  2200] loss: 0.067 accuracy: 0.772
[5,  2400] loss: 0.067 accuracy: 0.773
[5,  2600] loss: 0.066 accuracy: 0.775
[5,  2800] loss: 0.067 accuracy: 0.770
[5,  3000] loss: 0.066 accuracy: 0.777
[5,  3200] loss: 0.067 accuracy: 0.773
[5,  3400] loss: 0.066 accuracy: 0.776
[5,  3600] loss: 0.068 accuracy: 0.768
[5,  3800] loss: 0.067 accuracy: 0.768
[5,  4000] loss: 0.065 accuracy: 0.777
[5,  4200] loss: 0.066 accuracy: 0.772
[5,  4400] loss: 0.066 accuracy: 0.778
[5,  4600] loss: 0.065 accuracy: 0.778
[5,  4800] loss: 0.066 accuracy: 0.773
[5,  5000] loss: 0.065 accuracy: 0.778
[5,  5200] loss: 0.065 ac

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[5, 20600] loss: 0.066 accuracy: 0.776
[5, 20800] loss: 0.067 accuracy: 0.772
[5, 21000] loss: 0.068 accuracy: 0.763
[5, 21200] loss: 0.067 accuracy: 0.771
[5, 21400] loss: 0.066 accuracy: 0.771
[5, 21600] loss: 0.065 accuracy: 0.776
[5, 21800] loss: 0.067 accuracy: 0.771


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[5, 24000] loss: 0.066 accuracy: 0.774
[5, 24200] loss: 0.065 accuracy: 0.778
[5, 24400] loss: 0.067 accuracy: 0.769
[5, 24600] loss: 0.067 accuracy: 0.772



HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


val acc: 0.7825065537406736


HBox(children=(FloatProgress(value=0.0, max=24795.0), HTML(value='')))

[6,   200] loss: 0.065 accuracy: 0.781
[6,   400] loss: 0.064 accuracy: 0.783
[6,   600] loss: 0.067 accuracy: 0.772
[6,   800] loss: 0.063 accuracy: 0.782
[6,  1000] loss: 0.066 accuracy: 0.775
[6,  1200] loss: 0.067 accuracy: 0.773
[6,  1400] loss: 0.065 accuracy: 0.775
[6,  1600] loss: 0.066 accuracy: 0.772
[6,  1800] loss: 0.066 accuracy: 0.777
[6,  2000] loss: 0.065 accuracy: 0.775
[6,  2200] loss: 0.066 accuracy: 0.775
[6,  2400] loss: 0.066 accuracy: 0.773
[6,  2600] loss: 0.065 accuracy: 0.782
[6,  2800] loss: 0.065 accuracy: 0.775
[6,  3000] loss: 0.064 accuracy: 0.782
[6,  3200] loss: 0.066 accuracy: 0.773
[6,  3400] loss: 0.065 accuracy: 0.781
[6,  3600] loss: 0.067 accuracy: 0.772
[6,  3800] loss: 0.066 accuracy: 0.778
[6,  4000] loss: 0.066 accuracy: 0.776
[6,  4200] loss: 0.066 accuracy: 0.773
[6,  4800] loss: 0.065 accuracy: 0.777
[6,  5000] loss: 0.064 accuracy: 0.776
[6,  5200] loss: 0.066 accuracy: 0.776
[6,  5400] loss: 0.066 accuracy: 0.772
[6,  5600] loss: 0.065 ac

In [7]:
checkpoint = torch.load('models/best-overnet8.pth')
overnet.load_state_dict(checkpoint['model_state_dict'])

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, pin_memory = True, num_workers = 16)

overnet.eval()
corrects = 0
for batch_idx, (spec, target) in enumerate(tqdm(testloader)):
    spec, target = spec.float(), target.float()
    if half:
        spec, target = spec.half(), target.half()
    spec = spec.cuda(device)
    target = target.cuda(device)
    out = overnet(spec) 
    corrects += compute_corrects(out, target)
print('test acc:', corrects/len(testset))
overnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=3100.0), HTML(value='')))


test acc: 0.731498285944747
