In [1]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
from conv_tasnet import *
from pit_criterion import cal_loss
from collections import Counter
import os
device = 0
device_ids = [0, 1, 2, 3]
root = './'
sr = 8000
torch.cuda.set_device(device)

# Pre-Train Baseline

In [8]:
tas2 = ConvTasNet.load_model('final.pth.tar').cuda(device)
tasnet = ConvTasNet(N = 256, L = 20, B = 256, H = 512, P = 3, X = 8, R = 4, C = 3, norm_type="gLN", causal=0,
             mask_nonlinear='relu').cuda(device)

own_state = tasnet.state_dict()
for name, param in tqdm(tas2.state_dict().items()):
    if name not in own_state:
         continue
    if isinstance(param, torch.nn.Parameter):
        param = param.data
    try:
        own_state[name].copy_(param)
    except:
        print('shape mismatch')

optimizer = torch.optim.Adam(tasnet.parameters(), lr = 0.0001)
tasnet = nn.DataParallel(tasnet, device_ids = device_ids)

if os.path.exists('models/tasnet.pth'):
    print('load model')
    checkpoint = torch.load('models/tasnet.pth')
    tasnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
    
tasnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=294.0), HTML(value='')))

shape mismatch

load model


In [9]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, num_talker = 3):
        super().__init__()
        self.segments = pd.read_csv(root+csv)
        self.speakers = list(set(self.segments['speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.num_talker = num_talker
        
    def __len__(self):
        return len(self.segments)
    def __getitem__(self, idx):
        talkers = []
        sigs = []
        for i in range(self.num_talker):
            if i!=0:
                idx = np.random.randint(len(self.segments))
            seg = self.segments.iloc[idx]
            while seg['speaker'] in talkers:
                idx = np.random.randint(len(self.segments))
                seg = self.segments.iloc[idx]
            sig = np.load(root+seg['segfile'])
            sig-=np.mean(sig)
            sig/=np.std(sig)
            sigs.append(sig)
            talkers.append(seg['speaker'])
        sigs = np.array(sigs)
        return np.sum(sigs, axis = 0), sigs


trainset = OverlayDataSet('files/train-segments.csv')
valset = OverlayDataSet('files/val-segments.csv')
testset = OverlayDataSet('files/test-segments.csv')
idx = np.random.randint(len(testset))

mixture, sources = testset[idx]
Audio(mixture, rate = sr)

In [10]:
Audio(sources[0], rate = sr)

In [11]:
Audio(sources[1], rate = sr)

In [12]:
Audio(sources[2], rate = sr)

In [14]:
batch_size = 32
trainloader  = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)
tasnet.train()

for epoch in range(4, 64):
    running_loss = 0.0
    for batch_idx, (mixture, sources) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
        out = tasnet(mixture)
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(sources, out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr) # 2 seconds has 2*sr samples
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
            torch.save({
            'model_state_dict': tasnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/tasnet.pth')
    with torch.no_grad():
        running_loss = 0.0
        tasnet.eval()
        for batch_idx, (mixture, sources) in enumerate(tqdm(valloader)):
            mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
            out = tasnet(mixture)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(sources, out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr)

            running_loss += loss.item()*sources.shape[0]
        print(running_loss/len(valset))

HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[5,   200] loss: -7.928 
[5,   400] loss: -7.895 
[5,   600] loss: -7.682 
[5,   800] loss: -7.773 
[5,  1000] loss: -7.862 
[5,  1200] loss: -7.998 
[5,  1400] loss: -7.806 
[5,  1600] loss: -7.667 
[5,  1800] loss: -7.899 
[5,  2000] loss: -7.803 
[5,  2200] loss: -8.122 
[5,  2400] loss: -7.979 
[5,  2600] loss: -7.746 
[5,  2800] loss: -8.007 
[5,  3000] loss: -7.924 
[5,  3200] loss: -7.942 
[5,  3400] loss: -7.958 
[5,  3600] loss: -7.824 
[5,  3800] loss: -7.760 
[5,  4000] loss: -7.952 
[5,  4200] loss: -7.758 
[5,  4400] loss: -7.811 
[5,  4600] loss: -7.710 
[5,  4800] loss: -8.013 
[5,  5000] loss: -7.849 
[5,  5200] loss: -8.020 
[5,  5400] loss: -7.859 
[5,  5600] loss: -7.809 
[5,  5800] loss: -7.913 
[5,  6000] loss: -7.938 
[5,  6200] loss: -7.909 
[5,  6400] loss: -7.940 
[5,  6600] loss: -8.010 
[5,  6800] loss: -7.900 
[5,  7000] loss: -7.996 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.6879274742156865


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[6,   200] loss: -7.976 
[6,   400] loss: -7.996 
[6,   600] loss: -7.765 
[6,   800] loss: -7.889 
[6,  1000] loss: -7.891 
[6,  1200] loss: -8.080 
[6,  1400] loss: -7.920 
[6,  1600] loss: -7.755 
[6,  1800] loss: -7.947 
[6,  2000] loss: -7.794 
[6,  2200] loss: -8.123 
[6,  2400] loss: -8.074 
[6,  2600] loss: -7.907 
[6,  2800] loss: -8.060 
[6,  3000] loss: -8.051 
[6,  3200] loss: -8.022 
[6,  3400] loss: -8.022 
[6,  3600] loss: -7.896 
[6,  3800] loss: -7.790 
[6,  4000] loss: -8.077 
[6,  4200] loss: -7.940 
[6,  4400] loss: -7.910 
[6,  4600] loss: -7.700 
[6,  4800] loss: -8.112 
[6,  5000] loss: -7.973 
[6,  5200] loss: -8.023 
[6,  5400] loss: -7.875 
[6,  5600] loss: -7.958 
[6,  5800] loss: -8.018 
[6,  6000] loss: -7.974 
[6,  6200] loss: -7.925 
[6,  6400] loss: -7.896 
[6,  6600] loss: -8.014 
[6,  6800] loss: -7.932 
[6,  7000] loss: -7.997 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.6902729885511665


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[7,   200] loss: -7.996 
[7,   400] loss: -8.059 
[7,   600] loss: -7.815 
[7,   800] loss: -7.870 
[7,  1000] loss: -7.964 
[7,  1200] loss: -8.082 
[7,  1400] loss: -7.912 
[7,  1600] loss: -7.748 
[7,  1800] loss: -7.947 
[7,  2000] loss: -7.816 
[7,  2200] loss: -8.144 
[7,  2400] loss: -8.061 
[7,  2600] loss: -7.853 
[7,  2800] loss: -8.070 
[7,  3000] loss: -8.061 
[7,  3200] loss: -8.059 
[7,  3400] loss: -8.019 
[7,  3600] loss: -7.862 
[7,  3800] loss: -7.856 
[7,  4000] loss: -8.037 
[7,  4200] loss: -7.846 
[7,  4400] loss: -8.013 
[7,  4600] loss: -7.776 
[7,  4800] loss: -8.096 
[7,  5000] loss: -8.003 
[7,  5200] loss: -8.103 
[7,  5400] loss: -7.930 
[7,  5600] loss: -7.991 
[7,  5800] loss: -8.044 
[7,  6000] loss: -8.010 
[7,  6200] loss: -7.953 
[7,  6400] loss: -7.988 
[7,  6600] loss: -8.105 
[7,  6800] loss: -7.978 
[7,  7000] loss: -7.896 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.802821043098115


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[8,   200] loss: -8.018 
[8,   400] loss: -8.077 
[8,   600] loss: -7.818 
[8,   800] loss: -7.835 
[8,  1000] loss: -7.903 
[8,  1200] loss: -8.097 
[8,  1400] loss: -7.936 
[8,  1600] loss: -7.821 
[8,  1800] loss: -8.010 
[8,  2000] loss: -7.898 
[8,  2200] loss: -8.150 
[8,  2400] loss: -8.115 
[8,  2600] loss: -7.886 
[8,  2800] loss: -8.177 
[8,  3000] loss: -8.070 
[8,  3200] loss: -8.063 
[8,  3400] loss: -7.989 
[8,  3600] loss: -7.972 
[8,  3800] loss: -7.798 
[8,  4000] loss: -8.053 
[8,  4200] loss: -7.995 
[8,  4400] loss: -8.019 
[8,  4600] loss: -7.839 
[8,  4800] loss: -8.089 
[8,  5000] loss: -7.956 
[8,  5200] loss: -8.009 
[8,  5400] loss: -7.889 
[8,  5600] loss: -7.997 
[8,  5800] loss: -7.961 
[8,  6000] loss: -7.933 
[8,  6200] loss: -7.975 
[8,  6400] loss: -8.012 
[8,  6600] loss: -8.062 
[8,  6800] loss: -8.013 
[8,  7000] loss: -8.015 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.789959502505116


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[9,   200] loss: -8.070 
[9,   400] loss: -8.116 
[9,   600] loss: -7.849 
[9,   800] loss: -7.976 
[9,  1000] loss: -7.893 
[9,  1200] loss: -8.156 
[9,  1400] loss: -8.036 
[9,  1600] loss: -7.881 
[9,  1800] loss: -7.950 
[9,  2000] loss: -7.843 
[9,  2200] loss: -8.118 
[9,  2400] loss: -8.078 
[9,  2600] loss: -7.973 
[9,  2800] loss: -8.082 
[9,  3000] loss: -8.110 
[9,  3200] loss: -8.141 
[9,  3400] loss: -8.119 
[9,  3600] loss: -8.029 
[9,  3800] loss: -7.960 
[9,  4000] loss: -8.087 
[9,  4200] loss: -7.986 
[9,  4400] loss: -8.034 
[9,  4600] loss: -7.861 
[9,  4800] loss: -8.062 
[9,  5000] loss: -8.015 
[9,  5200] loss: -8.112 
[9,  5400] loss: -7.928 
[9,  5600] loss: -8.037 
[9,  5800] loss: -7.972 
[9,  6000] loss: -7.989 
[9,  6200] loss: -7.933 
[9,  6400] loss: -8.117 
[9,  6600] loss: -8.102 
[9,  6800] loss: -8.015 
[9,  7000] loss: -8.074 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.765552412238254


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[10,   200] loss: -7.976 
[10,   400] loss: -7.994 
[10,   600] loss: -7.882 
[10,   800] loss: -7.965 
[10,  1000] loss: -7.947 
[10,  1200] loss: -8.071 
[10,  1400] loss: -7.972 
[10,  1600] loss: -7.830 
[10,  1800] loss: -7.997 
[10,  2000] loss: -7.840 
[10,  2200] loss: -8.156 
[10,  2400] loss: -8.178 
[10,  2600] loss: -7.929 
[10,  2800] loss: -8.173 
[10,  3000] loss: -8.142 
[10,  3200] loss: -8.147 
[10,  3400] loss: -8.085 
[10,  3600] loss: -7.972 
[10,  3800] loss: -7.875 
[10,  4000] loss: -8.133 
[10,  4200] loss: -7.967 
[10,  4400] loss: -8.000 
[10,  4600] loss: -7.854 
[10,  4800] loss: -8.135 
[10,  5000] loss: -8.086 
[10,  5200] loss: -8.111 
[10,  5400] loss: -8.026 
[10,  5600] loss: -8.034 
[10,  5800] loss: -8.130 
[10,  6000] loss: -8.109 
[10,  6200] loss: -8.043 
[10,  6400] loss: -8.122 
[10,  6600] loss: -8.126 
[10,  6800] loss: -7.969 
[10,  7000] loss: -8.062 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.810931867926244


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[11,   200] loss: -8.102 
[11,   400] loss: -8.099 
[11,   600] loss: -7.836 
[11,   800] loss: -7.888 
[11,  1000] loss: -7.956 
[11,  1200] loss: -8.230 
[11,  1400] loss: -7.992 
[11,  1600] loss: -7.863 
[11,  1800] loss: -8.068 
[11,  2000] loss: -7.950 
[11,  2200] loss: -8.212 
[11,  2400] loss: -8.205 
[11,  2600] loss: -7.965 
[11,  2800] loss: -8.233 
[11,  3000] loss: -8.209 
[11,  3200] loss: -8.133 
[11,  3400] loss: -8.103 
[11,  3600] loss: -8.009 
[11,  3800] loss: -7.922 
[11,  4000] loss: -8.128 
[11,  4200] loss: -8.069 
[11,  4400] loss: -8.049 
[11,  4600] loss: -7.914 
[11,  4800] loss: -8.223 
[11,  5000] loss: -8.064 
[11,  5200] loss: -8.158 
[11,  5400] loss: -7.997 
[11,  5600] loss: -8.049 
[11,  5800] loss: -8.133 
[11,  6000] loss: -8.116 
[11,  6200] loss: -8.041 
[11,  6400] loss: -8.088 
[11,  6600] loss: -8.167 
[11,  6800] loss: -8.079 
[11,  7000] loss: -8.091 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.7507980652159425


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[12,   200] loss: -8.169 
[12,   400] loss: -8.186 
[12,   600] loss: -7.907 
[12,   800] loss: -7.988 
[12,  1000] loss: -8.051 
[12,  1200] loss: -8.206 
[12,  1400] loss: -8.001 
[12,  1600] loss: -7.927 
[12,  1800] loss: -8.073 
[12,  2000] loss: -7.939 
[12,  2200] loss: -8.242 
[12,  2400] loss: -8.185 
[12,  2600] loss: -8.024 
[12,  2800] loss: -8.232 
[12,  3000] loss: -8.233 
[12,  3200] loss: -8.222 
[12,  3400] loss: -8.135 
[12,  3600] loss: -8.054 
[12,  3800] loss: -7.987 
[12,  4000] loss: -8.208 
[12,  4200] loss: -8.106 
[12,  4400] loss: -8.085 
[12,  4600] loss: -7.913 
[12,  4800] loss: -8.211 
[12,  5000] loss: -8.085 
[12,  5200] loss: -8.219 
[12,  5400] loss: -8.097 
[12,  5600] loss: -8.100 
[12,  5800] loss: -8.141 
[12,  6000] loss: -8.111 
[12,  6200] loss: -8.101 
[12,  6400] loss: -8.114 
[12,  6600] loss: -8.210 
[12,  6800] loss: -8.062 
[12,  7000] loss: -8.099 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.867658137743216


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[13,   200] loss: -8.134 
[13,   400] loss: -8.226 
[13,   600] loss: -7.949 
[13,   800] loss: -7.970 
[13,  1000] loss: -8.139 
[13,  1200] loss: -8.153 
[13,  1400] loss: -8.114 
[13,  1600] loss: -7.991 
[13,  1800] loss: -8.054 
[13,  2000] loss: -8.020 
[13,  2200] loss: -8.234 
[13,  2400] loss: -8.153 
[13,  2600] loss: -8.022 
[13,  2800] loss: -8.138 
[13,  3000] loss: -8.137 
[13,  3200] loss: -8.198 
[13,  3400] loss: -8.140 
[13,  3600] loss: -8.095 
[13,  3800] loss: -7.975 
[13,  4000] loss: -8.209 
[13,  4200] loss: -8.038 
[13,  4400] loss: -8.101 
[13,  4600] loss: -7.935 
[13,  4800] loss: -8.289 
[13,  5000] loss: -8.210 
[13,  5200] loss: -8.169 
[13,  5400] loss: -8.105 
[13,  5600] loss: -8.162 
[13,  5800] loss: -8.229 
[13,  6000] loss: -8.085 
[13,  6200] loss: -8.045 
[13,  6400] loss: -8.091 
[13,  6600] loss: -8.169 
[13,  6800] loss: -8.051 
[13,  7000] loss: -8.194 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.900429351320304


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[14,   200] loss: -8.199 
[14,   400] loss: -8.297 
[14,   600] loss: -7.963 
[14,   800] loss: -7.990 
[14,  1000] loss: -8.149 
[14,  1200] loss: -8.290 
[14,  1400] loss: -8.175 
[14,  1600] loss: -7.964 
[14,  1800] loss: -8.114 
[14,  2000] loss: -8.076 
[14,  2200] loss: -8.287 
[14,  2400] loss: -8.298 
[14,  2600] loss: -8.127 
[14,  2800] loss: -8.279 
[14,  3000] loss: -8.286 
[14,  3200] loss: -8.118 
[14,  3400] loss: -8.184 
[14,  3600] loss: -8.103 
[14,  3800] loss: -8.016 
[14,  4000] loss: -8.175 
[14,  4200] loss: -8.102 
[14,  4400] loss: -8.199 
[14,  4600] loss: -7.984 
[14,  4800] loss: -8.341 
[14,  5000] loss: -8.153 
[14,  5200] loss: -8.283 
[14,  5400] loss: -8.142 
[14,  5600] loss: -8.127 
[14,  5800] loss: -8.164 
[14,  6000] loss: -8.136 
[14,  6200] loss: -8.109 
[14,  6400] loss: -8.146 
[14,  6600] loss: -8.207 
[14,  6800] loss: -8.113 
[14,  7000] loss: -8.220 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.946555136874378


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[15,   200] loss: -8.269 
[15,   400] loss: -8.194 
[15,   600] loss: -7.942 
[15,   800] loss: -7.971 
[15,  1000] loss: -7.989 
[15,  1200] loss: -8.172 
[15,  1400] loss: -8.059 
[15,  1600] loss: -7.879 
[15,  1800] loss: -8.035 
[15,  2000] loss: -7.924 
[15,  2200] loss: -8.146 
[15,  2400] loss: -8.287 
[15,  2600] loss: -8.030 
[15,  2800] loss: -8.259 
[15,  3000] loss: -8.199 
[15,  3200] loss: -8.166 
[15,  3400] loss: -8.219 
[15,  3600] loss: -8.080 
[15,  3800] loss: -7.923 
[15,  4000] loss: -8.204 
[15,  4200] loss: -8.084 
[15,  4400] loss: -8.112 
[15,  4600] loss: -7.982 
[15,  4800] loss: -8.253 
[15,  5000] loss: -8.059 
[15,  5200] loss: -8.254 
[15,  5400] loss: -8.166 
[15,  5600] loss: -8.178 
[15,  5800] loss: -8.171 
[15,  6000] loss: -8.086 
[15,  6200] loss: -8.051 
[15,  6400] loss: -8.073 
[15,  6600] loss: -8.103 
[15,  6800] loss: -8.106 
[15,  7000] loss: -8.135 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.917469855988643


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[16,   200] loss: -8.264 
[16,   400] loss: -8.287 
[16,   600] loss: -8.027 
[16,   800] loss: -8.079 
[16,  1000] loss: -8.076 
[16,  1200] loss: -8.226 
[16,  1400] loss: -8.192 
[16,  1600] loss: -7.986 
[16,  1800] loss: -8.091 
[16,  2000] loss: -8.080 
[16,  2200] loss: -8.324 
[16,  2400] loss: -8.284 
[16,  2600] loss: -8.098 
[16,  2800] loss: -8.304 
[16,  3000] loss: -8.321 
[16,  3200] loss: -8.236 
[16,  3400] loss: -8.256 
[16,  3600] loss: -8.130 
[16,  3800] loss: -7.955 
[16,  4000] loss: -8.225 
[16,  4200] loss: -8.178 
[16,  4400] loss: -8.152 
[16,  4600] loss: -7.988 
[16,  4800] loss: -8.312 
[16,  5000] loss: -8.231 
[16,  5200] loss: -8.271 
[16,  5400] loss: -8.151 
[16,  5600] loss: -8.186 
[16,  5800] loss: -8.173 
[16,  6000] loss: -8.212 
[16,  6200] loss: -8.152 
[16,  6400] loss: -8.217 
[16,  6600] loss: -8.267 
[16,  6800] loss: -8.121 
[16,  7000] loss: -8.165 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.973725526665311


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[17,   200] loss: -8.210 
[17,   400] loss: -8.330 
[17,   600] loss: -8.016 
[17,   800] loss: -8.062 
[17,  1000] loss: -8.140 
[17,  1200] loss: -8.333 
[17,  1400] loss: -8.110 
[17,  1600] loss: -8.030 
[17,  1800] loss: -8.159 
[17,  2000] loss: -8.080 
[17,  2200] loss: -8.270 
[17,  2400] loss: -8.272 
[17,  2600] loss: -8.085 
[17,  2800] loss: -8.283 
[17,  3000] loss: -8.270 
[17,  3200] loss: -8.146 
[17,  3400] loss: -8.275 
[17,  3600] loss: -8.082 
[17,  3800] loss: -7.977 
[17,  4000] loss: -8.190 
[17,  4200] loss: -8.114 
[17,  4400] loss: -8.132 
[17,  4600] loss: -8.047 
[17,  4800] loss: -8.291 
[17,  5000] loss: -8.212 
[17,  5200] loss: -8.259 
[17,  5400] loss: -8.228 
[17,  5600] loss: -8.242 
[17,  5800] loss: -8.227 
[17,  6000] loss: -8.152 
[17,  6200] loss: -8.208 
[17,  6400] loss: -8.204 
[17,  6600] loss: -8.267 
[17,  6800] loss: -8.229 
[17,  7000] loss: -8.219 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.969617295664145


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[18,   200] loss: -8.260 
[18,   400] loss: -8.334 
[18,   600] loss: -7.995 
[18,   800] loss: -8.041 
[18,  1000] loss: -8.025 
[18,  1200] loss: -8.307 
[18,  1400] loss: -8.141 
[18,  1600] loss: -8.023 
[18,  1800] loss: -8.125 
[18,  2000] loss: -8.118 
[18,  2200] loss: -8.389 
[18,  2400] loss: -8.255 
[18,  2600] loss: -8.177 
[18,  2800] loss: -8.325 
[18,  3000] loss: -8.330 
[18,  3200] loss: -8.307 
[18,  3400] loss: -8.290 
[18,  3600] loss: -8.124 
[18,  3800] loss: -8.074 
[18,  4000] loss: -8.233 
[18,  4200] loss: -8.153 
[18,  4400] loss: -8.230 
[18,  4600] loss: -8.021 
[18,  4800] loss: -8.338 
[18,  5000] loss: -8.133 
[18,  5200] loss: -8.325 
[18,  5400] loss: -8.154 
[18,  5600] loss: -8.184 
[18,  5800] loss: -8.270 
[18,  6000] loss: -8.277 
[18,  6200] loss: -8.206 
[18,  6400] loss: -8.244 
[18,  6600] loss: -8.215 
[18,  6800] loss: -8.186 
[18,  7000] loss: -8.202 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-8.081558935309786


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[19,   200] loss: -8.343 
[19,   400] loss: -8.393 
[19,   600] loss: -8.049 
[19,   800] loss: -8.200 
[19,  1000] loss: -8.184 
[19,  1200] loss: -8.401 
[19,  1400] loss: -8.123 
[19,  1600] loss: -8.061 
[19,  1800] loss: -8.262 
[19,  2000] loss: -8.179 
[19,  2200] loss: -8.370 
[19,  2400] loss: -8.284 
[19,  2600] loss: -8.169 
[19,  2800] loss: -8.405 
[19,  3000] loss: -8.450 
[19,  3200] loss: -8.289 
[19,  3400] loss: -8.377 
[19,  3600] loss: -8.215 
[19,  3800] loss: -8.042 
[19,  4000] loss: -8.307 
[19,  4200] loss: -8.220 
[19,  4400] loss: -8.291 
[19,  4600] loss: -8.018 
[19,  4800] loss: -8.325 
[19,  5000] loss: -8.219 
[19,  5200] loss: -8.323 
[19,  5400] loss: -8.254 
[19,  5600] loss: -8.194 
[19,  5800] loss: -8.256 
[19,  6000] loss: -8.204 
[19,  6200] loss: -8.253 
[19,  6400] loss: -8.373 
[19,  6600] loss: -8.341 
[19,  6800] loss: -8.193 
[19,  7000] loss: -8.261 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-8.065014668194896


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[20,   200] loss: -8.373 
[20,   400] loss: -8.313 
[20,   600] loss: -8.109 
[20,   800] loss: -8.223 
[20,  1000] loss: -8.271 
[20,  1200] loss: -8.500 
[20,  1400] loss: -8.257 
[20,  1600] loss: -8.071 
[20,  1800] loss: -8.270 
[20,  2000] loss: -8.177 
[20,  2200] loss: -8.464 
[20,  2400] loss: -8.412 
[20,  2600] loss: -8.184 
[20,  2800] loss: -8.428 
[20,  3000] loss: -8.405 
[20,  3200] loss: -8.393 
[20,  3400] loss: -8.335 
[20,  3600] loss: -8.219 
[20,  3800] loss: -8.108 
[20,  4000] loss: -8.333 
[20,  4200] loss: -8.233 
[20,  4400] loss: -8.278 
[20,  4600] loss: -8.123 
[20,  4800] loss: -8.424 
[20,  5000] loss: -8.275 
[20,  5200] loss: -8.301 
[20,  5400] loss: -8.268 
[20,  5600] loss: -8.246 
[20,  5800] loss: -8.320 
[20,  6000] loss: -8.298 
[20,  6200] loss: -8.202 
[20,  6400] loss: -8.295 
[20,  6600] loss: -8.404 
[20,  6800] loss: -8.198 
[20,  7000] loss: -8.276 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-7.978830134327193


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[21,   200] loss: -8.292 
[21,   400] loss: -8.354 
[21,   600] loss: -8.075 
[21,   800] loss: -8.152 
[21,  1000] loss: -8.232 
[21,  1200] loss: -8.414 
[21,  1400] loss: -8.304 
[21,  1600] loss: -8.159 
[21,  1800] loss: -8.264 
[21,  2000] loss: -8.167 
[21,  2200] loss: -8.382 
[21,  2400] loss: -8.390 
[21,  2600] loss: -8.148 
[21,  2800] loss: -8.289 
[21,  3000] loss: -8.381 
[21,  3200] loss: -8.342 
[21,  3400] loss: -8.290 
[21,  3600] loss: -8.149 
[21,  3800] loss: -8.059 
[21,  4000] loss: -8.351 
[21,  4200] loss: -8.279 
[21,  4400] loss: -8.269 
[21,  4600] loss: -8.108 
[21,  4800] loss: -8.325 
[21,  5000] loss: -8.287 
[21,  5200] loss: -8.326 
[21,  5400] loss: -8.304 
[21,  5600] loss: -8.270 
[21,  5800] loss: -8.400 
[21,  6000] loss: -8.275 
[21,  6200] loss: -8.278 
[21,  6400] loss: -8.313 
[21,  6600] loss: -8.351 
[21,  6800] loss: -8.279 
[21,  7000] loss: -8.322 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-8.010127165896959


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[22,   200] loss: -8.304 
[22,   400] loss: -8.321 
[22,   600] loss: -8.166 
[22,   800] loss: -8.180 
[22,  1000] loss: -8.252 
[22,  1200] loss: -8.404 
[22,  1400] loss: -8.236 
[22,  1600] loss: -8.094 
[22,  1800] loss: -8.275 
[22,  2000] loss: -8.180 
[22,  2200] loss: -8.425 
[22,  2400] loss: -8.402 
[22,  2600] loss: -8.208 
[22,  2800] loss: -8.376 
[22,  3000] loss: -8.427 
[22,  3200] loss: -8.385 
[22,  3400] loss: -8.342 
[22,  3600] loss: -8.181 
[22,  3800] loss: -8.065 
[22,  4000] loss: -8.377 
[22,  4200] loss: -8.256 
[22,  4400] loss: -8.341 
[22,  4600] loss: -8.136 
[22,  4800] loss: -8.436 
[22,  5000] loss: -8.345 
[22,  5200] loss: -8.429 
[22,  5400] loss: -8.371 
[22,  5600] loss: -8.309 
[22,  5800] loss: -8.290 
[22,  6000] loss: -8.334 
[22,  6200] loss: -8.339 
[22,  6400] loss: -8.314 
[22,  6600] loss: -8.343 
[22,  6800] loss: -8.315 
[22,  7000] loss: -8.372 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-8.11629737762816


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[23,   200] loss: -8.421 
[23,   400] loss: -8.538 
[23,   600] loss: -8.118 
[23,   800] loss: -8.251 
[23,  1000] loss: -8.263 
[23,  1200] loss: -8.459 
[23,  1400] loss: -8.245 
[23,  1600] loss: -8.181 
[23,  1800] loss: -8.304 
[23,  2000] loss: -8.202 
[23,  2200] loss: -8.488 
[23,  2400] loss: -8.413 
[23,  2600] loss: -8.286 
[23,  2800] loss: -8.476 
[23,  3000] loss: -8.487 
[23,  3200] loss: -8.386 
[23,  3400] loss: -8.381 
[23,  3600] loss: -8.292 
[23,  3800] loss: -8.195 
[23,  4000] loss: -8.396 
[23,  4400] loss: -8.340 
[23,  4600] loss: -8.116 
[23,  4800] loss: -8.366 
[23,  5000] loss: -8.277 
[23,  5200] loss: -8.390 
[23,  5400] loss: -8.339 
[23,  5600] loss: -8.325 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[24,  1800] loss: -8.276 
[24,  2000] loss: -8.240 
[24,  2200] loss: -8.464 
[24,  2400] loss: -8.433 
[24,  2600] loss: -8.223 
[24,  2800] loss: -8.547 
[24,  3000] loss: -8.480 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[24,  6800] loss: -8.319 
[24,  7000] loss: -8.467 



HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-8.17405060543957


HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

[25,   200] loss: -8.514 
[25,   400] loss: -8.501 
[25,   600] loss: -8.159 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[25,  3400] loss: -8.427 
[25,  3600] loss: -8.340 
[25,  3800] loss: -8.211 
[25,  4000] loss: -8.453 
[25,  4200] loss: -8.424 
[25,  4400] loss: -8.391 
[25,  4600] loss: -8.235 
[25,  4800] loss: -8.486 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[26,   600] loss: -8.228 
[26,   800] loss: -8.207 
[26,  1000] loss: -8.330 
[26,  1200] loss: -8.541 
[26,  1400] loss: -8.290 
[26,  1600] loss: -8.218 
[26,  1800] loss: -8.380 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[26,  4400] loss: -8.343 
[26,  4600] loss: -8.123 


Traceback (most recent call last):


KeyboardInterrupt: 

  File "/home/junzhez2/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/home/junzhez2/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/junzhez2/anaconda3/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/home/junzhez2/anaconda3/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


# Maybe also try doing recursive defogging

In [19]:
torch.manual_seed(0)
np.random.seed(0)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)
with torch.no_grad():
    running_loss = 0.0
    tasnet.eval()
    for batch_idx, (mixture, sources) in enumerate(tqdm(testloader)):
        mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
        out = tasnet(mixture)
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(sources, out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr)

        running_loss += loss.item()*sources.shape[0]
    print(running_loss/len(testset))

HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-8.075121239824593


In [18]:
# baseline TasNet SNR: -8.117dB
torch.save({
'model_state_dict': tasnet.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, 'models/stable/tasnet-baseline.pth')