In [5]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
sys.path.append('SincNet/')
from conv_tasnet import *
from pit_criterion import cal_loss
from dnn_models import *
from data_io import ReadList,read_conf_inp,str_to_bool
from collections import Counter
device = 1
root = '../'
old_sr = 16000
new_sr = 8000

In [7]:
def load8hz(filename):
    samples = np.load(filename)/(2**15)
    samples = resample(samples, old_sr, new_sr)
    # pad the samples
    if len(samples)>16000:
        samples = samples[:16000]
    if len(samples)<16000:
        padding = np.zeros(16000-len(samples))
        samples = np.concatenate([samples, padding])
    
    return samples

class SourceSet(torch.utils.data.Dataset):
    def __init__(self, root, csv):
        super().__init__()
        self.root = root
        self.csv = pd.read_csv(root+csv)
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        sig1, sig2 = load8hz(root+row['first_file']), load8hz(root+row['second_file'])
        return sig1, sig2
sourceset_train = SourceSet(root, 'overlay-train.csv')


In [6]:
tasnet = ConvTasNet.load_model('final.pth.tar').cuda(device)
tasnet.train()
optimizer = torch.optim.Adam(tasnet.parameters(), lr = 0.001)
if os.path.exists('models/tasnet.pth'):
    print('load model')
    checkpoint = torch.load('models/tasnet.pth')
    overnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']


ConvTasNet(
  (encoder): Encoder(
    (conv1d_U): Conv1d(1, 256, kernel_size=(20,), stride=(10,), bias=False)
  )
  (separator): TemporalConvNet(
    (network): Sequential(
      (0): ChannelwiseLayerNorm()
      (1): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
      (2): Sequential(
        (0): Sequential(
          (0): TemporalBlock(
            (net): Sequential(
              (0): Conv1d(256, 512, kernel_size=(1,), stride=(1,), bias=False)
              (1): PReLU(num_parameters=1)
              (2): GlobalLayerNorm()
              (3): DepthwiseSeparableConv(
                (net): Sequential(
                  (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,), groups=512, bias=False)
                  (1): PReLU(num_parameters=1)
                  (2): GlobalLayerNorm()
                  (3): Conv1d(512, 256, kernel_size=(1,), stride=(1,), bias=False)
                )
              )
            )
          )
          (1): TemporalBlock(
      

In [None]:
batch_size = 8
sourceloader_train  = torch.utils.data.DataLoader(sourceset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)

for epoch in range(64):
    running_loss = 0.0
    for batch_idx, (sig1, sig2) in enumerate(tqdm(sourceloader_train)):
        optimizer.zero_grad()
        sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
        out = tasnet(sig1+sig2)
        source = torch.stack([sig1, sig2], dim = 1).detach()
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(source, out, torch.ones(batch_size, dtype = torch.int32).cuda(device)*16000)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
            torch.save({
            'model_state_dict': tasnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/tasnet.pth')

HBox(children=(FloatProgress(value=0.0, max=11020.0), HTML(value='')))

[1,   200] loss: -5.306 
[1,   400] loss: -7.088 
[1,   600] loss: -7.693 
[1,   800] loss: -8.213 
[1,  1000] loss: -8.359 
[1,  1200] loss: -9.046 
[1,  1400] loss: -9.422 
[1,  1600] loss: -9.418 
[1,  1800] loss: -9.483 
[1,  2000] loss: -9.617 
[1,  2400] loss: -9.916 
[1,  2600] loss: -10.315 
[1,  2800] loss: -10.095 
[1,  3000] loss: -10.409 
[1,  3400] loss: -10.462 
[1,  3600] loss: -10.715 
[1,  3800] loss: -10.751 
[1,  4000] loss: -10.686 


In [1]:



def reformat(prefix, i, filename):
    filename = filename.split('.')[-2]
    return filename + '_' + prefix + '_' + str(i) + '.npy'

In [2]:
# create a new copy of training data in pwd, with same filename, but replace audio data with de-mixed audio data
csv = pd.read_csv(root+csv_name)

with torch.no_grad():
    for i in tqdm(range(len(csv))): 
        row = csv.iloc[i]
        seg1 = load8hz(root+row['first_file'])
        seg2 = load8hz(root+row['second_file'])
        create_dir(row['first_file'])
        create_dir(row['second_file'])
        shorter = min(len(seg1), len(seg2))
        if len(seg1)>shorter:
            seg1 = seg1[:shorter]
        if len(seg2)>shorter:
            seg2 = seg2[:shorter]
        mixture = torch.Tensor(seg1+seg2).cuda(device)
        mixture = mixture[None, ...]
        out = model(mixture)
        new_seg1, new_seg2 = out[0].cpu().detach().numpy()
        newfile1, newfile2 = reformat(mode, i, row['first_file']), reformat(mode, i, row['second_file'])
        csv.at[i, 'first_file'] = newfile1
        csv.at[i, 'second_file'] = newfile2
        np.save(newfile1, new_seg1)
        np.save(newfile2, new_seg2)
    csv.to_csv(csv_name, index = False)

NameError: name 'pd' is not defined

In [3]:
def chop_chunk(signal):
    signal_len = signal.shape[-1]
    if signal_len < 16000:
        padding = np.zeros(16000-len(signal))
        signal = np.cat((signal, padding))
    N_fr=signal_len//wlen
    chunks = []
    for i in range(N_fr):
        chunks.append(signal[i*wlen:(i+1)*wlen])
    return chunks

class ChunkSet(torch.utils.data.Dataset):
    def __init__(self, csv, mode='train'):
        super().__init__()
        self.csv = pd.read_csv(csv)
        self.speakers = list(set(self.csv['first_speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.mode = mode
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        spkr1, spkr2 = row['first_speaker'], row['second_speaker']
        sig1, sig2 = np.load(row['first_file']), np.load(row['second_file'])
        chunk1, chunk2 = chop_chunk(sig1), chop_chunk(sig2)
        target_vec = np.zeros(len(self.speakers))
        target_vec[self.spkr2idx[spkr1]] = 1
        target_vec[self.spkr2idx[spkr2]] = 1
        if self.mode == 'val':
            return np.array(chunk1), np.array(chunk2), target_vec
        if self.mode == 'train':
            return chunk1[np.random.randint(len(chunk1))], chunk2[np.random.randint(len(chunk2))], target_vec

chunkset_train = ChunkSet('overlay-train.csv', mode = 'train')
chunkset_val = ChunkSet('overlay-val.csv', mode = 'val')

NameError: name 'torch' is not defined

In [3]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

In [13]:
fs=new_sr
cw_len=200
cw_shift=10

wlen=int(fs*cw_len/1000.00)
#wshift=int(fs*cw_shift/1000.00)




class MixedClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        cnn_arch = {
                'input_dim':wlen,
                'fs':fs,
                'cnn_N_filt':[80,60,60],
                'cnn_len_filt':[251,5,5],
                'cnn_max_pool_len':[3,3,3],
                'cnn_use_laynorm_inp':True,
                'cnn_use_batchnorm_inp':False,
                'cnn_use_laynorm':[True,True,True],
                'cnn_use_batchnorm':[False,False,False],
                'cnn_act':['leaky_relu','leaky_relu','leaky_relu'],
                'cnn_drop':[0.0,0.0,0.0]
                }
        self.cnn_net = SincNet(cnn_arch)

        dnn1_arch = {'input_dim': self.cnn_net.out_dim,
                  'fc_lay': [2048,2048,2048],
                  'fc_drop': [0.0,0.0,0.0], 
                  'fc_use_batchnorm': [True,True,True],
                  'fc_use_laynorm': [False,False,False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['leaky_relu','leaky_relu','leaky_relu']
                  }
        self.dnn1 = MLP(dnn1_arch)


        dnn2_arch = {'input_dim':2048 ,
                  'fc_lay': [20],
                  'fc_drop': [0.0], 
                  'fc_use_batchnorm': [False],
                  'fc_use_laynorm': [False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['linear'] # leakyrelu(1) is just identity mapping
                  }
        self.dnn2 = MLP(dnn2_arch)
        
        self.softmax = nn.Softmax(dim = 1)
    def forward(self, X1, X2):
        out1, out2 = self.cnn_net(X1), self.cnn_net(X2)
        out1, out2 = self.dnn1(out1), self.dnn1(out2)
        out1, out2 = self.dnn2(out1), self.dnn2(out2)
        out1, out2 = self.softmax(out1), self.softmax(out2)
        out = torch.stack([out1, out2], dim = 0)
        out,_ = torch.max(out, dim = 0)
        return out

cls = MixedClassifier().cuda(device)
cls.train()
optimizer = torch.optim.Adam(cls.parameters(), 0.001)

In [19]:
batch_size = 32
chunkloader_train = torch.utils.data.DataLoader(chunkset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()



for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (X1, X2, target) in enumerate(tqdm(chunkloader_train)):
        optimizer.zero_grad()
        X1, X2, target = X1.float().cuda(device), X2.float().cuda(device), target.float().cuda(device)

        out = cls(X1, X2)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(cls.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': cls.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/sincnet.pth')


HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[1,   200] loss: 0.189 accuracy: 0.300
[1,   400] loss: 0.187 accuracy: 0.312
[1,   600] loss: 0.190 accuracy: 0.303
[1,   800] loss: 0.186 accuracy: 0.318
[1,  1000] loss: 0.189 accuracy: 0.302
[1,  1200] loss: 0.189 accuracy: 0.300
[1,  1400] loss: 0.188 accuracy: 0.310
[1,  1600] loss: 0.189 accuracy: 0.301
[1,  1800] loss: 0.189 accuracy: 0.302
[1,  2000] loss: 0.189 accuracy: 0.300
[1,  2200] loss: 0.190 accuracy: 0.302
[1,  2400] loss: 0.187 accuracy: 0.308
[1,  2600] loss: 0.189 accuracy: 0.302



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[2,   200] loss: 0.329 accuracy: 0.559
[2,   400] loss: 0.186 accuracy: 0.312
[2,   600] loss: 0.182 accuracy: 0.332
[2,   800] loss: 0.185 accuracy: 0.313
[2,  1000] loss: 0.184 accuracy: 0.322
[2,  1200] loss: 0.184 accuracy: 0.320
[2,  1400] loss: 0.185 accuracy: 0.307
[2,  1600] loss: 0.187 accuracy: 0.305
[2,  1800] loss: 0.185 accuracy: 0.318
[2,  2000] loss: 0.187 accuracy: 0.302
[2,  2200] loss: 0.185 accuracy: 0.312
[2,  2400] loss: 0.186 accuracy: 0.308
[2,  2600] loss: 0.185 accuracy: 0.318



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[3,   200] loss: 0.329 accuracy: 0.553
[3,   400] loss: 0.182 accuracy: 0.328
[3,   600] loss: 0.184 accuracy: 0.316
[3,   800] loss: 0.184 accuracy: 0.322
[3,  1000] loss: 0.180 accuracy: 0.325
[3,  1200] loss: 0.184 accuracy: 0.312
[3,  1400] loss: 0.181 accuracy: 0.331
[3,  1600] loss: 0.183 accuracy: 0.327
[3,  1800] loss: 0.185 accuracy: 0.320
[3,  2000] loss: 0.183 accuracy: 0.323
[3,  2200] loss: 0.182 accuracy: 0.327
[3,  2400] loss: 0.185 accuracy: 0.317
[3,  2600] loss: 0.180 accuracy: 0.329



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[4,   200] loss: 0.322 accuracy: 0.589
[4,   400] loss: 0.183 accuracy: 0.321
[4,   600] loss: 0.182 accuracy: 0.322
[4,   800] loss: 0.182 accuracy: 0.324
[4,  1000] loss: 0.179 accuracy: 0.333
[4,  1200] loss: 0.181 accuracy: 0.329
[4,  1400] loss: 0.184 accuracy: 0.322
[4,  1600] loss: 0.180 accuracy: 0.335
[4,  1800] loss: 0.180 accuracy: 0.332
[4,  2000] loss: 0.179 accuracy: 0.338
[4,  2200] loss: 0.180 accuracy: 0.336
[4,  2400] loss: 0.183 accuracy: 0.319
[4,  2600] loss: 0.180 accuracy: 0.338



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[5,   200] loss: 0.317 accuracy: 0.601
[5,   400] loss: 0.178 accuracy: 0.339
[5,   600] loss: 0.178 accuracy: 0.337
[5,   800] loss: 0.180 accuracy: 0.333
[5,  1000] loss: 0.178 accuracy: 0.345
[5,  1200] loss: 0.180 accuracy: 0.333
[5,  1400] loss: 0.178 accuracy: 0.330
[5,  1600] loss: 0.177 accuracy: 0.343
[5,  1800] loss: 0.181 accuracy: 0.332
[5,  2000] loss: 0.180 accuracy: 0.333
[5,  2200] loss: 0.178 accuracy: 0.340
[5,  2400] loss: 0.181 accuracy: 0.330
[5,  2600] loss: 0.179 accuracy: 0.330



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[6,   200] loss: 0.316 accuracy: 0.596
[6,   400] loss: 0.177 accuracy: 0.344
[6,   600] loss: 0.177 accuracy: 0.344
[6,   800] loss: 0.178 accuracy: 0.335
[6,  1000] loss: 0.176 accuracy: 0.345
[6,  1200] loss: 0.177 accuracy: 0.336
[6,  1400] loss: 0.177 accuracy: 0.344
[6,  1600] loss: 0.175 accuracy: 0.349
[6,  1800] loss: 0.177 accuracy: 0.344
[6,  2000] loss: 0.177 accuracy: 0.341
[6,  2200] loss: 0.175 accuracy: 0.345
[6,  2400] loss: 0.177 accuracy: 0.346
[6,  2600] loss: 0.176 accuracy: 0.349



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[7,   200] loss: 0.311 accuracy: 0.618
[7,   400] loss: 0.174 accuracy: 0.351
[7,   600] loss: 0.174 accuracy: 0.348
[7,   800] loss: 0.175 accuracy: 0.349
[7,  1000] loss: 0.175 accuracy: 0.342
[7,  1200] loss: 0.175 accuracy: 0.343
[7,  1400] loss: 0.173 accuracy: 0.355
[7,  1600] loss: 0.174 accuracy: 0.356
[7,  1800] loss: 0.173 accuracy: 0.353
[7,  2000] loss: 0.174 accuracy: 0.347
[7,  2200] loss: 0.174 accuracy: 0.353
[7,  2400] loss: 0.176 accuracy: 0.346
[7,  2600] loss: 0.175 accuracy: 0.352



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[8,   200] loss: 0.311 accuracy: 0.617
[8,   400] loss: 0.172 accuracy: 0.361
[8,   600] loss: 0.173 accuracy: 0.356
[8,   800] loss: 0.174 accuracy: 0.352
[8,  1000] loss: 0.173 accuracy: 0.359
[8,  1200] loss: 0.172 accuracy: 0.356
[8,  1400] loss: 0.173 accuracy: 0.354
[8,  1600] loss: 0.171 accuracy: 0.369
[8,  1800] loss: 0.173 accuracy: 0.349
[8,  2000] loss: 0.170 accuracy: 0.369
[8,  2200] loss: 0.173 accuracy: 0.354
[8,  2400] loss: 0.173 accuracy: 0.361
[8,  2600] loss: 0.171 accuracy: 0.358



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[9,   200] loss: 0.305 accuracy: 0.648
[9,   400] loss: 0.170 accuracy: 0.367
[9,   600] loss: 0.172 accuracy: 0.364
[9,   800] loss: 0.172 accuracy: 0.366
[9,  1000] loss: 0.171 accuracy: 0.356
[9,  1200] loss: 0.170 accuracy: 0.364
[9,  1400] loss: 0.171 accuracy: 0.357
[9,  1600] loss: 0.169 accuracy: 0.366
[9,  1800] loss: 0.172 accuracy: 0.352
[9,  2000] loss: 0.171 accuracy: 0.363
[9,  2200] loss: 0.171 accuracy: 0.359
[9,  2400] loss: 0.170 accuracy: 0.366
[9,  2600] loss: 0.172 accuracy: 0.357



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[10,   200] loss: 0.303 accuracy: 0.640
[10,   400] loss: 0.168 accuracy: 0.369
[10,   600] loss: 0.168 accuracy: 0.378
[10,   800] loss: 0.171 accuracy: 0.365
[10,  1000] loss: 0.168 accuracy: 0.369
[10,  1200] loss: 0.167 accuracy: 0.375
[10,  1400] loss: 0.169 accuracy: 0.361
[10,  1600] loss: 0.171 accuracy: 0.360
[10,  1800] loss: 0.170 accuracy: 0.366
[10,  2000] loss: 0.169 accuracy: 0.368
[10,  2200] loss: 0.169 accuracy: 0.378
[10,  2400] loss: 0.169 accuracy: 0.373
[10,  2600] loss: 0.168 accuracy: 0.384



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[11,   200] loss: 0.302 accuracy: 0.651
[11,   400] loss: 0.167 accuracy: 0.376
[11,   600] loss: 0.166 accuracy: 0.381
[11,   800] loss: 0.171 accuracy: 0.361
[11,  1000] loss: 0.167 accuracy: 0.375
[11,  1200] loss: 0.167 accuracy: 0.387
[11,  1400] loss: 0.171 accuracy: 0.364
[11,  1600] loss: 0.166 accuracy: 0.385
[11,  1800] loss: 0.168 accuracy: 0.371
[11,  2000] loss: 0.169 accuracy: 0.363
[11,  2200] loss: 0.166 accuracy: 0.382
[11,  2400] loss: 0.168 accuracy: 0.372
[11,  2600] loss: 0.167 accuracy: 0.372



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[12,   200] loss: 0.292 accuracy: 0.677
[12,   400] loss: 0.165 accuracy: 0.384
[12,   600] loss: 0.165 accuracy: 0.383
[12,   800] loss: 0.168 accuracy: 0.375
[12,  1000] loss: 0.168 accuracy: 0.379
[12,  1200] loss: 0.167 accuracy: 0.375
[12,  1400] loss: 0.163 accuracy: 0.392
[12,  1600] loss: 0.166 accuracy: 0.387
[12,  1800] loss: 0.166 accuracy: 0.382
[12,  2000] loss: 0.167 accuracy: 0.369
[12,  2200] loss: 0.166 accuracy: 0.379
[12,  2400] loss: 0.164 accuracy: 0.388
[12,  2600] loss: 0.165 accuracy: 0.385



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[13,   200] loss: 0.291 accuracy: 0.683
[13,   400] loss: 0.163 accuracy: 0.389
[13,   600] loss: 0.167 accuracy: 0.378
[13,   800] loss: 0.166 accuracy: 0.381
[13,  1000] loss: 0.166 accuracy: 0.383
[13,  1200] loss: 0.165 accuracy: 0.383
[13,  1400] loss: 0.164 accuracy: 0.387
[13,  1600] loss: 0.163 accuracy: 0.387
[14,   200] loss: 0.288 accuracy: 0.692
[14,   400] loss: 0.160 accuracy: 0.403
[14,   600] loss: 0.163 accuracy: 0.396
[14,   800] loss: 0.161 accuracy: 0.401
[14,  1000] loss: 0.165 accuracy: 0.385
[14,  1200] loss: 0.165 accuracy: 0.388
[14,  1400] loss: 0.162 accuracy: 0.396
[14,  1600] loss: 0.161 accuracy: 0.399
[14,  1800] loss: 0.164 accuracy: 0.389
[14,  2000] loss: 0.163 accuracy: 0.391
[14,  2200] loss: 0.165 accuracy: 0.387
[14,  2400] loss: 0.162 accuracy: 0.389
[14,  2600] loss: 0.162 accuracy: 0.394



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[16,  1600] loss: 0.158 accuracy: 0.413
[16,  1800] loss: 0.162 accuracy: 0.400
[16,  2000] loss: 0.163 accuracy: 0.397
[16,  2200] loss: 0.161 accuracy: 0.400
[16,  2400] loss: 0.159 accuracy: 0.409
[16,  2600] loss: 0.160 accuracy: 0.401



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[17,   200] loss: 0.283 accuracy: 0.720
[17,   400] loss: 0.157 accuracy: 0.410
[17,   600] loss: 0.158 accuracy: 0.406


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[19,  1400] loss: 0.153 accuracy: 0.430
[19,  1600] loss: 0.155 accuracy: 0.415
[19,  1800] loss: 0.157 accuracy: 0.422
[19,  2000] loss: 0.157 accuracy: 0.406
[19,  2200] loss: 0.156 accuracy: 0.410
[19,  2400] loss: 0.156 accuracy: 0.416
[19,  2600] loss: 0.155 accuracy: 0.423



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[20,   200] loss: 0.275 accuracy: 0.745
[20,   400] loss: 0.155 accuracy: 0.418
[20,   600] loss: 0.152 accuracy: 0.432
[20,   800] loss: 0.151 accuracy: 0.432
[20,  1000] loss: 0.156 accuracy: 0.415
[20,  1200] loss: 0.155 accuracy: 0.426
[20,  1400] loss: 0.153 accuracy: 0.426


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[23,   400] loss: 0.151 accuracy: 0.429
[23,   600] loss: 0.149 accuracy: 0.440
[23,   800] loss: 0.149 accuracy: 0.441
[23,  1000] loss: 0.150 accuracy: 0.432
[23,  1200] loss: 0.150 accuracy: 0.438
[23,  1400] loss: 0.149 accuracy: 0.452
[23,  1600] loss: 0.150 accuracy: 0.434
[23,  1800] loss: 0.152 accuracy: 0.438
[23,  2000] loss: 0.150 accuracy: 0.442
[23,  2200] loss: 0.151 accuracy: 0.438
[23,  2400] loss: 0.151 accuracy: 0.434
[23,  2600] loss: 0.148 accuracy: 0.454



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[24,   200] loss: 0.266 accuracy: 0.771
[24,   400] loss: 0.149 accuracy: 0.440


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[24,  1400] loss: 0.147 accuracy: 0.445
[24,  1600] loss: 0.149 accuracy: 0.448
[24,  1800] loss: 0.149 accuracy: 0.449
[24,  2000] loss: 0.150 accuracy: 0.436
[24,  2200] loss: 0.150 accuracy: 0.444
[24,  2400] loss: 0.151 accuracy: 0.439
[24,  2600] loss: 0.150 accuracy: 0.443



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[25,   200] loss: 0.263 accuracy: 0.797
[25,   400] loss: 0.146 accuracy: 0.454
[25,   600] loss: 0.149 accuracy: 0.441
[25,   800] loss: 0.149 accuracy: 0.444
[25,  1000] loss: 0.149 accuracy: 0.440
[25,  1200] loss: 0.153 accuracy: 0.427
[25,  1400] loss: 0.148 accuracy: 0.449
[25,  1600] loss: 0.146 accuracy: 0.466
[25,  1800] loss: 0.150 accuracy: 0.436
[25,  2000] loss: 0.150 accuracy: 0.442
[25,  2200] loss: 0.147 accuracy: 0.450
[25,  2400] loss: 0.148 accuracy: 0.441
[25,  2600] loss: 0.147 accuracy: 0.454



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[26,   200] loss: 0.262 accuracy: 0.801
[26,   400] loss: 0.149 accuracy: 0.450
[26,   600] loss: 0.149 accuracy: 0.448
[26,   800] loss: 0.149 accuracy: 0.444
[26,  1000] loss: 0.146 accuracy: 0.455
[26,  1200] loss: 0.149 accuracy: 0.440
[26,  1400] loss: 0.146 accuracy: 0.453
[26,  1600] loss: 0.148 accuracy: 0.449
[26,  1800] loss: 0.147 accuracy: 0.457
[26,  2000] loss: 0.147 accuracy: 0.455
[26,  2200] loss: 0.146 accuracy: 0.456
[26,  2400] loss: 0.150 accuracy: 0.440
[26,  2600] loss: 0.146 accuracy: 0.458



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[27,   200] loss: 0.259 accuracy: 0.803
[27,   400] loss: 0.145 accuracy: 0.463
[27,   600] loss: 0.146 accuracy: 0.460
[27,   800] loss: 0.146 accuracy: 0.454
[27,  1000] loss: 0.148 accuracy: 0.442
[27,  1200] loss: 0.146 accuracy: 0.456
[27,  1400] loss: 0.145 accuracy: 0.458
[27,  1600] loss: 0.149 accuracy: 0.444
[27,  1800] loss: 0.146 accuracy: 0.463
[27,  2000] loss: 0.148 accuracy: 0.447
[27,  2200] loss: 0.146 accuracy: 0.458
[27,  2400] loss: 0.147 accuracy: 0.454
[27,  2600] loss: 0.146 accuracy: 0.461



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[28,   200] loss: 0.256 accuracy: 0.824
[28,   400] loss: 0.145 accuracy: 0.457
[28,   600] loss: 0.144 accuracy: 0.463
[28,   800] loss: 0.144 accuracy: 0.461
[28,  1000] loss: 0.142 accuracy: 0.472
[28,  1200] loss: 0.147 accuracy: 0.450
[28,  1400] loss: 0.145 accuracy: 0.464
[28,  1600] loss: 0.148 accuracy: 0.449
[28,  1800] loss: 0.146 accuracy: 0.464
[28,  2000] loss: 0.144 accuracy: 0.466
[28,  2200] loss: 0.145 accuracy: 0.465
[28,  2400] loss: 0.146 accuracy: 0.452
[28,  2600] loss: 0.146 accuracy: 0.460



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[29,   200] loss: 0.255 accuracy: 0.836
[29,   400] loss: 0.142 accuracy: 0.475
[29,   600] loss: 0.144 accuracy: 0.458
[29,   800] loss: 0.142 accuracy: 0.473
[29,  1000] loss: 0.143 accuracy: 0.472
[29,  1200] loss: 0.144 accuracy: 0.467
[29,  1400] loss: 0.145 accuracy: 0.464
[29,  1600] loss: 0.147 accuracy: 0.451
[29,  1800] loss: 0.143 accuracy: 0.467
[29,  2000] loss: 0.145 accuracy: 0.461
[29,  2200] loss: 0.146 accuracy: 0.449
[29,  2400] loss: 0.146 accuracy: 0.449
[29,  2600] loss: 0.145 accuracy: 0.465



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[30,   200] loss: 0.252 accuracy: 0.840
[30,   400] loss: 0.143 accuracy: 0.476
[30,   600] loss: 0.144 accuracy: 0.463
[30,   800] loss: 0.143 accuracy: 0.463
[30,  1000] loss: 0.141 accuracy: 0.477
[30,  1200] loss: 0.142 accuracy: 0.474
[30,  1400] loss: 0.144 accuracy: 0.459
[30,  1600] loss: 0.143 accuracy: 0.469
[30,  1800] loss: 0.145 accuracy: 0.460
[30,  2000] loss: 0.142 accuracy: 0.471
[30,  2200] loss: 0.142 accuracy: 0.472
[30,  2400] loss: 0.146 accuracy: 0.456
[30,  2600] loss: 0.141 accuracy: 0.471



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[31,   200] loss: 0.252 accuracy: 0.840
[31,   400] loss: 0.142 accuracy: 0.469
[31,   600] loss: 0.141 accuracy: 0.473
[31,   800] loss: 0.144 accuracy: 0.468
[31,  1000] loss: 0.141 accuracy: 0.473
[31,  1200] loss: 0.144 accuracy: 0.469
[31,  1400] loss: 0.140 accuracy: 0.477
[31,  1600] loss: 0.141 accuracy: 0.472
[31,  1800] loss: 0.144 accuracy: 0.464
[31,  2000] loss: 0.140 accuracy: 0.482
[31,  2200] loss: 0.143 accuracy: 0.472
[31,  2400] loss: 0.145 accuracy: 0.470
[31,  2600] loss: 0.143 accuracy: 0.467



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[32,   200] loss: 0.251 accuracy: 0.843
[32,   400] loss: 0.139 accuracy: 0.479
[32,   600] loss: 0.141 accuracy: 0.477
[32,   800] loss: 0.143 accuracy: 0.468
[32,  1000] loss: 0.141 accuracy: 0.477
[32,  1200] loss: 0.142 accuracy: 0.470
[32,  1400] loss: 0.144 accuracy: 0.463
[32,  1600] loss: 0.142 accuracy: 0.475
[32,  1800] loss: 0.141 accuracy: 0.478
[32,  2000] loss: 0.141 accuracy: 0.472
[32,  2200] loss: 0.141 accuracy: 0.470
[32,  2400] loss: 0.141 accuracy: 0.478
[32,  2600] loss: 0.139 accuracy: 0.479



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[33,   200] loss: 0.251 accuracy: 0.841
[33,   400] loss: 0.138 accuracy: 0.486
[33,   600] loss: 0.139 accuracy: 0.485
[33,   800] loss: 0.141 accuracy: 0.482
[33,  1000] loss: 0.139 accuracy: 0.481
[33,  1200] loss: 0.139 accuracy: 0.491
[33,  1400] loss: 0.139 accuracy: 0.478
[33,  1600] loss: 0.140 accuracy: 0.469
[33,  1800] loss: 0.140 accuracy: 0.482
[33,  2000] loss: 0.140 accuracy: 0.485
[33,  2200] loss: 0.140 accuracy: 0.480
[33,  2400] loss: 0.143 accuracy: 0.462
[33,  2600] loss: 0.140 accuracy: 0.475



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[34,   200] loss: 0.246 accuracy: 0.856
[34,   400] loss: 0.141 accuracy: 0.476
[34,   600] loss: 0.141 accuracy: 0.476
[34,   800] loss: 0.140 accuracy: 0.474
[34,  1000] loss: 0.139 accuracy: 0.492
[34,  1200] loss: 0.137 accuracy: 0.492
[34,  1400] loss: 0.139 accuracy: 0.486
[34,  1600] loss: 0.137 accuracy: 0.484
[34,  1800] loss: 0.141 accuracy: 0.482
[34,  2000] loss: 0.139 accuracy: 0.482
[34,  2200] loss: 0.139 accuracy: 0.487
[34,  2400] loss: 0.140 accuracy: 0.480
[34,  2600] loss: 0.139 accuracy: 0.485



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[35,   200] loss: 0.247 accuracy: 0.856
[35,   400] loss: 0.137 accuracy: 0.490
[35,   600] loss: 0.137 accuracy: 0.488
[35,   800] loss: 0.138 accuracy: 0.478
[35,  1000] loss: 0.137 accuracy: 0.495
[35,  1200] loss: 0.140 accuracy: 0.477
[35,  1400] loss: 0.138 accuracy: 0.489
[35,  1600] loss: 0.138 accuracy: 0.488
[35,  1800] loss: 0.139 accuracy: 0.488
[35,  2000] loss: 0.139 accuracy: 0.483
[35,  2200] loss: 0.137 accuracy: 0.492
[35,  2400] loss: 0.140 accuracy: 0.477
[35,  2600] loss: 0.137 accuracy: 0.484



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[36,   200] loss: 0.246 accuracy: 0.873
[36,   400] loss: 0.138 accuracy: 0.490
[36,   600] loss: 0.137 accuracy: 0.497
[36,   800] loss: 0.139 accuracy: 0.491
[36,  1000] loss: 0.137 accuracy: 0.489
[36,  1200] loss: 0.136 accuracy: 0.496
[36,  1400] loss: 0.137 accuracy: 0.496
[36,  1600] loss: 0.137 accuracy: 0.490
[36,  1800] loss: 0.140 accuracy: 0.481
[36,  2000] loss: 0.136 accuracy: 0.496
[36,  2200] loss: 0.136 accuracy: 0.502
[36,  2400] loss: 0.137 accuracy: 0.492
[36,  2600] loss: 0.137 accuracy: 0.491



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[37,   200] loss: 0.241 accuracy: 0.871
[37,   400] loss: 0.135 accuracy: 0.497
[37,   600] loss: 0.134 accuracy: 0.500
[37,   800] loss: 0.137 accuracy: 0.494
[37,  1000] loss: 0.138 accuracy: 0.487
[37,  1200] loss: 0.138 accuracy: 0.480
[37,  1400] loss: 0.135 accuracy: 0.503
[37,  1600] loss: 0.136 accuracy: 0.495
[37,  1800] loss: 0.136 accuracy: 0.494
[37,  2000] loss: 0.136 accuracy: 0.494
[37,  2200] loss: 0.135 accuracy: 0.499
[37,  2400] loss: 0.137 accuracy: 0.493
[37,  2600] loss: 0.136 accuracy: 0.490



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[38,   200] loss: 0.240 accuracy: 0.883
[38,   400] loss: 0.137 accuracy: 0.494
[38,   600] loss: 0.137 accuracy: 0.493
[38,   800] loss: 0.137 accuracy: 0.489
[38,  1000] loss: 0.137 accuracy: 0.496
[38,  1200] loss: 0.135 accuracy: 0.499
[38,  1400] loss: 0.135 accuracy: 0.505
[38,  1600] loss: 0.134 accuracy: 0.507
[38,  1800] loss: 0.138 accuracy: 0.493
[38,  2000] loss: 0.136 accuracy: 0.495
[38,  2200] loss: 0.137 accuracy: 0.496
[38,  2400] loss: 0.135 accuracy: 0.494
[38,  2600] loss: 0.137 accuracy: 0.497



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[39,   200] loss: 0.242 accuracy: 0.891
[39,   400] loss: 0.133 accuracy: 0.507
[39,   600] loss: 0.134 accuracy: 0.498
[39,   800] loss: 0.137 accuracy: 0.492
[39,  1000] loss: 0.135 accuracy: 0.503
[39,  1200] loss: 0.137 accuracy: 0.486
[39,  1400] loss: 0.133 accuracy: 0.504
[39,  1600] loss: 0.135 accuracy: 0.499
[39,  1800] loss: 0.136 accuracy: 0.498
[39,  2000] loss: 0.134 accuracy: 0.506
[39,  2200] loss: 0.135 accuracy: 0.495
[39,  2400] loss: 0.132 accuracy: 0.504
[39,  2600] loss: 0.134 accuracy: 0.502



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[40,   200] loss: 0.242 accuracy: 0.882
[40,   400] loss: 0.134 accuracy: 0.508
[40,   600] loss: 0.133 accuracy: 0.508
[40,   800] loss: 0.135 accuracy: 0.506
[40,  1000] loss: 0.133 accuracy: 0.506
[40,  1200] loss: 0.136 accuracy: 0.495
[40,  1400] loss: 0.134 accuracy: 0.501
[40,  1600] loss: 0.132 accuracy: 0.511
[40,  1800] loss: 0.131 accuracy: 0.517
[40,  2000] loss: 0.134 accuracy: 0.499
[40,  2200] loss: 0.133 accuracy: 0.508
[40,  2400] loss: 0.135 accuracy: 0.505
[40,  2600] loss: 0.134 accuracy: 0.500



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[41,   200] loss: 0.235 accuracy: 0.905
[41,   400] loss: 0.134 accuracy: 0.512
[41,   600] loss: 0.131 accuracy: 0.516
[41,   800] loss: 0.135 accuracy: 0.502
[41,  1000] loss: 0.134 accuracy: 0.507
[41,  1200] loss: 0.130 accuracy: 0.522
[41,  1400] loss: 0.135 accuracy: 0.502
[41,  1600] loss: 0.135 accuracy: 0.495
[41,  1800] loss: 0.131 accuracy: 0.519
[41,  2000] loss: 0.134 accuracy: 0.499
[41,  2200] loss: 0.135 accuracy: 0.507
[41,  2400] loss: 0.133 accuracy: 0.512
[41,  2600] loss: 0.133 accuracy: 0.510



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[42,   200] loss: 0.234 accuracy: 0.917
[42,   400] loss: 0.132 accuracy: 0.522
[42,   600] loss: 0.132 accuracy: 0.511
[42,   800] loss: 0.130 accuracy: 0.516
[42,  1000] loss: 0.132 accuracy: 0.509
[42,  1200] loss: 0.135 accuracy: 0.508
[42,  1400] loss: 0.133 accuracy: 0.508
[42,  1600] loss: 0.135 accuracy: 0.501
[42,  1800] loss: 0.133 accuracy: 0.507
[42,  2000] loss: 0.133 accuracy: 0.512
[42,  2200] loss: 0.133 accuracy: 0.502
[42,  2400] loss: 0.131 accuracy: 0.514
[42,  2600] loss: 0.132 accuracy: 0.511



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[43,   200] loss: 0.232 accuracy: 0.919
[43,   400] loss: 0.131 accuracy: 0.512
[43,   600] loss: 0.132 accuracy: 0.508
[43,   800] loss: 0.131 accuracy: 0.511
[43,  1000] loss: 0.131 accuracy: 0.510
[43,  1200] loss: 0.133 accuracy: 0.504
[43,  1400] loss: 0.132 accuracy: 0.514
[43,  1600] loss: 0.133 accuracy: 0.510
[43,  1800] loss: 0.133 accuracy: 0.502
[43,  2000] loss: 0.132 accuracy: 0.513
[43,  2200] loss: 0.132 accuracy: 0.518
[43,  2400] loss: 0.132 accuracy: 0.512
[43,  2600] loss: 0.131 accuracy: 0.516



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[44,   200] loss: 0.233 accuracy: 0.928
[44,   400] loss: 0.130 accuracy: 0.515
[44,   600] loss: 0.131 accuracy: 0.510
[44,   800] loss: 0.130 accuracy: 0.524
[44,  1000] loss: 0.131 accuracy: 0.520
[44,  1200] loss: 0.129 accuracy: 0.526
[44,  1400] loss: 0.131 accuracy: 0.521
[44,  1600] loss: 0.134 accuracy: 0.508
[44,  1800] loss: 0.132 accuracy: 0.513
[44,  2000] loss: 0.133 accuracy: 0.513
[44,  2200] loss: 0.131 accuracy: 0.513
[44,  2400] loss: 0.132 accuracy: 0.510
[44,  2600] loss: 0.129 accuracy: 0.523



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[45,   200] loss: 0.228 accuracy: 0.937
[45,   400] loss: 0.131 accuracy: 0.521
[45,   600] loss: 0.131 accuracy: 0.518
[45,   800] loss: 0.132 accuracy: 0.512
[45,  1000] loss: 0.130 accuracy: 0.521
[45,  1200] loss: 0.131 accuracy: 0.515
[45,  1400] loss: 0.129 accuracy: 0.523
[45,  1600] loss: 0.131 accuracy: 0.516
[45,  1800] loss: 0.129 accuracy: 0.526
[45,  2000] loss: 0.129 accuracy: 0.524
[45,  2200] loss: 0.129 accuracy: 0.518
[45,  2400] loss: 0.132 accuracy: 0.516
[45,  2600] loss: 0.130 accuracy: 0.515



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[46,   200] loss: 0.230 accuracy: 0.929
[46,   400] loss: 0.126 accuracy: 0.532
[46,   600] loss: 0.130 accuracy: 0.521
[46,   800] loss: 0.132 accuracy: 0.512
[46,  1000] loss: 0.129 accuracy: 0.525
[46,  1200] loss: 0.129 accuracy: 0.524
[46,  1400] loss: 0.127 accuracy: 0.537
[46,  1600] loss: 0.132 accuracy: 0.516
[46,  1800] loss: 0.130 accuracy: 0.523
[46,  2000] loss: 0.130 accuracy: 0.525
[46,  2200] loss: 0.129 accuracy: 0.534
[46,  2400] loss: 0.130 accuracy: 0.514
[46,  2600] loss: 0.129 accuracy: 0.522



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[47,   200] loss: 0.230 accuracy: 0.929
[47,   400] loss: 0.127 accuracy: 0.535
[47,   600] loss: 0.129 accuracy: 0.535
[47,   800] loss: 0.129 accuracy: 0.519
[47,  1000] loss: 0.127 accuracy: 0.532
[47,  1200] loss: 0.131 accuracy: 0.521
[47,  1400] loss: 0.128 accuracy: 0.527
[47,  1600] loss: 0.130 accuracy: 0.516
[47,  1800] loss: 0.129 accuracy: 0.523
[47,  2000] loss: 0.127 accuracy: 0.535
[47,  2200] loss: 0.130 accuracy: 0.526
[47,  2400] loss: 0.129 accuracy: 0.527
[47,  2600] loss: 0.127 accuracy: 0.530



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[48,   200] loss: 0.224 accuracy: 0.948
[48,   400] loss: 0.126 accuracy: 0.541
[48,   600] loss: 0.128 accuracy: 0.527
[48,   800] loss: 0.127 accuracy: 0.539
[48,  1000] loss: 0.127 accuracy: 0.537
[48,  1200] loss: 0.130 accuracy: 0.528
[48,  1400] loss: 0.128 accuracy: 0.524
[48,  1600] loss: 0.129 accuracy: 0.533
[48,  1800] loss: 0.130 accuracy: 0.519
[48,  2000] loss: 0.128 accuracy: 0.533
[48,  2200] loss: 0.129 accuracy: 0.518
[48,  2400] loss: 0.131 accuracy: 0.515
[48,  2600] loss: 0.129 accuracy: 0.525



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[49,   200] loss: 0.226 accuracy: 0.948
[49,   400] loss: 0.127 accuracy: 0.534
[49,   600] loss: 0.125 accuracy: 0.539
[49,   800] loss: 0.128 accuracy: 0.534
[49,  1000] loss: 0.129 accuracy: 0.525
[49,  1200] loss: 0.127 accuracy: 0.534
[49,  1400] loss: 0.128 accuracy: 0.527
[49,  1600] loss: 0.128 accuracy: 0.527
[49,  1800] loss: 0.128 accuracy: 0.534
[49,  2000] loss: 0.126 accuracy: 0.540
[49,  2200] loss: 0.130 accuracy: 0.523
[49,  2400] loss: 0.130 accuracy: 0.517
[49,  2600] loss: 0.127 accuracy: 0.529



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[50,   200] loss: 0.228 accuracy: 0.926
[50,   400] loss: 0.127 accuracy: 0.535
[50,   600] loss: 0.125 accuracy: 0.537
[50,   800] loss: 0.127 accuracy: 0.528
[50,  1000] loss: 0.126 accuracy: 0.539
[50,  1200] loss: 0.126 accuracy: 0.537
[50,  1400] loss: 0.126 accuracy: 0.535
[50,  1600] loss: 0.127 accuracy: 0.531
[50,  1800] loss: 0.126 accuracy: 0.540
[50,  2000] loss: 0.127 accuracy: 0.526
[50,  2200] loss: 0.127 accuracy: 0.531
[50,  2400] loss: 0.130 accuracy: 0.527
[50,  2600] loss: 0.127 accuracy: 0.540



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[51,   200] loss: 0.223 accuracy: 0.950
[51,   400] loss: 0.125 accuracy: 0.545
[51,   600] loss: 0.128 accuracy: 0.528
[51,   800] loss: 0.127 accuracy: 0.528
[51,  1000] loss: 0.124 accuracy: 0.549
[51,  1200] loss: 0.127 accuracy: 0.531
[51,  1400] loss: 0.125 accuracy: 0.543
[51,  1600] loss: 0.127 accuracy: 0.531
[51,  1800] loss: 0.125 accuracy: 0.543
[51,  2000] loss: 0.125 accuracy: 0.539
[51,  2200] loss: 0.126 accuracy: 0.540
[51,  2400] loss: 0.128 accuracy: 0.536
[51,  2600] loss: 0.128 accuracy: 0.527



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[52,   200] loss: 0.223 accuracy: 0.958
[52,   400] loss: 0.124 accuracy: 0.546
[52,   600] loss: 0.121 accuracy: 0.560
[52,   800] loss: 0.127 accuracy: 0.549
[52,  1000] loss: 0.124 accuracy: 0.547
[52,  1200] loss: 0.126 accuracy: 0.537
[52,  1400] loss: 0.126 accuracy: 0.538
[52,  1600] loss: 0.126 accuracy: 0.540
[52,  1800] loss: 0.126 accuracy: 0.535
[52,  2000] loss: 0.124 accuracy: 0.543
[52,  2200] loss: 0.126 accuracy: 0.540
[52,  2400] loss: 0.125 accuracy: 0.552
[52,  2600] loss: 0.125 accuracy: 0.544



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[53,   200] loss: 0.222 accuracy: 0.957
[53,   400] loss: 0.123 accuracy: 0.547
[53,   600] loss: 0.125 accuracy: 0.540
[53,   800] loss: 0.125 accuracy: 0.543
[53,  1000] loss: 0.124 accuracy: 0.543
[53,  1200] loss: 0.125 accuracy: 0.536
[53,  1400] loss: 0.123 accuracy: 0.549
[53,  1600] loss: 0.124 accuracy: 0.540
[53,  1800] loss: 0.125 accuracy: 0.545
[53,  2000] loss: 0.124 accuracy: 0.555
[53,  2200] loss: 0.126 accuracy: 0.537
[53,  2400] loss: 0.125 accuracy: 0.537
[53,  2600] loss: 0.125 accuracy: 0.539



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[54,   200] loss: 0.220 accuracy: 0.975
[54,   400] loss: 0.124 accuracy: 0.538
[54,   600] loss: 0.122 accuracy: 0.553
[54,   800] loss: 0.125 accuracy: 0.544
[54,  1000] loss: 0.126 accuracy: 0.540
[54,  1200] loss: 0.125 accuracy: 0.545
[54,  1400] loss: 0.123 accuracy: 0.550
[54,  1600] loss: 0.125 accuracy: 0.538
[54,  1800] loss: 0.126 accuracy: 0.536
[54,  2000] loss: 0.125 accuracy: 0.541
[54,  2200] loss: 0.122 accuracy: 0.558
[54,  2400] loss: 0.123 accuracy: 0.546
[54,  2600] loss: 0.122 accuracy: 0.553



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[55,   200] loss: 0.220 accuracy: 0.970
[55,   400] loss: 0.123 accuracy: 0.548
[55,   600] loss: 0.122 accuracy: 0.551
[55,   800] loss: 0.124 accuracy: 0.551
[55,  1000] loss: 0.120 accuracy: 0.560
[55,  1200] loss: 0.123 accuracy: 0.547
[55,  1400] loss: 0.122 accuracy: 0.546
[55,  1600] loss: 0.126 accuracy: 0.534
[55,  1800] loss: 0.123 accuracy: 0.544
[55,  2000] loss: 0.125 accuracy: 0.539
[55,  2200] loss: 0.123 accuracy: 0.552
[55,  2400] loss: 0.124 accuracy: 0.546
[55,  2600] loss: 0.123 accuracy: 0.547



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[56,   200] loss: 0.215 accuracy: 0.989
[56,   400] loss: 0.121 accuracy: 0.558
[56,   600] loss: 0.122 accuracy: 0.552
[56,   800] loss: 0.121 accuracy: 0.563
[56,  1000] loss: 0.124 accuracy: 0.546
[56,  1200] loss: 0.124 accuracy: 0.542
[56,  1400] loss: 0.122 accuracy: 0.555
[56,  1600] loss: 0.125 accuracy: 0.540
[56,  1800] loss: 0.124 accuracy: 0.541
[56,  2000] loss: 0.123 accuracy: 0.545
[56,  2200] loss: 0.122 accuracy: 0.553
[56,  2400] loss: 0.123 accuracy: 0.547
[56,  2600] loss: 0.122 accuracy: 0.553



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[57,   200] loss: 0.215 accuracy: 0.979
[57,   400] loss: 0.121 accuracy: 0.561
[57,   600] loss: 0.121 accuracy: 0.555
[57,   800] loss: 0.121 accuracy: 0.561
[57,  1000] loss: 0.121 accuracy: 0.556
[57,  1200] loss: 0.123 accuracy: 0.551
[57,  1400] loss: 0.121 accuracy: 0.553
[57,  1600] loss: 0.122 accuracy: 0.555
[57,  1800] loss: 0.123 accuracy: 0.554
[57,  2000] loss: 0.120 accuracy: 0.568
[57,  2200] loss: 0.121 accuracy: 0.558
[57,  2400] loss: 0.121 accuracy: 0.562
[57,  2600] loss: 0.119 accuracy: 0.566



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[58,   200] loss: 0.216 accuracy: 0.988
[58,   400] loss: 0.121 accuracy: 0.559
[58,   600] loss: 0.121 accuracy: 0.563
[58,   800] loss: 0.123 accuracy: 0.553
[58,  1000] loss: 0.120 accuracy: 0.555
[58,  1200] loss: 0.123 accuracy: 0.551
[58,  1400] loss: 0.121 accuracy: 0.559
[58,  1600] loss: 0.120 accuracy: 0.553
[58,  1800] loss: 0.121 accuracy: 0.558
[58,  2000] loss: 0.123 accuracy: 0.547
[58,  2200] loss: 0.120 accuracy: 0.564
[58,  2400] loss: 0.120 accuracy: 0.559
[58,  2600] loss: 0.121 accuracy: 0.560



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[59,   200] loss: 0.213 accuracy: 1.000
[59,   400] loss: 0.121 accuracy: 0.562
[59,   600] loss: 0.118 accuracy: 0.569
[59,   800] loss: 0.123 accuracy: 0.554
[59,  1000] loss: 0.121 accuracy: 0.553
[59,  1200] loss: 0.122 accuracy: 0.554
[59,  1400] loss: 0.119 accuracy: 0.566
[59,  1600] loss: 0.121 accuracy: 0.561
[59,  1800] loss: 0.121 accuracy: 0.557
[59,  2000] loss: 0.120 accuracy: 0.555
[59,  2200] loss: 0.120 accuracy: 0.559
[59,  2400] loss: 0.121 accuracy: 0.551
[59,  2600] loss: 0.121 accuracy: 0.560



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[60,   200] loss: 0.212 accuracy: 0.999
[60,   400] loss: 0.120 accuracy: 0.562
[60,   600] loss: 0.120 accuracy: 0.562
[60,   800] loss: 0.123 accuracy: 0.545
[60,  1000] loss: 0.121 accuracy: 0.562
[60,  1200] loss: 0.121 accuracy: 0.552
[60,  1400] loss: 0.121 accuracy: 0.560
[60,  1600] loss: 0.123 accuracy: 0.557
[60,  1800] loss: 0.123 accuracy: 0.551
[60,  2000] loss: 0.118 accuracy: 0.575
[60,  2200] loss: 0.121 accuracy: 0.559
[60,  2400] loss: 0.120 accuracy: 0.557
[60,  2600] loss: 0.120 accuracy: 0.561



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[61,   200] loss: 0.214 accuracy: 1.001
[61,   400] loss: 0.122 accuracy: 0.552
[61,   600] loss: 0.120 accuracy: 0.561
[61,   800] loss: 0.122 accuracy: 0.556
[61,  1000] loss: 0.118 accuracy: 0.568
[61,  1200] loss: 0.121 accuracy: 0.559
[61,  1400] loss: 0.119 accuracy: 0.563
[61,  1600] loss: 0.117 accuracy: 0.573
[61,  1800] loss: 0.118 accuracy: 0.570
[61,  2000] loss: 0.120 accuracy: 0.559
[61,  2200] loss: 0.119 accuracy: 0.573
[61,  2400] loss: 0.119 accuracy: 0.565
[61,  2600] loss: 0.120 accuracy: 0.558



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[62,   200] loss: 0.216 accuracy: 0.996
[62,   400] loss: 0.117 accuracy: 0.575
[62,   600] loss: 0.122 accuracy: 0.563
[62,   800] loss: 0.121 accuracy: 0.564
[62,  1000] loss: 0.117 accuracy: 0.577
[62,  1200] loss: 0.119 accuracy: 0.564
[62,  1400] loss: 0.117 accuracy: 0.572
[62,  1600] loss: 0.120 accuracy: 0.560
[62,  1800] loss: 0.117 accuracy: 0.571
[62,  2000] loss: 0.120 accuracy: 0.559
[62,  2200] loss: 0.121 accuracy: 0.553
[62,  2400] loss: 0.119 accuracy: 0.571
[62,  2600] loss: 0.121 accuracy: 0.553



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[63,   200] loss: 0.211 accuracy: 1.012
[63,   400] loss: 0.117 accuracy: 0.567
[63,   600] loss: 0.119 accuracy: 0.555
[63,   800] loss: 0.117 accuracy: 0.576
[63,  1000] loss: 0.119 accuracy: 0.568
[63,  1200] loss: 0.118 accuracy: 0.575
[63,  1400] loss: 0.118 accuracy: 0.566
[63,  1600] loss: 0.117 accuracy: 0.574
[63,  1800] loss: 0.119 accuracy: 0.571
[63,  2000] loss: 0.118 accuracy: 0.570
[63,  2200] loss: 0.120 accuracy: 0.563
[63,  2400] loss: 0.117 accuracy: 0.573
[63,  2600] loss: 0.119 accuracy: 0.569



HBox(children=(FloatProgress(value=0.0, max=2755.0), HTML(value='')))

[64,   200] loss: 0.211 accuracy: 0.998
[64,   400] loss: 0.118 accuracy: 0.575
[64,   600] loss: 0.117 accuracy: 0.568
[64,   800] loss: 0.119 accuracy: 0.563
[64,  1000] loss: 0.119 accuracy: 0.567
[64,  1200] loss: 0.118 accuracy: 0.570
[64,  1400] loss: 0.118 accuracy: 0.572
[64,  1600] loss: 0.119 accuracy: 0.568
[64,  1800] loss: 0.118 accuracy: 0.569
[64,  2000] loss: 0.117 accuracy: 0.573
[64,  2200] loss: 0.117 accuracy: 0.573
[64,  2400] loss: 0.119 accuracy: 0.566
[64,  2600] loss: 0.117 accuracy: 0.577



In [None]:
chunkloader_val = torch.utils.data.DataLoader(chunkset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)


In [52]:
def create_batches_rnd(batch_size,data_folder,wav_lst,N_snt,wlen,lab_dict,fact_amp):
    
 # Initialization of the minibatch (batch_size,[0=>x_t,1=>x_t+N,1=>random_samp])
 sig_batch=np.zeros([batch_size,wlen])
 lab_batch=np.zeros(batch_size)
  
 snt_id_arr=np.random.randint(N_snt, size=batch_size)
 
 rand_amp_arr = np.random.uniform(1.0-fact_amp,1+fact_amp,batch_size)

 for i in range(batch_size):
     
  # select a random sentence from the list  (joint distribution)
  [fs,signal]=scipy.io.wavfile.read(data_folder+wav_lst[snt_id_arr[i]])
  signal=signal.astype(float)/32768

  # accesing to a random chunk
  snt_len=signal.shape[0]
  snt_beg=np.random.randint(snt_len-wlen-1) #randint(0, snt_len-2*wlen-1)
  snt_end=snt_beg+wlen
  
  sig_batch[i,:]=signal[snt_beg:snt_end]*rand_amp_arr[i]
  lab_batch[i]=lab_dict[wav_lst[snt_id_arr[i]]]
  
 inp=torch.from_numpy(sig_batch).float().cuda().contiguous()  # Current Frame
 lab=torch.from_numpy(lab_batch).float().cuda().contiguous()
  
 return inp,lab  

In [None]:
# Full Validation  new  
  if epoch%N_eval_epoch==0:
      
   CNN_net.eval()
   DNN1_net.eval()
   DNN2_net.eval()
   test_flag=1 
   loss_sum=0
   err_sum=0
   err_sum_snt=0
   
   with torch.no_grad():  
    for i in range(snt_te):
       
     #[fs,signal]=scipy.io.wavfile.read(data_folder+wav_lst_te[i])
     #signal=signal.astype(float)/32768

     [signal, fs] = sf.read(data_folder+wav_lst_te[i])

     signal=torch.from_numpy(signal).float().cuda().contiguous()
     lab_batch=lab_dict[wav_lst_te[i]]
    
     # split signals into chunks
     beg_samp=0
     end_samp=wlen
     
     N_fr=int((signal.shape[0]-wlen)/(wshift))
     

     sig_arr=torch.zeros([Batch_dev,wlen]).float().cuda().contiguous()
     lab= Variable((torch.zeros(N_fr+1)+lab_batch).cuda().contiguous().long())
     pout=Variable(torch.zeros(N_fr+1,class_lay[-1]).float().cuda().contiguous())
     count_fr=0
     count_fr_tot=0
     while end_samp<signal.shape[0]:
         sig_arr[count_fr,:]=signal[beg_samp:end_samp]
         beg_samp=beg_samp+wshift
         end_samp=beg_samp+wlen
         count_fr=count_fr+1
         count_fr_tot=count_fr_tot+1
         if count_fr==Batch_dev:
             inp=Variable(sig_arr)
             pout[count_fr_tot-Batch_dev:count_fr_tot,:]=DNN2_net(DNN1_net(CNN_net(inp)))
             count_fr=0
             sig_arr=torch.zeros([Batch_dev,wlen]).float().cuda().contiguous()
   
     if count_fr>0:
      inp=Variable(sig_arr[0:count_fr])
      pout[count_fr_tot-count_fr:count_fr_tot,:]=DNN2_net(DNN1_net(CNN_net(inp)))

    
     pred=torch.max(pout,dim=1)[1]
     loss = cost(pout, lab.long())
     err = torch.mean((pred!=lab.long()).float())
    
     [val,best_class]=torch.max(torch.sum(pout,dim=0),0)
     err_sum_snt=err_sum_snt+(best_class!=lab[0]).float()
    
    
     loss_sum=loss_sum+loss.detach()
     err_sum=err_sum+err.detach()
    
    err_tot_dev_snt=err_sum_snt/snt_te
    loss_tot_dev=loss_sum/snt_te
    err_tot_dev=err_sum/snt_te

  
   print("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt))
  
   with open(output_folder+"/res.res", "a") as res_file:
    res_file.write("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f\n" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt))   

   checkpoint={'CNN_model_par': CNN_net.state_dict(),
               'DNN1_model_par': DNN1_net.state_dict(),
               'DNN2_model_par': DNN2_net.state_dict(),
               }
   torch.save(checkpoint,output_folder+'/model_raw.pkl')
  
  else:
   print("epoch %i, loss_tr=%f err_tr=%f" % (epoch, loss_tot,err_tot))