In [1]:
import os
from pathlib import Path
import sys
from tqdm import tqdm
import numpy as np
import glob
from nnabla.utils.data_source import DataSource
from nnabla.utils.data_iterator import data_iterator


class DatasetGenerator:
    r""" Splits the Dataset and Labels into batches containing n_utterances for each of the n_speakers

    """

    def __init__(self):

        self.dev_dir = "./data/LibriSpeech/dev-clean"
        self.test_dir = "./data/LibriSpeech/test-clean"
        self.label_dir = "./data/LibriSpeech/labels/"
        self.mfcc_dev_dir = "./data/LibriSpeech/mfcc-dev/"
        self.mfcc_test_dir = "./data/LibriSpeech/mfcc-test/"
        self.n_utterances = 5  # utterances per speaker in batch
        self.batch_size = 200
        self.n_timesteps = 101
        self.n_features = 40
        self.ys = np.load(self.label_dir + 's_id.npy')
        self.xs = []
        mfccs = sorted(glob.glob(self.mfcc_dev_dir + '*.npy'))
        for i in tqdm(range(len(mfccs))):
            self.xs.append((np.load(mfccs[i])).T)
        
        self.n_data = len(self.xs)

    # Batch_size =200 , each batch will contain 5 utterances from each of the 40 speakers

    def get_features(self):
        ys = np.load(self.label_dir + 's_id.npy')

        ''' 
        Get total utterances for each speaker
        '''
        inputs = {k: [] for k in sorted(set(ys))}

        for i, s_id in enumerate(ys):
            inputs[s_id].append(self.xs[i])

        return inputs

    def prep_dataset(self):    
        mfccs = self.get_features()
        n_iters = self.n_data // self.batch_size
        total_data = n_iters*self.batch_size
        dataset = np.zeros((total_data,self.n_timesteps,self.n_features))
        for sid in mfccs:
            np.random.shuffle(mfccs[sid])            
        for i in range(0, total_data, self.batch_size):            
            for sid in mfccs:
                for j in range(self.n_utterances):                    
                    dataset[i+j] = mfccs[sid][j]
                i = i+j+1            
                mfccs[sid] = np.delete(mfccs[sid], range(self.n_utterances), axis=0)
                mfccs[sid] = np.append(mfccs[sid], dataset[-self.n_utterances:], axis=0)
        return dataset

2021-01-24 22:08:17,108 [nnabla][INFO]: Initializing CPU extension...


In [2]:
dg = DatasetGenerator()
dataset = dg.prep_dataset()

100%|██████████████████████████████████████████████████████████████████████████████| 7310/7310 [01:15<00:00, 97.19it/s]


In [3]:
import os
from pathlib import Path
import sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import nnabla as nn
import nnabla.parametric_functions as PF
import nnabla.functions as F
import nnabla.solvers as S
from nnabla.utils.data_iterator import data_iterator_simple
import nnabla.monitor as M
import nnabla.initializer as I

In [6]:
class Encoder:

    def __init__(self,dataset):
        '''
        Model Params (to be moved to hparams.py)
        '''
        self.n_utterances = 5
        self.n_speakers = 40
        self.batch_size = 200
        self.lstm_layers = 3
        self.lstm_hidden = 256
        self.lstm_directions = 1
        self.affine_hidden = 256
        self.embed_size = 256
        self.eps = 1e-6
        self.dataset = dataset 
#         DatasetGenerator().split_dataset()

        self.sim_weight = 0.0 
        self.sim_bias = 0.0
        self.sim_matrix = 0.0 
        self.eps = 1e-5

    def get_batch(self,i):
        '''
        Divide the dataset into batches containing 'm' utterances for 'n' speakers
        '''

        batch_size = self.batch_size
        batch_data = self.dataset[i*batch_size:(i+1)*batch_size]
#         self.dataset = np.delete(self.dataset, range(batch_size))
#         self.dataset = np.concatenate((self.dataset,batch_data), axis = 0)
        # convert shape from (B,T,I)  -> (T,B,I)  ((200, 138, 13) -> (138, 200, 13))
        batch_data = np.transpose(batch_data, (1, 0, 2))

        return batch_data

    def encoder(self, inputs, training=True):
        '''
        Speaker Encoder network
        '''

        L = self.lstm_layers
        D = self.lstm_directions
        B = self.batch_size
        H = self.lstm_hidden
        T, _, I = inputs.shape
        h = nn.Variable(shape=(L, D, B, H), need_grad=True)
        c = nn.Variable(shape=(L, D, B, H), need_grad=True)
        h.d = np.zeros(h.shape)
        c.d = np.zeros(c.shape)

        y, hn, cn = PF.lstm(inputs, h, c, num_layers=L, training=training)
        
        out = PF.affine(hn[-1][0], self.affine_hidden)
        out = F.relu(out)
        out = out/(F.norm(out,axis=1).reshape((-1,1)) + self.eps)
        
#        print(out.d)

        return out

    def similarity_matrix(self,emb):
        N = self.n_speakers
        M = self.n_utterances
        P = self.embed_size
        
        emb_re = emb.reshape((N, M, P))

        # compute the inclusion centroids
        cen = F.mean(emb_re, axis=1) 
        cen = cen / F.norm(cen, axis=1, keepdims=True)
        cen = F.reshape(cen, (-1, P))

        # compute the exclusion centroids
        exc = F.sum(emb_re, axis=1, keepdims=True) - emb_re
        exc = exc / (M - 1)
        exc = exc / F.norm(exc, axis=2, keepdims=True)
        exc = F.reshape(exc, emb.shape)

        diag = F.sum(exc * emb, axis=1, keepdims=True) # 20 x 1
        sim = F.affine(emb,  F.transpose(cen, (1, 0))) 

        mask = np.concatenate([np.tile(w, (M, 1)) for w in np.eye(N)])
        mask = nn.Variable.from_numpy_array(mask)
        sm = (1 - mask) * sim + F.tile(diag, N) * mask
        #sm = self.sim_weight*sm + self.sim_bias

        sm = F.add_scalar(F.mul_scalar(sm, self.sim_weight.d), self.sim_bias.d)
        
        return sm

#     def similarity_matrix(self, embeddings):
#         """
#             Computes the similarity matrix according the section 2.1 of GE2E.
#         """
#         N = self.n_speakers
#         M = self.n_utterances
#         P = self.embed_size
#         embeddings = embeddings.reshape(
#             (self.n_speakers, self.n_utterances, self.embed_size))
       
#         center = F.mean(embeddings, axis=1)             # [N,P] normalized center vectors eq.(1)
#         center_norm = F.norm(center,2,keepdims=True)
#         center = center/center_norm
#         center_except = F.reshape(F.sum(embeddings, axis=1, keepdims=True) - embeddings, (N*M,P))  # [NM,P] center vectors eq.(8)
#         # make similarity matrix eq.(9)
                
#         sm = np.concatenate(
#             [np.concatenate([np.sum(center_except.d[i*M:(i+1)*M,:]*embeddings.d[j,:,:], axis=1, keepdims=True) if i==j
#                         else np.sum(center.d[i:(i+1),:]*embeddings.d[j,:,:], axis=1, keepdims=True) for i in range(N)],
#                        axis=1) for j in range(N)], axis=0)
#         self.sim_matrix = nn.Variable.from_numpy_array(sm)
#         self.sim_matrix = F.add_scalar(F.mul_scalar(self.sim_matrix, self.sim_weight.d), self.sim_bias.d)
#         #print(self.sim_matrix.d)
#         return nn.Variable.from_numpy_array(sm)
        
    def loss_fn(self, embeddings):
        '''
        Define a cross entropy loss between the similarity matrix and the ground truth labels
        '''
        sim_matrix = self.similarity_matrix(embeddings)
        #self.sim_matrix.data.copy_from(sm.data)
        #self.sim_matrix = self.sim_matrix.reshape(
        #    (self.n_speakers * self.n_utterances, self.n_speakers))

        # Create ground truth labels
        ground_truth = nn.Variable.from_numpy_array(
            np.repeat(np.arange(self.n_speakers), self.n_utterances))
        ground_truth = ground_truth.reshape(
            (self.n_speakers * self.n_utterances, 1))

        loss = F.softmax_cross_entropy(sim_matrix, ground_truth)
        
        return loss
          

    def train(self):
        nn.clear_parameters()
        #nn.set_auto_forward(True)
        mfcc_dim1, mfcc_dim2 = self.dataset[0].shape
        n_batch = len(self.dataset) // self.batch_size
        max_epochs = 600
        loss_scale = 8
        monitor = M.Monitor('.')
        monitor_loss = M.MonitorSeries(
            "Training loss", monitor, interval=2, verbose=True)
        monitor_time = M.MonitorTimeElapsed(
            "Training time", monitor, interval=1000, verbose=True)
        self.sim_weight = nn.parameter.get_parameter_or_create(
            'sim_weight', (1,), I.ConstantInitializer(10.0), need_grad=True)
        self.sim_bias = nn.parameter.get_parameter_or_create(
             'sim_bias', (1,), I.ConstantInitializer(-5.0), need_grad=True)
        
#         self.sim_matrix = nn.Variable.from_numpy_array(np.zeros((self.n_speakers * self.n_utterances, self.n_speakers)), 
#                                                        need_grad = True)

        solver = S.Adam()
        xi = nn.Variable((mfcc_dim1, self.batch_size, mfcc_dim2), need_grad=True)
        # Get the embeddings from the encoder
        embeddings = self.encoder(xi, training=True)
        loss = self.loss_fn(embeddings) # Define loss
        solver.set_parameters(nn.get_parameters())
        print(nn.get_parameters())

        for epoch in range(max_epochs):
            
            # Iterations per epoch

            for i in range(n_batch):

                # Returns current batch
                xi.d = self.get_batch(i)

                loss.forward()
                solver.zero_grad()
                loss.backward(clear_buffer=True)
                solver.scale_grad(1. / loss_scale)
                solver.update()

                # monitor
                itr = epoch * n_batch + i
                monitor_loss.add(itr, loss.d)
            print(f"Epoch:{epoch} | Loss:{loss.d}")

In [7]:
enc = Encoder(dataset)
enc.train()

OrderedDict([('sim_weight', <Variable((1,), need_grad=True) at 0x2996100b450>), ('sim_bias', <Variable((1,), need_grad=True) at 0x29961043810>), ('lstm/weight_l0', <Variable((1, 4, 256, 296), need_grad=True) at 0x299612fa6d0>), ('lstm/weight', <Variable((2, 1, 4, 256, 512), need_grad=True) at 0x299612faa90>), ('lstm/bias', <Variable((3, 1, 4, 256), need_grad=True) at 0x299612fabd0>), ('affine/W', <Variable((256, 256), need_grad=True) at 0x299612faae0>), ('affine/b', <Variable((256,), need_grad=True) at 0x299612fac70>)])


2021-01-24 22:11:19,776 [nnabla][INFO]: iter=1 {Training loss}=3.5063488483428955
2021-01-24 22:11:38,153 [nnabla][INFO]: iter=3 {Training loss}=3.296698570251465
2021-01-24 22:11:57,328 [nnabla][INFO]: iter=5 {Training loss}=3.2159783840179443
2021-01-24 22:12:16,760 [nnabla][INFO]: iter=7 {Training loss}=3.0217413902282715
2021-01-24 22:12:36,028 [nnabla][INFO]: iter=9 {Training loss}=2.9759716987609863
2021-01-24 22:12:55,646 [nnabla][INFO]: iter=11 {Training loss}=2.984682559967041
2021-01-24 22:13:14,974 [nnabla][INFO]: iter=13 {Training loss}=3.074453353881836
2021-01-24 22:13:33,942 [nnabla][INFO]: iter=15 {Training loss}=2.9013309478759766
2021-01-24 22:13:52,806 [nnabla][INFO]: iter=17 {Training loss}=2.781064987182617
2021-01-24 22:14:11,942 [nnabla][INFO]: iter=19 {Training loss}=2.9247536659240723
2021-01-24 22:14:30,685 [nnabla][INFO]: iter=21 {Training loss}=2.72047758102417
2021-01-24 22:14:49,615 [nnabla][INFO]: iter=23 {Training loss}=2.7683305740356445
2021-01-24 22:1

Epoch:0 | Loss:[[ 1.1343552 ]
 [ 0.93748546]
 [ 1.2255394 ]
 [ 0.9675657 ]
 [ 0.9978937 ]
 [ 2.2459593 ]
 [ 1.6836097 ]
 [ 1.6575428 ]
 [ 1.668905  ]
 [ 1.6626154 ]
 [ 3.1590657 ]
 [ 2.4845529 ]
 [ 2.3917468 ]
 [ 4.165096  ]
 [ 4.165096  ]
 [ 1.4441309 ]
 [ 3.8939636 ]
 [ 1.4441309 ]
 [ 2.8362255 ]
 [ 1.7490937 ]
 [ 1.5424505 ]
 [ 1.4881225 ]
 [ 1.3491092 ]
 [ 1.3916821 ]
 [ 1.5216861 ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 6.0893693 ]
 [ 5.1690626 ]
 [ 5.9380436 ]
 [ 2.039061  ]
 [ 2.19873   ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.838387  ]
 [ 2.1456673 ]
 [ 1.938043  ]
 [ 1.7798697 ]
 [ 2.138361  ]
 [ 4.027737  ]
 [ 2.83838

2021-01-24 22:17:01,334 [nnabla][INFO]: iter=37 {Training loss}=2.7550432682037354
2021-01-24 22:17:20,302 [nnabla][INFO]: iter=39 {Training loss}=2.6989781856536865
2021-01-24 22:17:39,074 [nnabla][INFO]: iter=41 {Training loss}=2.555396795272827
2021-01-24 22:17:57,980 [nnabla][INFO]: iter=43 {Training loss}=2.630783796310425
2021-01-24 22:18:17,094 [nnabla][INFO]: iter=45 {Training loss}=2.7999613285064697
2021-01-24 22:18:36,039 [nnabla][INFO]: iter=47 {Training loss}=2.5470457077026367
2021-01-24 22:18:54,786 [nnabla][INFO]: iter=49 {Training loss}=2.7106215953826904
2021-01-24 22:19:13,417 [nnabla][INFO]: iter=51 {Training loss}=2.56636118888855
2021-01-24 22:19:32,242 [nnabla][INFO]: iter=53 {Training loss}=2.3179919719696045
2021-01-24 22:19:51,269 [nnabla][INFO]: iter=55 {Training loss}=2.56302547454834
2021-01-24 22:20:10,227 [nnabla][INFO]: iter=57 {Training loss}=2.58589243888855
2021-01-24 22:20:28,836 [nnabla][INFO]: iter=59 {Training loss}=2.513296604156494
2021-01-24 22

Epoch:1 | Loss:[[1.3675661 ]
 [1.2616485 ]
 [1.3119102 ]
 [1.2453952 ]
 [1.2799665 ]
 [4.1200924 ]
 [1.4924916 ]
 [1.6246169 ]
 [1.9535486 ]
 [1.9175243 ]
 [3.0902555 ]
 [3.1477802 ]
 [3.679193  ]
 [6.6379538 ]
 [6.6379538 ]
 [1.9281788 ]
 [1.879652  ]
 [1.9281788 ]
 [1.7137048 ]
 [2.222343  ]
 [2.2563443 ]
 [0.55323267]
 [0.90098035]
 [0.5774286 ]
 [0.87513953]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.370552  ]
 [2.2622638 ]
 [3.4437957 ]
 [3.1152973 ]
 [1.6743511 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.1080291 ]
 [2.3940406 ]
 [1.9225305 ]
 [2.8776238 ]
 [2.1240656 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [2.7761855 ]
 [1.9

2021-01-24 22:22:40,135 [nnabla][INFO]: iter=73 {Training loss}=2.5659737586975098
2021-01-24 22:22:58,565 [nnabla][INFO]: iter=75 {Training loss}=2.515228748321533
2021-01-24 22:23:17,372 [nnabla][INFO]: iter=77 {Training loss}=2.5283291339874268
2021-01-24 22:23:36,041 [nnabla][INFO]: iter=79 {Training loss}=2.498563289642334
2021-01-24 22:23:54,819 [nnabla][INFO]: iter=81 {Training loss}=2.5928497314453125
2021-01-24 22:24:13,749 [nnabla][INFO]: iter=83 {Training loss}=2.3167104721069336
2021-01-24 22:24:32,394 [nnabla][INFO]: iter=85 {Training loss}=2.5422873497009277
2021-01-24 22:24:51,036 [nnabla][INFO]: iter=87 {Training loss}=2.381408929824829
2021-01-24 22:25:09,296 [nnabla][INFO]: iter=89 {Training loss}=2.2195353507995605
2021-01-24 22:25:27,737 [nnabla][INFO]: iter=91 {Training loss}=2.301358461380005
2021-01-24 22:25:46,731 [nnabla][INFO]: iter=93 {Training loss}=2.382939338684082
2021-01-24 22:26:05,323 [nnabla][INFO]: iter=95 {Training loss}=2.276193618774414
2021-01-24

Epoch:2 | Loss:[[0.8653343 ]
 [0.8752197 ]
 [0.86411005]
 [0.8676011 ]
 [0.84125173]
 [1.802418  ]
 [1.1715033 ]
 [2.086986  ]
 [0.9812367 ]
 [1.1910301 ]
 [2.814827  ]
 [2.8191915 ]
 [3.12975   ]
 [5.8628664 ]
 [5.8628664 ]
 [1.797902  ]
 [1.9000118 ]
 [1.797902  ]
 [2.471164  ]
 [1.9772964 ]
 [0.8443697 ]
 [0.64369327]
 [0.6855148 ]
 [0.68634313]
 [0.9906296 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [1.8508666 ]
 [1.6670697 ]
 [2.1891274 ]
 [1.8383013 ]
 [1.6534367 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [1.3751565 ]
 [1.8613853 ]
 [1.491255  ]
 [1.7978213 ]
 [2.1261187 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [2.7698927 ]
 [1.4

2021-01-24 22:28:14,081 [nnabla][INFO]: iter=109 {Training loss}=2.635732650756836
2021-01-24 22:28:32,573 [nnabla][INFO]: iter=111 {Training loss}=2.5408859252929688
2021-01-24 22:28:51,270 [nnabla][INFO]: iter=113 {Training loss}=2.3992443084716797
2021-01-24 22:29:10,021 [nnabla][INFO]: iter=115 {Training loss}=2.5622973442077637
2021-01-24 22:29:28,164 [nnabla][INFO]: iter=117 {Training loss}=2.5285041332244873
2021-01-24 22:29:46,528 [nnabla][INFO]: iter=119 {Training loss}=2.2657830715179443
2021-01-24 22:30:05,243 [nnabla][INFO]: iter=121 {Training loss}=2.4464404582977295
2021-01-24 22:30:23,887 [nnabla][INFO]: iter=123 {Training loss}=2.367598533630371
2021-01-24 22:30:42,508 [nnabla][INFO]: iter=125 {Training loss}=2.185542106628418
2021-01-24 22:31:00,661 [nnabla][INFO]: iter=127 {Training loss}=2.605356454849243
2021-01-24 22:31:19,524 [nnabla][INFO]: iter=129 {Training loss}=2.4198877811431885
2021-01-24 22:31:38,209 [nnabla][INFO]: iter=131 {Training loss}=2.3705635070800

Epoch:3 | Loss:[[ 0.88761175]
 [ 0.89570385]
 [ 0.94925034]
 [ 0.90813947]
 [ 0.9569048 ]
 [ 2.59198   ]
 [ 1.6084218 ]
 [ 1.675326  ]
 [ 1.5668521 ]
 [ 1.8071128 ]
 [ 3.0614471 ]
 [ 1.9301097 ]
 [ 2.3042605 ]
 [ 5.4823627 ]
 [ 5.4823627 ]
 [ 1.8148698 ]
 [ 2.3847556 ]
 [ 1.8148698 ]
 [ 2.6461978 ]
 [ 2.1070428 ]
 [ 1.4459112 ]
 [ 1.5063293 ]
 [ 1.0110682 ]
 [ 1.5687407 ]
 [ 1.3811879 ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 0.75881803]
 [ 0.7330587 ]
 [ 3.2867537 ]
 [ 2.2575693 ]
 [ 1.0176997 ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 2.816036  ]
 [ 1.6369822 ]
 [ 1.6087459 ]
 [ 1.3544265 ]
 [ 2.630979  ]
 [ 1.4892975 ]
 [ 2.81603

2021-01-24 22:33:48,716 [nnabla][INFO]: iter=145 {Training loss}=2.436481237411499
2021-01-24 22:34:06,961 [nnabla][INFO]: iter=147 {Training loss}=2.303014039993286
2021-01-24 22:34:25,654 [nnabla][INFO]: iter=149 {Training loss}=2.2492308616638184
2021-01-24 22:34:44,523 [nnabla][INFO]: iter=151 {Training loss}=2.192913293838501
2021-01-24 22:35:03,041 [nnabla][INFO]: iter=153 {Training loss}=2.38405179977417
2021-01-24 22:35:21,562 [nnabla][INFO]: iter=155 {Training loss}=2.2618868350982666
2021-01-24 22:35:39,995 [nnabla][INFO]: iter=157 {Training loss}=2.4260270595550537
2021-01-24 22:35:58,554 [nnabla][INFO]: iter=159 {Training loss}=2.363929033279419
2021-01-24 22:36:17,256 [nnabla][INFO]: iter=161 {Training loss}=2.1758882999420166
2021-01-24 22:36:35,937 [nnabla][INFO]: iter=163 {Training loss}=2.4161951541900635
2021-01-24 22:36:54,400 [nnabla][INFO]: iter=165 {Training loss}=2.340636968612671
2021-01-24 22:37:12,724 [nnabla][INFO]: iter=167 {Training loss}=2.270401954650879


Epoch:4 | Loss:[[ 0.8723094 ]
 [ 0.8377958 ]
 [ 0.8742547 ]
 [ 0.82674456]
 [ 0.8326591 ]
 [ 2.4495735 ]
 [ 1.906151  ]
 [ 1.4306489 ]
 [ 1.2468878 ]
 [ 1.8501413 ]
 [ 2.4286394 ]
 [ 2.438681  ]
 [ 2.8337154 ]
 [ 5.5220027 ]
 [ 5.5220027 ]
 [ 1.1589084 ]
 [ 2.998796  ]
 [ 1.1589084 ]
 [ 1.6258211 ]
 [ 1.4230847 ]
 [ 1.4043038 ]
 [ 1.159064  ]
 [ 1.0841714 ]
 [ 1.7054315 ]
 [ 2.072805  ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 1.6783302 ]
 [ 1.7019978 ]
 [ 4.0175643 ]
 [ 2.1938043 ]
 [ 1.5451913 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.7779768 ]
 [ 2.787196  ]
 [ 2.921031  ]
 [ 2.390805  ]
 [ 2.0943289 ]
 [ 2.6171582 ]
 [ 2.77797

2021-01-24 22:39:23,515 [nnabla][INFO]: iter=181 {Training loss}=2.396416664123535
2021-01-24 22:39:42,297 [nnabla][INFO]: iter=183 {Training loss}=2.251049280166626
2021-01-24 22:40:00,686 [nnabla][INFO]: iter=185 {Training loss}=2.328432083129883
2021-01-24 22:40:19,358 [nnabla][INFO]: iter=187 {Training loss}=2.2990779876708984
2021-01-24 22:40:37,890 [nnabla][INFO]: iter=189 {Training loss}=2.3049254417419434
2021-01-24 22:40:56,562 [nnabla][INFO]: iter=191 {Training loss}=2.0444328784942627
2021-01-24 22:41:15,312 [nnabla][INFO]: iter=193 {Training loss}=2.351693630218506
2021-01-24 22:41:33,813 [nnabla][INFO]: iter=195 {Training loss}=2.3213412761688232
2021-01-24 22:41:52,796 [nnabla][INFO]: iter=197 {Training loss}=2.0697710514068604
2021-01-24 22:42:11,521 [nnabla][INFO]: iter=199 {Training loss}=2.2539925575256348
2021-01-24 22:42:30,078 [nnabla][INFO]: iter=201 {Training loss}=2.2124786376953125
2021-01-24 22:42:48,881 [nnabla][INFO]: iter=203 {Training loss}=2.3377873897552

Epoch:5 | Loss:[[1.1241581 ]
 [0.7443495 ]
 [0.7500229 ]
 [0.7190602 ]
 [0.8156676 ]
 [1.9816222 ]
 [1.4329513 ]
 [1.6905818 ]
 [2.0647955 ]
 [1.8981451 ]
 [1.710107  ]
 [2.1588228 ]
 [3.448708  ]
 [4.901683  ]
 [4.901683  ]
 [1.9056836 ]
 [2.6032586 ]
 [1.9056836 ]
 [2.2729263 ]
 [1.8094597 ]
 [1.0408411 ]
 [0.76781917]
 [0.5838635 ]
 [1.1417242 ]
 [1.0898733 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [1.6584175 ]
 [1.6234174 ]
 [2.4198437 ]
 [2.05228   ]
 [1.6217749 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.19448   ]
 [2.2440648 ]
 [1.709256  ]
 [2.116867  ]
 [1.7611668 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.7906606 ]
 [2.4

2021-01-24 22:44:58,509 [nnabla][INFO]: iter=217 {Training loss}=2.4953930377960205
2021-01-24 22:45:16,769 [nnabla][INFO]: iter=219 {Training loss}=2.481599807739258
2021-01-24 22:45:35,810 [nnabla][INFO]: iter=221 {Training loss}=2.3745222091674805
2021-01-24 22:45:54,583 [nnabla][INFO]: iter=223 {Training loss}=2.2030770778656006
2021-01-24 22:46:12,805 [nnabla][INFO]: iter=225 {Training loss}=2.4088687896728516
2021-01-24 22:46:31,593 [nnabla][INFO]: iter=227 {Training loss}=2.190786838531494
2021-01-24 22:46:49,850 [nnabla][INFO]: iter=229 {Training loss}=2.2634005546569824
2021-01-24 22:47:08,539 [nnabla][INFO]: iter=231 {Training loss}=2.2146689891815186
2021-01-24 22:47:27,101 [nnabla][INFO]: iter=233 {Training loss}=2.1587462425231934
2021-01-24 22:47:45,533 [nnabla][INFO]: iter=235 {Training loss}=2.2149994373321533
2021-01-24 22:48:04,417 [nnabla][INFO]: iter=237 {Training loss}=2.1797854900360107
2021-01-24 22:48:23,063 [nnabla][INFO]: iter=239 {Training loss}=2.20008206367

Epoch:6 | Loss:[[0.6885852 ]
 [0.6599421 ]
 [0.90025157]
 [0.64035815]
 [0.6590691 ]
 [1.5205681 ]
 [1.1545854 ]
 [1.3020995 ]
 [0.89451015]
 [1.2847695 ]
 [2.4653354 ]
 [2.2874732 ]
 [2.2368245 ]
 [3.985735  ]
 [3.985735  ]
 [1.6233348 ]
 [1.8216376 ]
 [1.6233348 ]
 [1.7315049 ]
 [1.5504115 ]
 [0.9667575 ]
 [0.60521305]
 [0.58825755]
 [0.9594463 ]
 [0.90549725]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [1.2814276 ]
 [1.1187828 ]
 [1.4825041 ]
 [2.01951   ]
 [1.3547215 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [1.6174595 ]
 [1.7850347 ]
 [1.7504197 ]
 [1.7725718 ]
 [2.5137522 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.8329802 ]
 [2.2

2021-01-24 22:50:33,418 [nnabla][INFO]: iter=253 {Training loss}=2.080251932144165
2021-01-24 22:50:52,024 [nnabla][INFO]: iter=255 {Training loss}=1.9834179878234863
2021-01-24 22:51:10,841 [nnabla][INFO]: iter=257 {Training loss}=2.1022462844848633
2021-01-24 22:51:29,568 [nnabla][INFO]: iter=259 {Training loss}=2.0915255546569824
2021-01-24 22:51:47,986 [nnabla][INFO]: iter=261 {Training loss}=2.1510138511657715
2021-01-24 22:52:06,272 [nnabla][INFO]: iter=263 {Training loss}=1.9472815990447998
2021-01-24 22:52:24,639 [nnabla][INFO]: iter=265 {Training loss}=2.1909265518188477
2021-01-24 22:52:43,312 [nnabla][INFO]: iter=267 {Training loss}=2.006927728652954
2021-01-24 22:53:01,871 [nnabla][INFO]: iter=269 {Training loss}=1.8223159313201904
2021-01-24 22:53:20,357 [nnabla][INFO]: iter=271 {Training loss}=2.031790256500244
2021-01-24 22:53:38,435 [nnabla][INFO]: iter=273 {Training loss}=2.0495519638061523
2021-01-24 22:53:56,753 [nnabla][INFO]: iter=275 {Training loss}=2.069704055786

Epoch:7 | Loss:[[0.8404051 ]
 [0.8045059 ]
 [0.94287086]
 [0.8952632 ]
 [0.81243855]
 [1.703419  ]
 [1.0702078 ]
 [2.0673652 ]
 [1.0085115 ]
 [1.0857433 ]
 [3.1254947 ]
 [1.9278864 ]
 [2.1059341 ]
 [6.249978  ]
 [6.249978  ]
 [2.028489  ]
 [2.5169873 ]
 [2.028489  ]
 [1.6994882 ]
 [1.2636812 ]
 [1.8192586 ]
 [0.8830267 ]
 [0.6565179 ]
 [1.1116371 ]
 [0.89621836]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [1.0286398 ]
 [0.9573944 ]
 [1.6613913 ]
 [1.677447  ]
 [2.330995  ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.3630238 ]
 [1.668496  ]
 [1.6397144 ]
 [1.8858352 ]
 [2.1449728 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [2.7707562 ]
 [1.7

2021-01-24 22:56:06,069 [nnabla][INFO]: iter=289 {Training loss}=2.02249813079834
2021-01-24 22:56:24,482 [nnabla][INFO]: iter=291 {Training loss}=1.9323519468307495
2021-01-24 22:56:42,351 [nnabla][INFO]: iter=293 {Training loss}=1.9301341772079468
2021-01-24 22:57:01,198 [nnabla][INFO]: iter=295 {Training loss}=1.9536226987838745
2021-01-24 22:57:19,964 [nnabla][INFO]: iter=297 {Training loss}=2.13161563873291
2021-01-24 22:57:38,476 [nnabla][INFO]: iter=299 {Training loss}=1.8499536514282227
2021-01-24 22:57:56,596 [nnabla][INFO]: iter=301 {Training loss}=1.9764893054962158
2021-01-24 22:58:15,312 [nnabla][INFO]: iter=303 {Training loss}=1.9501385688781738
2021-01-24 22:58:33,907 [nnabla][INFO]: iter=305 {Training loss}=1.650172472000122
2021-01-24 22:58:52,108 [nnabla][INFO]: iter=307 {Training loss}=1.964810848236084
2021-01-24 22:59:10,458 [nnabla][INFO]: iter=309 {Training loss}=2.0918703079223633
2021-01-24 22:59:28,931 [nnabla][INFO]: iter=311 {Training loss}=2.076798915863037

Epoch:8 | Loss:[[1.4976851 ]
 [0.5038369 ]
 [0.55566454]
 [0.78086275]
 [0.5696901 ]
 [2.1124234 ]
 [0.9893247 ]
 [0.7374017 ]
 [0.8857727 ]
 [1.0580837 ]
 [3.81871   ]
 [1.941451  ]
 [2.133133  ]
 [4.555847  ]
 [4.555847  ]
 [1.644582  ]
 [1.4486098 ]
 [1.644582  ]
 [1.799379  ]
 [1.605095  ]
 [0.79074734]
 [0.95864636]
 [0.9342938 ]
 [1.0528402 ]
 [1.017729  ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [0.96504194]
 [0.88320416]
 [1.71729   ]
 [1.8769866 ]
 [1.050518  ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [1.5867705 ]
 [1.5788395 ]
 [1.7355663 ]
 [2.1347206 ]
 [1.6589916 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [2.7948782 ]
 [1.1

2021-01-24 23:01:37,514 [nnabla][INFO]: iter=325 {Training loss}=1.989101529121399
2021-01-24 23:01:56,138 [nnabla][INFO]: iter=327 {Training loss}=1.9468700885772705
2021-01-24 23:02:14,725 [nnabla][INFO]: iter=329 {Training loss}=1.9108717441558838
2021-01-24 23:02:32,517 [nnabla][INFO]: iter=331 {Training loss}=1.8216438293457031
2021-01-24 23:02:50,570 [nnabla][INFO]: iter=333 {Training loss}=1.9280434846878052
2021-01-24 23:03:08,701 [nnabla][INFO]: iter=335 {Training loss}=1.779225468635559
2021-01-24 23:03:27,359 [nnabla][INFO]: iter=337 {Training loss}=1.9529849290847778
2021-01-24 23:03:46,256 [nnabla][INFO]: iter=339 {Training loss}=2.062739372253418
2021-01-24 23:04:04,691 [nnabla][INFO]: iter=341 {Training loss}=1.680403709411621
2021-01-24 23:04:22,741 [nnabla][INFO]: iter=343 {Training loss}=1.785764455795288
2021-01-24 23:04:41,075 [nnabla][INFO]: iter=345 {Training loss}=1.9452835321426392
2021-01-24 23:04:59,181 [nnabla][INFO]: iter=347 {Training loss}=1.89267551898956

Epoch:9 | Loss:[[0.735772  ]
 [0.7568685 ]
 [0.7644828 ]
 [1.0019166 ]
 [0.68540907]
 [2.224726  ]
 [0.89630705]
 [0.87936944]
 [1.0722183 ]
 [1.0634328 ]
 [3.736922  ]
 [1.6910628 ]
 [1.704981  ]
 [5.1090965 ]
 [5.1090965 ]
 [1.982861  ]
 [1.7300258 ]
 [1.982861  ]
 [1.7978369 ]
 [1.3525717 ]
 [0.48356146]
 [0.66606283]
 [0.72066444]
 [1.29175   ]
 [1.1288953 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [0.7566956 ]
 [0.6696903 ]
 [2.754695  ]
 [1.4207226 ]
 [1.4112197 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [1.176129  ]
 [1.1193297 ]
 [1.5628341 ]
 [1.5803392 ]
 [2.589725  ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [2.7765398 ]
 [1.4

2021-01-24 23:07:06,883 [nnabla][INFO]: iter=361 {Training loss}=1.8513637781143188
2021-01-24 23:07:25,123 [nnabla][INFO]: iter=363 {Training loss}=1.9512077569961548
2021-01-24 23:07:43,714 [nnabla][INFO]: iter=365 {Training loss}=2.003016710281372
2021-01-24 23:08:02,438 [nnabla][INFO]: iter=367 {Training loss}=1.835451602935791
2021-01-24 23:08:20,625 [nnabla][INFO]: iter=369 {Training loss}=2.0967798233032227
2021-01-24 23:08:38,395 [nnabla][INFO]: iter=371 {Training loss}=1.79111647605896
2021-01-24 23:08:56,384 [nnabla][INFO]: iter=373 {Training loss}=1.8644229173660278
2021-01-24 23:09:14,821 [nnabla][INFO]: iter=375 {Training loss}=1.794245958328247
2021-01-24 23:09:32,811 [nnabla][INFO]: iter=377 {Training loss}=1.6980174779891968
2021-01-24 23:09:51,247 [nnabla][INFO]: iter=379 {Training loss}=1.8785220384597778
2021-01-24 23:10:09,066 [nnabla][INFO]: iter=381 {Training loss}=1.9091925621032715
2021-01-24 23:10:26,802 [nnabla][INFO]: iter=383 {Training loss}=1.85594058036804

Epoch:10 | Loss:[[0.8181421 ]
 [0.70629746]
 [0.78228676]
 [0.85317993]
 [0.64058584]
 [1.5894765 ]
 [1.3379486 ]
 [1.0213445 ]
 [1.4856588 ]
 [1.2049847 ]
 [4.194795  ]
 [1.6789435 ]
 [1.537419  ]
 [5.7998114 ]
 [5.7998114 ]
 [2.1153846 ]
 [1.9948545 ]
 [2.1153846 ]
 [1.5085258 ]
 [1.4641023 ]
 [0.8301855 ]
 [0.80217683]
 [0.77631545]
 [1.029909  ]
 [0.94426537]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [1.2144964 ]
 [1.0225186 ]
 [1.380947  ]
 [1.7135175 ]
 [1.0656087 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [1.4768844 ]
 [1.5190903 ]
 [1.9295015 ]
 [1.8740503 ]
 [1.3813621 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [2.7786791 ]
 [1.

2021-01-24 23:12:35,618 [nnabla][INFO]: iter=397 {Training loss}=1.800217628479004
2021-01-24 23:12:53,960 [nnabla][INFO]: iter=399 {Training loss}=1.7917895317077637
2021-01-24 23:13:11,744 [nnabla][INFO]: iter=401 {Training loss}=1.7820144891738892
2021-01-24 23:13:29,988 [nnabla][INFO]: iter=403 {Training loss}=1.702552318572998
2021-01-24 23:13:48,188 [nnabla][INFO]: iter=405 {Training loss}=1.8415051698684692
2021-01-24 23:14:06,289 [nnabla][INFO]: iter=407 {Training loss}=1.6888320446014404
2021-01-24 23:14:24,311 [nnabla][INFO]: iter=409 {Training loss}=1.7250511646270752
2021-01-24 23:14:42,162 [nnabla][INFO]: iter=411 {Training loss}=1.8688864707946777
2021-01-24 23:15:00,327 [nnabla][INFO]: iter=413 {Training loss}=1.576442003250122
2021-01-24 23:15:19,005 [nnabla][INFO]: iter=415 {Training loss}=1.7714760303497314
2021-01-24 23:15:37,354 [nnabla][INFO]: iter=417 {Training loss}=1.8590307235717773
2021-01-24 23:15:55,655 [nnabla][INFO]: iter=419 {Training loss}=1.866517543792

Epoch:11 | Loss:[[0.3953639 ]
 [0.33843428]
 [0.91672915]
 [1.0473387 ]
 [0.44820696]
 [1.8198729 ]
 [1.310638  ]
 [1.2389731 ]
 [1.5414246 ]
 [1.8284538 ]
 [2.885965  ]
 [1.8687253 ]
 [2.0324838 ]
 [6.8370433 ]
 [6.8370433 ]
 [1.1625589 ]
 [1.7739502 ]
 [1.1625589 ]
 [1.4090102 ]
 [1.0495129 ]
 [1.0829002 ]
 [1.0635049 ]
 [0.8602967 ]
 [1.1522155 ]
 [1.9838278 ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [0.43301198]
 [0.33543125]
 [1.00128   ]
 [0.9726446 ]
 [0.69916403]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [1.1829156 ]
 [1.4327763 ]
 [1.2352377 ]
 [3.1436071 ]
 [0.8303665 ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [2.757369  ]
 [1.

2021-01-24 23:18:06,899 [nnabla][INFO]: iter=433 {Training loss}=1.7299307584762573
2021-01-24 23:18:25,544 [nnabla][INFO]: iter=435 {Training loss}=1.6940972805023193
2021-01-24 23:18:44,322 [nnabla][INFO]: iter=437 {Training loss}=1.7770376205444336
2021-01-24 23:19:02,889 [nnabla][INFO]: iter=439 {Training loss}=1.5955986976623535
2021-01-24 23:19:21,443 [nnabla][INFO]: iter=441 {Training loss}=1.7869294881820679
2021-01-24 23:19:40,075 [nnabla][INFO]: iter=443 {Training loss}=1.5917391777038574
2021-01-24 23:19:58,990 [nnabla][INFO]: iter=445 {Training loss}=1.8222922086715698
2021-01-24 23:20:17,551 [nnabla][INFO]: iter=447 {Training loss}=1.955286979675293
2021-01-24 23:20:36,393 [nnabla][INFO]: iter=449 {Training loss}=1.7391265630722046
2021-01-24 23:20:55,447 [nnabla][INFO]: iter=451 {Training loss}=1.8029571771621704
2021-01-24 23:21:14,143 [nnabla][INFO]: iter=453 {Training loss}=1.8649290800094604
2021-01-24 23:21:32,828 [nnabla][INFO]: iter=455 {Training loss}=1.8541063070

Epoch:12 | Loss:[[ 0.7116388 ]
 [ 0.6850746 ]
 [ 0.71641034]
 [ 0.83063436]
 [ 0.73361135]
 [ 2.176107  ]
 [ 0.7869606 ]
 [ 0.6218951 ]
 [ 1.2514929 ]
 [ 0.7980007 ]
 [ 3.2280536 ]
 [ 2.0387475 ]
 [ 2.8412437 ]
 [ 7.837415  ]
 [ 7.837415  ]
 [ 1.6210132 ]
 [ 1.7566357 ]
 [ 1.6210132 ]
 [ 1.9052929 ]
 [ 1.7710059 ]
 [ 1.4467145 ]
 [ 1.543197  ]
 [ 0.8184636 ]
 [ 0.9711407 ]
 [ 0.8230087 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 0.22732477]
 [ 0.33312818]
 [ 1.796114  ]
 [ 3.475043  ]
 [ 0.23892602]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 2.7532513 ]
 [ 1.1638472 ]
 [ 1.6119634 ]
 [ 1.483773  ]
 [ 3.5036998 ]
 [ 1.7806609 ]
 [ 2.7532

2021-01-24 23:23:43,387 [nnabla][INFO]: iter=469 {Training loss}=1.69423246383667
2021-01-24 23:24:02,046 [nnabla][INFO]: iter=471 {Training loss}=1.632779598236084
2021-01-24 23:24:20,271 [nnabla][INFO]: iter=473 {Training loss}=1.7093065977096558
2021-01-24 23:24:38,769 [nnabla][INFO]: iter=475 {Training loss}=1.5712509155273438
2021-01-24 23:24:57,216 [nnabla][INFO]: iter=477 {Training loss}=1.7539420127868652
2021-01-24 23:25:15,591 [nnabla][INFO]: iter=479 {Training loss}=1.5886414051055908
2021-01-24 23:25:34,054 [nnabla][INFO]: iter=481 {Training loss}=1.6360946893692017
2021-01-24 23:25:52,737 [nnabla][INFO]: iter=483 {Training loss}=1.7244871854782104
2021-01-24 23:26:11,313 [nnabla][INFO]: iter=485 {Training loss}=1.4276334047317505
2021-01-24 23:26:29,898 [nnabla][INFO]: iter=487 {Training loss}=1.5596362352371216
2021-01-24 23:26:48,176 [nnabla][INFO]: iter=489 {Training loss}=1.6498373746871948
2021-01-24 23:27:06,800 [nnabla][INFO]: iter=491 {Training loss}=1.532147169113

Epoch:13 | Loss:[[ 0.3053896 ]
 [ 0.2865725 ]
 [ 0.3352281 ]
 [ 0.6801262 ]
 [ 0.39217782]
 [ 2.2353296 ]
 [ 0.9444619 ]
 [ 0.68917483]
 [ 0.85367936]
 [ 0.8399429 ]
 [ 2.8281317 ]
 [ 2.222795  ]
 [ 2.6369863 ]
 [ 6.6993074 ]
 [ 6.6993074 ]
 [ 1.2403482 ]
 [ 1.1170275 ]
 [ 1.2403482 ]
 [ 1.5931108 ]
 [ 1.1148431 ]
 [ 1.0489922 ]
 [ 1.1567185 ]
 [ 0.8509086 ]
 [ 1.129875  ]
 [ 1.1058496 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 0.5054905 ]
 [ 0.38230458]
 [ 1.4645227 ]
 [ 2.2695434 ]
 [ 0.50324494]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 2.7609942 ]
 [ 0.72924453]
 [ 1.6606367 ]
 [ 1.1764138 ]
 [ 1.3625606 ]
 [ 1.2766297 ]
 [ 2.7609

2021-01-24 23:29:17,139 [nnabla][INFO]: iter=505 {Training loss}=1.5910378694534302
2021-01-24 23:29:35,618 [nnabla][INFO]: iter=507 {Training loss}=1.5047122240066528
2021-01-24 23:29:53,921 [nnabla][INFO]: iter=509 {Training loss}=1.5091701745986938
2021-01-24 23:30:12,682 [nnabla][INFO]: iter=511 {Training loss}=1.371121883392334
2021-01-24 23:30:31,419 [nnabla][INFO]: iter=513 {Training loss}=1.7639758586883545
2021-01-24 23:30:50,316 [nnabla][INFO]: iter=515 {Training loss}=1.5457061529159546
2021-01-24 23:31:08,450 [nnabla][INFO]: iter=517 {Training loss}=1.5162975788116455
2021-01-24 23:31:26,896 [nnabla][INFO]: iter=519 {Training loss}=1.625199556350708
2021-01-24 23:31:45,116 [nnabla][INFO]: iter=521 {Training loss}=1.424853801727295
2021-01-24 23:32:03,448 [nnabla][INFO]: iter=523 {Training loss}=1.5448483228683472
2021-01-24 23:32:22,202 [nnabla][INFO]: iter=525 {Training loss}=1.6317050457000732
2021-01-24 23:32:40,714 [nnabla][INFO]: iter=527 {Training loss}=1.557545423507

Epoch:14 | Loss:[[0.5170929 ]
 [0.50596356]
 [0.555598  ]
 [0.95067436]
 [0.604603  ]
 [2.1101642 ]
 [1.2945197 ]
 [1.2057016 ]
 [1.1543119 ]
 [2.060532  ]
 [2.8366601 ]
 [1.8213336 ]
 [1.7814302 ]
 [6.4951897 ]
 [6.4951897 ]
 [1.5829678 ]
 [1.4180293 ]
 [1.5829678 ]
 [1.7163961 ]
 [1.5448283 ]
 [1.05458   ]
 [1.2540704 ]
 [0.87264705]
 [1.1137701 ]
 [0.9430396 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [0.27886555]
 [0.23100016]
 [0.5787486 ]
 [1.7952925 ]
 [0.73052746]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [0.8705229 ]
 [1.3952467 ]
 [1.114162  ]
 [2.9941938 ]
 [1.1525978 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [2.7656524 ]
 [1.

2021-01-24 23:34:49,699 [nnabla][INFO]: iter=541 {Training loss}=1.4901132583618164
2021-01-24 23:35:08,042 [nnabla][INFO]: iter=543 {Training loss}=1.4873501062393188
2021-01-24 23:35:26,338 [nnabla][INFO]: iter=545 {Training loss}=1.4416736364364624
2021-01-24 23:35:44,443 [nnabla][INFO]: iter=547 {Training loss}=1.394999384880066
2021-01-24 23:36:02,587 [nnabla][INFO]: iter=549 {Training loss}=1.5797622203826904
2021-01-24 23:36:20,374 [nnabla][INFO]: iter=551 {Training loss}=1.4178881645202637


KeyboardInterrupt: 

In [None]:
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
centroids_incl = F.mean(embeds, axis=1, keepdims=True)
centroids_incl = centroids_incl / F.norm(centroids_incl, axis=2, keepdims=True)
centroids_excl = F.sum(embeds, axis=1, keepdims=True) - embeds
centroids_excl /= (utterances_per_speaker - 1)
centroids_excl = centroids_excl / F.norm(centroids_excl, axis=2, keepdims=True)
sim_matrix = nn.Variable.from_numpy_array(np.zeros((speakers_per_batch, utterances_per_speaker, speakers_per_batch),dtype=np.int))
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
for j in range(speakers_per_batch):
    mask = np.where(mask_matrix[j])[0]
    sim_matrix[mask, :, j] = F.sum(embeds[mask] * centroids_incl[j],axis=2)
    sim_matrix[j, :, j] = F.sum(embeds[j] * centroids_excl[j],axis=1)

In [None]:
center = F.mean(embedded_split, axis=1)             # [N,P] normalized center vectors eq.(1)
center_norm = F.norm(center,2,keepdims=True)
center = center/center_norm

center_except = F.reshape(F.sum(embedded_split, axis=1, keepdims=True) - embedded_split, (N*M,P))  # [NM,P] center vectors eq.(8)
# make similarity matrix eq.(9)
S = tf.concat(
    [tf.concat([tf.reduce_sum(center_except[i*M:(i+1)*M,:]*embedded_split[j,:,:], axis=1, keep_dims=True) if i==j
                else tf.reduce_sum(center[i:(i+1),:]*embedded_split[j,:,:], axis=1, keep_dims=True) for i in range(N)],
               axis=1) for j in range(N)], axis=0)

nnabla.functions.concatenate

In [None]:
zz = {"s1":np.array([np.random.randint(1,10,(2,4)),np.random.randint(1,10,(2,4)),
                     np.random.randint(1,10,(2,4)),np.random.randint(1,10,(2,4))]),
      "s2":np.array([np.random.randint(1,10,(2,4)),np.random.randint(1,10,(2,4)),
                     np.random.randint(1,10,(2,4)),np.random.randint(1,10,(2,4))])    
     
     }
zz

In [None]:
xx = zz.copy()
curr= xx["s1"][0:2]
print(xx["s1"].shape)
xx["s1"]=np.delete(xx["s1"], [0,1], axis=0)
xx["s1"]=np.append(xx["s1"], curr, axis=0)
print(xx["s1"].shape)
print(curr)
xx

In [14]:
nn.set_auto_forward(True)

def compute_similarity(emb, n_speakers, n_utterances):
    # embedding reshape
    emb_re = emb.reshape((n_speakers, n_utterances, -1))
    
    # compute the inclusion centroids
    cen = F.mean(emb_re, axis=1) 
    cen = cen / F.norm(cen, axis=1, keepdims=True)
    cen = F.reshape(cen, (-1, emb.shape[-1]))
    
    # compute the exclusion centroids
    exc = F.sum(emb_re, axis=1, keepdims=True) - emb_re
    exc = exc / (n_utterances - 1)
    exc = exc / F.norm(exc, axis=2, keepdims=True)
    exc = F.reshape(exc, emb.shape)

    diag = F.sum(exc * emb, axis=1, keepdims=True) # 20 x 1
    sim = F.affine(emb,  F.transpose(cen, (1, 0))) 

    mask = np.concatenate([np.tile(w, (n_utterances, 1)) for w in np.eye(n_speakers)])
    mask = nn.Variable.from_numpy_array(mask)
    ret = (1 - mask) * sim + F.tile(diag, n_speakers) * mask
    return  ret



bz = 5
dim = 6
n_speakers = 5
n_utterances = 4

rng = np.random.RandomState(1234)
data = np.arange(bz * n_utterances * dim).reshape((bz * n_utterances, dim))

rng.randn(bz, dim)

emb = nn.Variable.from_numpy_array(data)
emb = emb / F.norm(emb, axis=1, keepdims=True)

ret =  compute_similarity(emb, n_speakers, n_utterances)

In [15]:
ret.shape

(20, 5)

In [16]:
ret.d

array([[0.8931736 , 0.85282373, 0.8417678 , 0.8371577 , 0.8346132 ],
       [0.99777466, 0.9890098 , 0.98571754, 0.98425335, 0.98342335],
       [0.9846763 , 0.99773246, 0.99611515, 0.99533224, 0.99487424],
       [0.975716  , 0.99944794, 0.9985397 , 0.99804544, 0.997745  ],
       [0.9831892 , 0.9998139 , 0.99937737, 0.9990421 , 0.99882853],
       [0.98095524, 0.9999939 , 0.99972546, 0.9994907 , 0.99933195],
       [0.9793355 , 0.99997264, 0.99988353, 0.99971807, 0.99959725],
       [0.9781107 , 0.9998819 , 0.99995685, 0.999842  , 0.99974895],
       [0.9771528 , 0.99987066, 0.99998003, 0.9999125 , 0.99984056],
       [0.97638404, 0.9998066 , 0.9999988 , 0.99995357, 0.9998982 ],
       [0.97575355, 0.99974537, 0.99999726, 0.99997735, 0.9999353 ],
       [0.9752273 , 0.99968845, 0.99998474, 0.9999905 , 0.99995965],
       [0.97478133, 0.9996364 , 0.9999811 , 0.9999952 , 0.9999757 ],
       [0.9743989 , 0.99958897, 0.9999692 , 0.99999964, 0.99998623],
       [0.9740671 , 0.99954563, 0.

In [11]:
mask = np.concatenate([np.tile(w, (4, 1)) for w in np.eye(5)])

In [13]:
1-mask

array([[0., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [1., 0., 1., 1., 1.],
       [1., 0., 1., 1., 1.],
       [1., 0., 1., 1., 1.],
       [1., 0., 1., 1., 1.],
       [1., 1., 0., 1., 1.],
       [1., 1., 0., 1., 1.],
       [1., 1., 0., 1., 1.],
       [1., 1., 0., 1., 1.],
       [1., 1., 1., 0., 1.],
       [1., 1., 1., 0., 1.],
       [1., 1., 1., 0., 1.],
       [1., 1., 1., 0., 1.],
       [1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 0.]])