In [1]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
sys.path.append('SincNet/')
from conv_tasnet import *
from pit_criterion import cal_loss
from dnn_models import *
from data_io import ReadList,read_conf_inp,str_to_bool
from collections import Counter
import os
device = 1
root = '../'
old_sr = 16000
new_sr = 8000

In [2]:
def load8hz(filename):
    samples = np.load(filename)/(2**15)
    samples = resample(samples, old_sr, new_sr)
    # pad the samples
    if len(samples)>16000:
        samples = samples[:16000]
    if len(samples)<16000:
        padding = np.zeros(16000-len(samples))
        samples = np.concatenate([samples, padding])
    
    return samples

class SourceSet(torch.utils.data.Dataset):
    def __init__(self, root, csv):
        super().__init__()
        self.root = root
        self.csv = pd.read_csv(root+csv)
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        sig1, sig2 = load8hz(root+row['first_file']), load8hz(root+row['second_file'])
        return sig1, sig2
sourceset_train = SourceSet(root, 'overlay-train.csv')


In [3]:
tasnet = ConvTasNet.load_model('final.pth.tar').cuda(device)
tasnet.train()
optimizer = torch.optim.Adam(tasnet.parameters(), lr = 0.001)
if os.path.exists('models/tasnet.pth'):
    print('load model')
    checkpoint = torch.load('models/tasnet.pth')
    tasnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']


load model


In [8]:
batch_size = 8
sourceloader_train  = torch.utils.data.DataLoader(sourceset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)

for epoch in range(64):
    running_loss = 0.0
    for batch_idx, (sig1, sig2) in enumerate(tqdm(sourceloader_train)):
        optimizer.zero_grad()
        sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
        out = tasnet(sig1+sig2)
        source = torch.stack([sig1, sig2], dim = 1).detach()
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(source, out, torch.ones(batch_size, dtype = torch.int32).cuda(device)*16000)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
            torch.save({
            'model_state_dict': tasnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/tasnet.pth')

HBox(children=(FloatProgress(value=0.0, max=11020.0), HTML(value='')))

[1,   200] loss: -5.306 
[1,   400] loss: -7.088 
[1,   600] loss: -7.693 
[1,   800] loss: -8.213 
[1,  1000] loss: -8.359 
[1,  1200] loss: -9.046 
[1,  1400] loss: -9.422 
[1,  1600] loss: -9.418 
[1,  1800] loss: -9.483 
[1,  2000] loss: -9.617 
[1,  2400] loss: -9.916 
[1,  2600] loss: -10.315 
[1,  2800] loss: -10.095 
[1,  3000] loss: -10.409 
[1,  3400] loss: -10.462 
[1,  3600] loss: -10.715 
[1,  3800] loss: -10.751 
[1,  4000] loss: -10.686 
[1,  4200] loss: -10.880 
[1,  4400] loss: -10.897 
[1,  4600] loss: -10.904 
[1,  4800] loss: -10.872 
[1,  5000] loss: -11.027 
[1,  5200] loss: -11.191 
[1,  5400] loss: -10.955 
[1,  5600] loss: -11.128 
[1,  5800] loss: -11.260 
[1,  6000] loss: -11.420 
[1,  6200] loss: -11.372 
[1,  6400] loss: -11.294 
[1,  6600] loss: -11.246 
[1,  6800] loss: -11.170 
[1,  7000] loss: -11.439 
[1,  7200] loss: -11.343 
[1,  7400] loss: -11.506 
[1,  7600] loss: -11.371 
[1,  7800] loss: -11.489 
[1,  8000] loss: -11.501 
[1,  8200] loss: -11.55

HBox(children=(FloatProgress(value=0.0, max=11020.0), HTML(value='')))

[2,   200] loss: -11.751 
[2,   400] loss: -12.042 
[2,   600] loss: -11.917 
[2,   800] loss: -11.934 
[2,  1000] loss: -11.870 
[2,  1200] loss: -12.128 
[2,  1400] loss: -12.013 
[2,  1600] loss: -11.981 
[2,  1800] loss: -11.894 
[2,  2000] loss: -12.098 
[2,  2200] loss: -12.081 
[2,  2400] loss: -12.062 
[2,  2600] loss: -11.933 


KeyboardInterrupt: 

In [17]:
def reformat(prefix, i, filename):
    filename = filename.split('.')[-2]
    return filename + '_' + prefix + '_' + str(i) + '.npy'

In [25]:
# create a new copy of training data in pwd, with same filename, but replace audio data with de-mixed audio data
mode = 'test'
csv_name = 'overlay-'+mode+'.csv'
csv = pd.read_csv(root+csv_name)

with torch.no_grad():
    for i in tqdm(range(len(csv))): 
        row = csv.iloc[i]
        seg1 = load8hz(root+row['first_file'])
        seg2 = load8hz(root+row['second_file'])
        create_dir(row['first_file'])
        create_dir(row['second_file'])
#         shorter = min(len(seg1), len(seg2))
#         if len(seg1)>shorter:
#             seg1 = seg1[:shorter]
#         if len(seg2)>shorter:
#             seg2 = seg2[:shorter]
        mixture = torch.Tensor(seg1+seg2).cuda(device)
        mixture = mixture[None, ...]
        out = tasnet(mixture)
        new_seg1, new_seg2 = out[0].cpu().detach().numpy()
        newfile1, newfile2 = reformat(mode, i, row['first_file']), reformat(mode, i, row['second_file'])
        csv.at[i, 'first_file'] = newfile1
        csv.at[i, 'second_file'] = newfile2
        np.save(newfile1, new_seg1)
        np.save(newfile2, new_seg2)
    csv.to_csv(csv_name, index = False)

HBox(children=(FloatProgress(value=0.0, max=11020.0), HTML(value='')))




In [26]:
Audio(seg1, rate = new_sr)

In [27]:
Audio(seg2, rate = new_sr)

In [28]:
Audio(new_seg1, rate = new_sr)

In [29]:
Audio(new_seg2, rate = new_sr)

In [3]:
def chop_chunk(signal):
    signal_len = signal.shape[-1]
    if signal_len < 16000:
        padding = np.zeros(16000-len(signal))
        signal = np.cat((signal, padding))
    N_fr=signal_len//wlen
    chunks = []
    for i in range(N_fr):
        chunks.append(signal[i*wlen:(i+1)*wlen])
    return chunks

class ChunkSet(torch.utils.data.Dataset):
    def __init__(self, csv, mode='train'):
        super().__init__()
        self.csv = pd.read_csv(csv)
        self.speakers = list(set(self.csv['first_speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.mode = mode
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        spkr1, spkr2 = row['first_speaker'], row['second_speaker']
        sig1, sig2 = np.load(row['first_file']), np.load(row['second_file'])
        chunk1, chunk2 = chop_chunk(sig1), chop_chunk(sig2)
        target_vec = np.zeros(len(self.speakers))
        target_vec[self.spkr2idx[spkr1]] = 1
        target_vec[self.spkr2idx[spkr2]] = 1
        if self.mode == 'val':
            return np.array(chunk1), np.array(chunk2), target_vec
        if self.mode == 'train':
            return chunk1[np.random.randint(len(chunk1))], chunk2[np.random.randint(len(chunk2))], target_vec

chunkset_train = ChunkSet('overlay-train.csv', mode = 'train')
chunkset_val = ChunkSet('overlay-val.csv', mode = 'val')
chunkset_test = ChunkSet('overlay-test.csv', mode = 'val')

In [4]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

In [5]:
fs=new_sr
cw_len=200
cw_shift=10

wlen=int(fs*cw_len/1000.00)
#wshift=int(fs*cw_shift/1000.00)




class MixedClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        cnn_arch = {
                'input_dim':wlen,
                'fs':fs,
                'cnn_N_filt':[80,60,60],
                'cnn_len_filt':[251,5,5],
                'cnn_max_pool_len':[3,3,3],
                'cnn_use_laynorm_inp':True,
                'cnn_use_batchnorm_inp':False,
                'cnn_use_laynorm':[True,True,True],
                'cnn_use_batchnorm':[False,False,False],
                'cnn_act':['leaky_relu','leaky_relu','leaky_relu'],
                'cnn_drop':[0.0,0.0,0.0]
                }
        self.cnn_net = SincNet(cnn_arch)

        dnn1_arch = {'input_dim': self.cnn_net.out_dim,
                  'fc_lay': [2048,2048,2048],
                  'fc_drop': [0.0,0.0,0.0], 
                  'fc_use_batchnorm': [True,True,True],
                  'fc_use_laynorm': [False,False,False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['leaky_relu','leaky_relu','leaky_relu']
                  }
        self.dnn1 = MLP(dnn1_arch)


        dnn2_arch = {'input_dim':2048 ,
                  'fc_lay': [20],
                  'fc_drop': [0.0], 
                  'fc_use_batchnorm': [False],
                  'fc_use_laynorm': [False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['linear'] # leakyrelu(1) is just identity mapping
                  }
        self.dnn2 = MLP(dnn2_arch)
        
        self.softmax = nn.Softmax(dim = 1)
    def forward(self, X):
        out = self.cnn_net(X)
        out = self.dnn1(out)
        out = self.dnn2(out)
        out = self.softmax(out)
        return out

cls = MixedClassifier().cuda(device)
cls.train()
optimizer = torch.optim.Adam(cls.parameters(), 0.001)


if os.path.exists('models/sincnet.pth'):
    print('load model')
    checkpoint = torch.load('models/sincnet.pth')
    cls.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0

load model


In [7]:
batch_size = 64
chunkloader_train = torch.utils.data.DataLoader(chunkset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()



for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (X1, X2, target) in enumerate(tqdm(chunkloader_train)):
        optimizer.zero_grad()
        X1, X2, target = X1.float().cuda(device), X2.float().cuda(device), target.float().cuda(device)

        out1, out2 = cls(X1), cls(X2)
        out = torch.stack([out1, out2], dim = 0)
        out, _ = torch.max(out, dim = 0)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(cls.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        running_accuracy += compute_corrects(out, target)/batch_size
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': cls.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/sincnet.pth')


HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[1,   200] loss: 0.034 accuracy: 0.880
[1,   400] loss: 0.033 accuracy: 0.879
[1,   600] loss: 0.032 accuracy: 0.882
[1,   800] loss: 0.032 accuracy: 0.886
[1,  1000] loss: 0.031 accuracy: 0.884
[1,  1200] loss: 0.032 accuracy: 0.884



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[2,   200] loss: 0.031 accuracy: 0.887
[2,   400] loss: 0.032 accuracy: 0.887
[2,   600] loss: 0.032 accuracy: 0.885
[2,   800] loss: 0.031 accuracy: 0.888
[2,  1000] loss: 0.033 accuracy: 0.882
[2,  1200] loss: 0.032 accuracy: 0.887



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[3,   200] loss: 0.030 accuracy: 0.890
[3,   400] loss: 0.031 accuracy: 0.890
[3,   600] loss: 0.031 accuracy: 0.887
[3,   800] loss: 0.032 accuracy: 0.884
[3,  1000] loss: 0.032 accuracy: 0.885
[3,  1200] loss: 0.033 accuracy: 0.882



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[4,   200] loss: 0.030 accuracy: 0.894
[4,   400] loss: 0.031 accuracy: 0.890
[4,   600] loss: 0.031 accuracy: 0.890
[4,   800] loss: 0.031 accuracy: 0.889
[4,  1000] loss: 0.031 accuracy: 0.884
[4,  1200] loss: 0.032 accuracy: 0.882



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[5,   200] loss: 0.031 accuracy: 0.891
[5,   400] loss: 0.030 accuracy: 0.892
[5,   600] loss: 0.030 accuracy: 0.892
[5,   800] loss: 0.030 accuracy: 0.895
[5,  1000] loss: 0.031 accuracy: 0.889
[5,  1200] loss: 0.032 accuracy: 0.885



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[6,   200] loss: 0.030 accuracy: 0.892
[6,   400] loss: 0.029 accuracy: 0.896
[6,   600] loss: 0.030 accuracy: 0.895
[6,   800] loss: 0.030 accuracy: 0.896
[6,  1000] loss: 0.030 accuracy: 0.891
[6,  1200] loss: 0.032 accuracy: 0.889



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[7,   200] loss: 0.029 accuracy: 0.896
[7,   400] loss: 0.029 accuracy: 0.898
[7,   600] loss: 0.031 accuracy: 0.890
[7,   800] loss: 0.030 accuracy: 0.893
[7,  1000] loss: 0.029 accuracy: 0.894
[7,  1200] loss: 0.031 accuracy: 0.891



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[8,   200] loss: 0.028 accuracy: 0.902
[8,   400] loss: 0.030 accuracy: 0.892
[8,   600] loss: 0.029 accuracy: 0.895
[8,   800] loss: 0.029 accuracy: 0.897
[8,  1000] loss: 0.030 accuracy: 0.893
[8,  1200] loss: 0.030 accuracy: 0.896



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[9,   200] loss: 0.027 accuracy: 0.899
[9,   400] loss: 0.028 accuracy: 0.902
[9,   600] loss: 0.029 accuracy: 0.895
[9,   800] loss: 0.029 accuracy: 0.895
[9,  1000] loss: 0.029 accuracy: 0.898
[9,  1200] loss: 0.028 accuracy: 0.898



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[10,   200] loss: 0.028 accuracy: 0.900
[10,   400] loss: 0.028 accuracy: 0.900
[10,   600] loss: 0.028 accuracy: 0.899
[10,   800] loss: 0.028 accuracy: 0.898
[10,  1000] loss: 0.029 accuracy: 0.895
[10,  1200] loss: 0.028 accuracy: 0.896



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[11,   200] loss: 0.028 accuracy: 0.898
[11,   400] loss: 0.027 accuracy: 0.905
[11,   600] loss: 0.028 accuracy: 0.901
[11,   800] loss: 0.029 accuracy: 0.899
[11,  1000] loss: 0.027 accuracy: 0.904
[11,  1200] loss: 0.027 accuracy: 0.900



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[12,   200] loss: 0.026 accuracy: 0.906
[12,   400] loss: 0.028 accuracy: 0.897
[12,   600] loss: 0.027 accuracy: 0.903
[12,   800] loss: 0.027 accuracy: 0.902
[12,  1000] loss: 0.027 accuracy: 0.903
[12,  1200] loss: 0.028 accuracy: 0.901



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[13,   200] loss: 0.026 accuracy: 0.905
[13,   400] loss: 0.028 accuracy: 0.903
[13,   600] loss: 0.026 accuracy: 0.906
[13,   800] loss: 0.027 accuracy: 0.903
[13,  1000] loss: 0.027 accuracy: 0.904
[13,  1200] loss: 0.027 accuracy: 0.905



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[14,   200] loss: 0.027 accuracy: 0.905
[14,   400] loss: 0.026 accuracy: 0.911
[14,   600] loss: 0.027 accuracy: 0.904
[14,   800] loss: 0.028 accuracy: 0.902
[14,  1000] loss: 0.026 accuracy: 0.908
[14,  1200] loss: 0.027 accuracy: 0.905



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[15,   200] loss: 0.026 accuracy: 0.910
[15,   400] loss: 0.025 accuracy: 0.913
[15,   600] loss: 0.025 accuracy: 0.911
[15,   800] loss: 0.026 accuracy: 0.907
[15,  1000] loss: 0.027 accuracy: 0.904
[15,  1200] loss: 0.028 accuracy: 0.902



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[16,   200] loss: 0.026 accuracy: 0.908
[16,   400] loss: 0.026 accuracy: 0.908
[16,   600] loss: 0.026 accuracy: 0.907
[16,   800] loss: 0.027 accuracy: 0.905
[16,  1000] loss: 0.027 accuracy: 0.902
[16,  1200] loss: 0.026 accuracy: 0.907



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[17,   200] loss: 0.024 accuracy: 0.915
[17,   400] loss: 0.025 accuracy: 0.911
[17,   600] loss: 0.025 accuracy: 0.911
[17,   800] loss: 0.026 accuracy: 0.909
[17,  1000] loss: 0.026 accuracy: 0.908
[17,  1200] loss: 0.025 accuracy: 0.909



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[18,   200] loss: 0.026 accuracy: 0.911
[18,   400] loss: 0.025 accuracy: 0.908
[18,   600] loss: 0.025 accuracy: 0.911
[18,   800] loss: 0.025 accuracy: 0.912
[18,  1000] loss: 0.025 accuracy: 0.912
[18,  1200] loss: 0.025 accuracy: 0.914



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[19,   200] loss: 0.024 accuracy: 0.913
[19,   400] loss: 0.025 accuracy: 0.911
[19,   600] loss: 0.026 accuracy: 0.905
[19,   800] loss: 0.025 accuracy: 0.914
[19,  1000] loss: 0.025 accuracy: 0.912
[19,  1200] loss: 0.024 accuracy: 0.916



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[20,   200] loss: 0.024 accuracy: 0.916
[20,   400] loss: 0.025 accuracy: 0.911
[20,   600] loss: 0.024 accuracy: 0.917
[20,   800] loss: 0.025 accuracy: 0.914
[20,  1000] loss: 0.023 accuracy: 0.915
[20,  1200] loss: 0.024 accuracy: 0.912



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[21,   200] loss: 0.023 accuracy: 0.918
[21,   400] loss: 0.022 accuracy: 0.922
[21,   600] loss: 0.025 accuracy: 0.912
[21,   800] loss: 0.024 accuracy: 0.917
[21,  1000] loss: 0.025 accuracy: 0.910
[21,  1200] loss: 0.026 accuracy: 0.908



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[22,   200] loss: 0.024 accuracy: 0.917
[22,   400] loss: 0.023 accuracy: 0.922
[22,   600] loss: 0.023 accuracy: 0.919
[22,   800] loss: 0.023 accuracy: 0.921
[22,  1000] loss: 0.025 accuracy: 0.916
[22,  1200] loss: 0.024 accuracy: 0.912



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[23,   200] loss: 0.023 accuracy: 0.917
[23,   400] loss: 0.022 accuracy: 0.922
[23,   600] loss: 0.023 accuracy: 0.918
[23,   800] loss: 0.024 accuracy: 0.914
[23,  1000] loss: 0.024 accuracy: 0.915
[23,  1200] loss: 0.024 accuracy: 0.915



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[24,   200] loss: 0.022 accuracy: 0.924
[24,   400] loss: 0.022 accuracy: 0.922
[24,   600] loss: 0.024 accuracy: 0.916
[24,   800] loss: 0.023 accuracy: 0.918
[24,  1000] loss: 0.022 accuracy: 0.922
[24,  1200] loss: 0.024 accuracy: 0.917



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[25,   200] loss: 0.022 accuracy: 0.923
[25,   400] loss: 0.022 accuracy: 0.922
[25,   600] loss: 0.023 accuracy: 0.921
[25,   800] loss: 0.023 accuracy: 0.919
[25,  1000] loss: 0.024 accuracy: 0.918
[25,  1200] loss: 0.023 accuracy: 0.916



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[26,   200] loss: 0.022 accuracy: 0.921
[26,   400] loss: 0.023 accuracy: 0.920
[26,   600] loss: 0.022 accuracy: 0.923
[26,   800] loss: 0.024 accuracy: 0.917
[26,  1000] loss: 0.022 accuracy: 0.923
[26,  1200] loss: 0.024 accuracy: 0.919



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[27,   200] loss: 0.022 accuracy: 0.925
[27,   400] loss: 0.023 accuracy: 0.920
[27,   600] loss: 0.023 accuracy: 0.917
[27,   800] loss: 0.023 accuracy: 0.919
[27,  1000] loss: 0.022 accuracy: 0.923
[27,  1200] loss: 0.022 accuracy: 0.921



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[28,   200] loss: 0.022 accuracy: 0.923
[28,   400] loss: 0.021 accuracy: 0.926
[28,   600] loss: 0.023 accuracy: 0.918
[28,   800] loss: 0.023 accuracy: 0.920
[28,  1000] loss: 0.022 accuracy: 0.923
[28,  1200] loss: 0.023 accuracy: 0.922



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[29,   200] loss: 0.022 accuracy: 0.927
[29,   400] loss: 0.020 accuracy: 0.928
[29,   600] loss: 0.022 accuracy: 0.921
[29,   800] loss: 0.022 accuracy: 0.923
[29,  1000] loss: 0.022 accuracy: 0.926
[29,  1200] loss: 0.021 accuracy: 0.928



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[30,   200] loss: 0.021 accuracy: 0.925
[30,   400] loss: 0.021 accuracy: 0.925
[30,   600] loss: 0.021 accuracy: 0.926
[30,   800] loss: 0.022 accuracy: 0.921
[30,  1000] loss: 0.022 accuracy: 0.921
[30,  1200] loss: 0.022 accuracy: 0.922



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[31,   200] loss: 0.021 accuracy: 0.928
[31,   400] loss: 0.021 accuracy: 0.928
[31,   600] loss: 0.022 accuracy: 0.922
[31,   800] loss: 0.022 accuracy: 0.923
[31,  1000] loss: 0.022 accuracy: 0.922
[31,  1200] loss: 0.022 accuracy: 0.925



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[32,   200] loss: 0.021 accuracy: 0.929
[32,   400] loss: 0.021 accuracy: 0.926
[32,   600] loss: 0.021 accuracy: 0.927
[32,   800] loss: 0.022 accuracy: 0.923
[32,  1000] loss: 0.022 accuracy: 0.924
[32,  1200] loss: 0.021 accuracy: 0.925



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[33,   200] loss: 0.020 accuracy: 0.931
[33,   400] loss: 0.020 accuracy: 0.930
[33,   600] loss: 0.021 accuracy: 0.928
[33,   800] loss: 0.021 accuracy: 0.925
[33,  1000] loss: 0.022 accuracy: 0.924
[33,  1200] loss: 0.021 accuracy: 0.926



HBox(children=(FloatProgress(value=0.0, max=1378.0), HTML(value='')))

[34,   200] loss: 0.020 accuracy: 0.930


KeyboardInterrupt: 

In [None]:
chunkloader_val = torch.utils.data.DataLoader(chunkset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)


In [52]:
def create_batches_rnd(batch_size,data_folder,wav_lst,N_snt,wlen,lab_dict,fact_amp):
    
 # Initialization of the minibatch (batch_size,[0=>x_t,1=>x_t+N,1=>random_samp])
 sig_batch=np.zeros([batch_size,wlen])
 lab_batch=np.zeros(batch_size)
  
 snt_id_arr=np.random.randint(N_snt, size=batch_size)
 
 rand_amp_arr = np.random.uniform(1.0-fact_amp,1+fact_amp,batch_size)

 for i in range(batch_size):
     
  # select a random sentence from the list  (joint distribution)
  [fs,signal]=scipy.io.wavfile.read(data_folder+wav_lst[snt_id_arr[i]])
  signal=signal.astype(float)/32768

  # accesing to a random chunk
  snt_len=signal.shape[0]
  snt_beg=np.random.randint(snt_len-wlen-1) #randint(0, snt_len-2*wlen-1)
  snt_end=snt_beg+wlen
  
  sig_batch[i,:]=signal[snt_beg:snt_end]*rand_amp_arr[i]
  lab_batch[i]=lab_dict[wav_lst[snt_id_arr[i]]]
  
 inp=torch.from_numpy(sig_batch).float().cuda().contiguous()  # Current Frame
 lab=torch.from_numpy(lab_batch).float().cuda().contiguous()
  
 return inp,lab  

In [None]:
# Full Validation  new  
  if epoch%N_eval_epoch==0:
      
   CNN_net.eval()
   DNN1_net.eval()
   DNN2_net.eval()
   test_flag=1 
   loss_sum=0
   err_sum=0
   err_sum_snt=0
   
   with torch.no_grad():  
    for i in range(snt_te):
       
     #[fs,signal]=scipy.io.wavfile.read(data_folder+wav_lst_te[i])
     #signal=signal.astype(float)/32768

     [signal, fs] = sf.read(data_folder+wav_lst_te[i])

     signal=torch.from_numpy(signal).float().cuda().contiguous()
     lab_batch=lab_dict[wav_lst_te[i]]
    
     # split signals into chunks
     beg_samp=0
     end_samp=wlen
     
     N_fr=int((signal.shape[0]-wlen)/(wshift))
     

     sig_arr=torch.zeros([Batch_dev,wlen]).float().cuda().contiguous()
     lab= Variable((torch.zeros(N_fr+1)+lab_batch).cuda().contiguous().long())
     pout=Variable(torch.zeros(N_fr+1,class_lay[-1]).float().cuda().contiguous())
     count_fr=0
     count_fr_tot=0
     while end_samp<signal.shape[0]:
         sig_arr[count_fr,:]=signal[beg_samp:end_samp]
         beg_samp=beg_samp+wshift
         end_samp=beg_samp+wlen
         count_fr=count_fr+1
         count_fr_tot=count_fr_tot+1
         if count_fr==Batch_dev:
             inp=Variable(sig_arr)
             pout[count_fr_tot-Batch_dev:count_fr_tot,:]=DNN2_net(DNN1_net(CNN_net(inp)))
             count_fr=0
             sig_arr=torch.zeros([Batch_dev,wlen]).float().cuda().contiguous()
   
     if count_fr>0:
      inp=Variable(sig_arr[0:count_fr])
      pout[count_fr_tot-count_fr:count_fr_tot,:]=DNN2_net(DNN1_net(CNN_net(inp)))

    
     pred=torch.max(pout,dim=1)[1]
     loss = cost(pout, lab.long())
     err = torch.mean((pred!=lab.long()).float())
    
     [val,best_class]=torch.max(torch.sum(pout,dim=0),0)
     err_sum_snt=err_sum_snt+(best_class!=lab[0]).float()
    
    
     loss_sum=loss_sum+loss.detach()
     err_sum=err_sum+err.detach()
    
    err_tot_dev_snt=err_sum_snt/snt_te
    loss_tot_dev=loss_sum/snt_te
    err_tot_dev=err_sum/snt_te

  
   print("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt))
  
   with open(output_folder+"/res.res", "a") as res_file:
    res_file.write("epoch %i, loss_tr=%f err_tr=%f loss_te=%f err_te=%f err_te_snt=%f\n" % (epoch, loss_tot,err_tot,loss_tot_dev,err_tot_dev,err_tot_dev_snt))   

   checkpoint={'CNN_model_par': CNN_net.state_dict(),
               'DNN1_model_par': DNN1_net.state_dict(),
               'DNN2_model_par': DNN2_net.state_dict(),
               }
   torch.save(checkpoint,output_folder+'/model_raw.pkl')
  
  else:
   print("epoch %i, loss_tr=%f err_tr=%f" % (epoch, loss_tot,err_tot))