In [1]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
sys.path.append('SincNet/')
from conv_tasnet import *
from pit_criterion import cal_loss
from dnn_models import *
from data_io import ReadList,read_conf_inp,str_to_bool
from collections import Counter
import os
device = 3
device_ids = [3]
root = '../'
old_sr = 16000
new_sr = 8000

In [2]:
def load8hz(filename):
    samples = np.load(filename)/(2**15)
    samples = resample(samples, old_sr, new_sr)
    # pad the samples
    if len(samples)>2*new_sr:
        samples = samples[:2*new_sr]
    if len(samples)<2*new_sr:
        padding = np.zeros(2*new_sr-len(samples))
        samples = np.concatenate([samples, padding])
    
    return samples

class SourceSet(torch.utils.data.Dataset):
    def __init__(self, csv):
        super().__init__()
        self.csv = pd.read_csv(csv)
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        sig1, sig2 = load8hz(root+row['first_file']), load8hz(root+row['second_file'])
        return sig1, sig2
sourceset_train = SourceSet('overlay-train.csv')


In [4]:
tasnet = ConvTasNet.load_model('final.pth.tar').cuda(device)
optimizer = torch.optim.Adam(tasnet.parameters(), lr = 0.001)
if os.path.exists('models/tasnet.pth'):
    print('load model')
    checkpoint = torch.load('models/tasnet.pth')
    tasnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
loss

load model


tensor(-13.8432, device='cuda:3', requires_grad=True)

In [4]:
batch_size = 8
sourceloader_train  = torch.utils.data.DataLoader(sourceset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
tasnet.train()

for epoch in range(64):
    running_loss = 0.0
    for batch_idx, (sig1, sig2) in enumerate(tqdm(sourceloader_train)):
        optimizer.zero_grad()
        sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
        out = tasnet(sig1+sig2)
        source = torch.stack([sig1, sig2], dim = 1).detach()
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(source, out, torch.ones(batch_size, dtype = torch.int32).cuda(device)*2*new_sr)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
            torch.save({
            'model_state_dict': tasnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/tasnet.pth')

HBox(children=(FloatProgress(value=0.0, max=9500.0), HTML(value='')))

  frame = signal.new_tensor(frame).long()  # signal may in GPU or CPU


[1,   200] loss: -5.167 
[1,   400] loss: -6.867 
[1,   600] loss: -7.426 
[1,   800] loss: -7.843 
[1,  1000] loss: -8.265 
[1,  1200] loss: -8.652 
[1,  1400] loss: -8.868 
[1,  1600] loss: -9.058 
[1,  1800] loss: -9.171 
[1,  2000] loss: -9.386 
[1,  2200] loss: -9.449 
[1,  2400] loss: -9.769 
[1,  2600] loss: -9.942 
[1,  2800] loss: -9.882 
[1,  3000] loss: -9.975 
[1,  3200] loss: -10.000 
[1,  3400] loss: -10.031 
[1,  3600] loss: -10.213 
[1,  3800] loss: -10.247 
[1,  4000] loss: -10.360 
[1,  4200] loss: -10.266 
[1,  4400] loss: -10.541 
[1,  4600] loss: -10.666 
[1,  4800] loss: -10.592 
[1,  5000] loss: -10.697 
[1,  5200] loss: -10.730 
[1,  5400] loss: -10.705 
[1,  5600] loss: -10.744 
[1,  5800] loss: -10.895 
[1,  6000] loss: -11.050 
[1,  6200] loss: -10.962 
[1,  6400] loss: -11.099 
[1,  6600] loss: -11.007 
[1,  6800] loss: -11.065 
[1,  7000] loss: -11.032 
[1,  7200] loss: -11.152 
[1,  7400] loss: -11.268 
[1,  7600] loss: -11.165 
[1,  7800] loss: -11.271 
[

HBox(children=(FloatProgress(value=0.0, max=9500.0), HTML(value='')))

[2,   200] loss: -11.438 
[2,   400] loss: -11.493 
[2,   600] loss: -11.572 
[2,   800] loss: -11.673 
[2,  1000] loss: -11.650 
[2,  1200] loss: -11.520 
[2,  1400] loss: -11.775 
[2,  1600] loss: -11.709 
[2,  1800] loss: -11.776 
[2,  2000] loss: -11.600 
[2,  2200] loss: -11.804 
[2,  2400] loss: -11.794 
[2,  2600] loss: -11.769 
[2,  2800] loss: -11.780 
[2,  3000] loss: -11.720 
[2,  3200] loss: -11.950 
[2,  3400] loss: -11.748 
[2,  3600] loss: -11.908 
[2,  3800] loss: -11.997 
[2,  4000] loss: -11.885 
[2,  4200] loss: -11.854 
[2,  4400] loss: -11.910 
[2,  4600] loss: -11.927 
[2,  4800] loss: -11.831 
[2,  5000] loss: -11.940 
[2,  5200] loss: -11.819 
[2,  5400] loss: -11.880 
[2,  5600] loss: -11.876 
[2,  5800] loss: -12.022 
[2,  6000] loss: -11.914 
[2,  6200] loss: -11.783 
[2,  6400] loss: -11.979 
[2,  6600] loss: -11.900 
[2,  6800] loss: -12.005 
[2,  7000] loss: -11.891 
[2,  7200] loss: -12.066 
[2,  7400] loss: -11.958 
[2,  7600] loss: -12.135 
[2,  7800] l

HBox(children=(FloatProgress(value=0.0, max=9500.0), HTML(value='')))

[3,   200] loss: -12.362 
[3,   400] loss: -12.326 
[3,   600] loss: -12.300 
[3,   800] loss: -12.266 
[3,  1000] loss: -12.254 
[3,  1200] loss: -12.245 
[3,  1400] loss: -12.312 
[3,  1600] loss: -12.201 
[3,  1800] loss: -12.231 
[3,  2000] loss: -12.165 
[3,  2200] loss: -12.342 
[3,  2400] loss: -12.292 
[3,  2600] loss: -12.323 
[3,  2800] loss: -12.338 
[3,  3000] loss: -12.404 
[3,  3200] loss: -12.299 
[3,  3400] loss: -12.360 
[3,  3600] loss: -12.338 
[3,  3800] loss: -12.484 
[3,  4000] loss: -12.401 
[3,  4200] loss: -12.341 
[3,  4400] loss: -12.334 
[3,  4600] loss: -12.428 
[3,  4800] loss: -12.428 
[3,  5000] loss: -12.421 
[3,  5200] loss: -12.227 
[3,  5400] loss: -12.475 
[3,  5600] loss: -12.448 
[3,  5800] loss: -12.440 
[3,  6000] loss: -12.479 
[3,  6200] loss: -12.346 
[3,  6400] loss: -12.504 
[3,  6600] loss: -12.483 
[3,  6800] loss: -12.545 
[3,  7000] loss: -12.575 
[3,  7200] loss: -12.474 
[3,  7400] loss: -12.446 
[3,  7600] loss: -12.532 
[3,  7800] l

HBox(children=(FloatProgress(value=0.0, max=9500.0), HTML(value='')))

[4,   200] loss: -12.710 
[4,   400] loss: -12.694 
[4,   600] loss: -12.805 
[4,   800] loss: -12.686 
[4,  1000] loss: -12.725 
[4,  1200] loss: -12.733 
[4,  1400] loss: -12.705 
[4,  1600] loss: -12.629 
[4,  1800] loss: -12.598 
[4,  2000] loss: -12.731 
[4,  2200] loss: -12.713 
[4,  2400] loss: -12.666 
[4,  2600] loss: -12.855 
[4,  2800] loss: -12.729 
[4,  3000] loss: -12.706 
[4,  3200] loss: -12.695 
[4,  3400] loss: -12.747 
[4,  3600] loss: -12.697 
[4,  3800] loss: -12.754 
[4,  4000] loss: -12.845 
[4,  4200] loss: -12.683 
[4,  4400] loss: -12.720 
[4,  4600] loss: -12.787 
[4,  4800] loss: -12.553 
[4,  5000] loss: -12.757 
[4,  5200] loss: -12.802 
[4,  5400] loss: -12.701 
[4,  5600] loss: -12.611 
[4,  5800] loss: -12.782 
[4,  6000] loss: -12.795 
[4,  6200] loss: -12.680 
[4,  6400] loss: -12.808 
[4,  6600] loss: -12.663 
[4,  6800] loss: -12.889 


KeyboardInterrupt: 

In [5]:
def reformat(prefix, i, filename):
    filename = filename.split('.')[-2]
    return filename + '_' + prefix + '_' + str(i) + '.npy'

In [8]:
# create a new copy of training data in pwd, with same filename, but replace audio data with de-mixed audio data
mode = 'test'
csv_name = 'overlay-'+mode+'.csv'
csv_separated = 'separated-'+mode+'.csv'
csv = pd.read_csv(csv_name)
tasnet.eval()
batch_size = 8

with torch.no_grad():
    for i in tqdm(range(len(csv))):
        row = csv.iloc[i]
        seg1 = load8hz(root+row['first_file'])
        seg2 = load8hz(root+row['second_file'])
        create_dir(row['first_file'])
        create_dir(row['second_file'])
        sig1, sig2 = torch.Tensor(seg1)[None, ...].float().cuda(device), torch.Tensor(seg2)[None, ...].float().cuda(device)
        source = torch.stack([sig1, sig2], dim = 1).detach()
        mixture = sig1+sig2
        out = tasnet(mixture)
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(source, out, torch.ones(batch_size, dtype = torch.int32).cuda(device)*2*new_sr)
        new_seg1, new_seg2 = reorder_estimate_source[0].cpu().detach().numpy()
        newfile1, newfile2 = reformat(mode, i, row['first_file']), reformat(mode, i, row['second_file'])
        csv.at[i, 'first_file'] = newfile1
        csv.at[i, 'second_file'] = newfile2
        np.save(newfile1, new_seg1)
        np.save(newfile2, new_seg2)
    csv.to_csv(csv_separated, index = False)

HBox(children=(FloatProgress(value=0.0, max=17480.0), HTML(value='')))




In [15]:
mode = 'test'
csv_name = 'overlay-'+mode+'.csv'
csv = pd.read_csv(csv_name)
i = np.random.randint(len(csv))
row = csv.iloc[i]
seg1 = load8hz(root+row['first_file'])
seg2 = load8hz(root+row['second_file'])
create_dir(row['first_file'])
create_dir(row['second_file'])
mixture = torch.Tensor(seg1+seg2).cuda(device)
mixture = mixture[None, ...]
out = tasnet(mixture)
new_seg1, new_seg2 = out[0].cpu().detach().numpy()

In [16]:
Audio(seg1, rate = new_sr)

In [17]:
Audio(seg2, rate = new_sr)

In [18]:
Audio(new_seg1, rate = new_sr)

In [19]:
Audio(new_seg2, rate = new_sr)

In [23]:
# def chop_chunk(signal):
#     signal_len = signal.shape[-1]
#     if signal_len < 2*new_sr:
#         padding = np.zeros(2*new_sr-len(signal))
#         signal = np.cat((signal, padding))
#     N_fr=signal_len//wlen
#     chunks = []
#     for i in range(N_fr):
#         chunks.append(signal[i*wlen:(i+1)*wlen])
#     return chunks

class ChunkSet(torch.utils.data.Dataset):
    def __init__(self, csv):
        super().__init__()
        self.csv = pd.read_csv(csv)
        self.speakers = list(set(self.csv['first_speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        spkr1, spkr2 = row['first_speaker'], row['second_speaker']
        sig1, sig2 = np.load(row['first_file']), np.load(row['second_file'])
        spkr1 = self.spkr2idx[spkr1]
        spkr2 = self.spkr2idx[spkr2]
        target1 = np.zeros(len(self.speakers))
        target2 = np.zeros(len(self.speakers))
        target1[spkr1] = 1
        target2[spkr2] = 1
        return sig1, sig2, target1, target2

chunkset_train = ChunkSet('separated-train.csv')
chunkset_val = ChunkSet('separated-val.csv')
chunkset_val = ChunkSet('separated-test.csv')

In [24]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

In [25]:
fs=new_sr
cw_len=200
cw_shift=10

wlen=int(fs*cw_len/1000.00)
wshift=int(fs*cw_shift/1000.00)




class MixedClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        cnn_arch = {
                'input_dim':wlen,
                'fs':fs,
                'cnn_N_filt':[80,60,60],
                'cnn_len_filt':[251,5,5],
                'cnn_max_pool_len':[3,3,3],
                'cnn_use_laynorm_inp':True,
                'cnn_use_batchnorm_inp':False,
                'cnn_use_laynorm':[True,True,True],
                'cnn_use_batchnorm':[False,False,False],
                'cnn_act':['leaky_relu','leaky_relu','leaky_relu'],
                'cnn_drop':[0.0,0.0,0.0]
                }
        self.cnn_net = SincNet(cnn_arch)

        dnn1_arch = {'input_dim': self.cnn_net.out_dim,
                  'fc_lay': [2048,2048,2048],
                  'fc_drop': [0.0,0.0,0.0], 
                  'fc_use_batchnorm': [True,True,True],
                  'fc_use_laynorm': [False,False,False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['leaky_relu','leaky_relu','leaky_relu']
                  }
        self.dnn1 = MLP(dnn1_arch)


        dnn2_arch = {'input_dim':2048 ,
                  'fc_lay': [20],
                  'fc_drop': [0.0], 
                  'fc_use_batchnorm': [False],
                  'fc_use_laynorm': [False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['linear'] # leakyrelu(1) is just identity mapping
                  }
        self.dnn2 = MLP(dnn2_arch)
        
        self.softmax = nn.Softmax(dim = 1)
        self.use_all_chunks = self.training
        
    def chop_chunk(self, signal):
        batch_size, signal_len = signal.shape
        N_fr=(signal_len-wlen)//wshift
        chunks = []
        for i in range(N_fr):
            chunks.append(signal[..., i*wshift:i*wshift+wlen]) # list of N_fr elements, each (batch_size*wlen)
        return chunks
    
    def estimate(self, chunks):
        out_vecs = []
        if not self.use_all_chunks:
            indices = np.random.randint(len(chunks), size = (len(chunks)//10))
            chunks = [chunks[idx] for idx in indices]
        for chunk in chunks:
            out_vecs.append(self.softmax(self.dnn2(self.dnn1(self.cnn_net(chunk))))) # list of N_fr elements, each (batch_size*N_spkr), softmaxed
        out_tensor = torch.stack(out_vecs, dim = 1) # batch_size*N_fr*N_spkr
        out_tensor = out_tensor.mean(dim = 1) # batch_size*N_spkr
        return out_tensor 
    
    def forward(self, signal):
        X = self.chop_chunk(signal)
        out = self.estimate(X)
        return out
    

cls = MixedClassifier().cuda(device)
optimizer = torch.optim.Adam(cls.parameters(), 0.001)


if os.path.exists('models/sincnet.pth'):
    print('load model')
    checkpoint = torch.load('models/sincnet.pth')
    cls.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('initializing new model')
    bestacc = 0.0

initializing new model


In [None]:
batch_size = 32
chunkloader_train = torch.utils.data.DataLoader(chunkset_train, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
chunkloader_val = torch.utils.data.DataLoader(chunkset_val, batch_size=32, shuffle=True, pin_memory = True, num_workers = 16)
criterion = torch.nn.BCELoss()


for epoch in range(64):
    running_loss = 0.0
    running_accuracy = 0.0
    cls.train()
    cls.use_all_chunks = False
    for batch_idx, (sig1, sig2, target1, target2) in enumerate(tqdm(chunkloader_train)):
        optimizer.zero_grad()
        sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
        target1, target2 = target1.float().cuda(device), target2.float().cuda(device)
        out1, out2 = cls(sig1), cls(sig2)
        
        loss = criterion(out1, target1)+criterion(out2, target2)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(cls.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        out = torch.stack([out1, out2], dim = 0)
        out, _ = torch.max(out, dim = 0)
        target = target1+target2
        running_accuracy += compute_corrects(out, target)/batch_size
        
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            running_loss = 0.0
            running_accuracy = 0.0
            torch.save({
            'model_state_dict': cls.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'bestacc': bestacc
            }, 'models/sincnet.pth')
    
    
    cls.eval()
    cls.use_all_chunks = True
    corrects = 0
    with torch.no_grad():
        for batch_idx, (sig1, sig2, target1, target2) in enumerate(tqdm(chunkloader_val)):
            sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
            target1, target2 = target1.float().cuda(device), target2.float().cuda(device)
            out1, out2 = cls(sig1), cls(sig2)
            
            out = torch.stack([out1, out2], dim = 0)
            out, _ = torch.max(out, dim = 0)
            target = target1+target2
            corrects += compute_corrects(out, target)
        
        print('val acc:', corrects/len(chunkset_val))
        if corrects/len(chunkset_val) > bestacc:
            bestacc = corrects/len(chunkset_val)
            torch.save({
            'model_state_dict': cls.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'bestacc': bestacc
            }, 'models/best-sincnet.pth')

HBox(children=(FloatProgress(value=0.0, max=2375.0), HTML(value='')))

[1,   200] loss: 0.336 accuracy: 0.051
[1,   400] loss: 0.264 accuracy: 0.181
[1,   600] loss: 0.240 accuracy: 0.256


In [71]:
class E2ESet(torch.utils.data.Dataset):
    def __init__(self, root, csv):
        super().__init__()
        self.root = root
        self.csv = pd.read_csv(csv)
        self.speakers = list(set(self.csv['first_speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
    def __len__(self):
        return len(self.csv)
    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        sig1, sig2 = load8hz(root+row['first_file']), load8hz(root+row['second_file']) # original files
        spkr1, spkr2 = row['first_speaker'], row['second_speaker']
        target_vec = np.zeros(len(self.speakers))
        target_vec[self.spkr2idx[spkr1]] = 1
        target_vec[self.spkr2idx[spkr2]] = 1
        return sig1+sig2, target_vec
e2eset_train = E2ESet(root, 'overlay-train.csv')
e2eset_val = E2ESet(root, 'overlay-val.csv')
e2eset_test = E2ESet(root, 'overlay-test.csv')

In [74]:
if os.path.exists('models/sincnet.pth'):
    print('load sincnet model')
    checkpoint = torch.load('models/sincnet.pth')
    cls.load_state_dict(checkpoint['model_state_dict'])
    
class E2Enet(nn.Module):
    def __init__(self, tasnet, cls):
        super().__init__()
        self.tasnet = tasnet
        self.cls = cls

    def forward(self, sig_mixed):
        sig12 = self.tasnet(sig_mixed) # batch_size*wlen, batch_size*wlen
        sig1, sig2 = sig12[:, 0], sig12[:, 1]
        pred1, pred2 = self.cls(sig1), self.cls(sig2)
        pred_combined = torch.stack([pred1, pred2], dim = 0) # 2*batch_size*N_spkr
        pred_combined , _ = torch.max(pred_combined, dim = 0) # batch_size*N_spkr
        return pred_combined

e2enet = E2Enet(tasnet, cls).cuda(device)

optimizer = torch.optim.Adam(e2enet.parameters(), lr = 0.001)
if os.path.exists('models/e2enet.pth'):
    print('load model')
    checkpoint = torch.load('models/e2enet.pth')
    e2enet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
    if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']
    else:
        bestacc = 0.0
else:
    print('using loaded tasnet+sincnet')
    bestacc = 0.0
    
e2enet = nn.DataParallel(e2enet, device_ids = device_ids)

load sincnet model
using loaded tasnet+sincnet


In [None]:
e2enet.eval()
e2enet.module.cls.use_all_chunks = True

batch_size = 32
e2eloader_test  = torch.utils.data.DataLoader(e2eset_test, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
with torch.no_grad():    
    corrects = 0
    for batch_idx, (mixed_sig, target) in enumerate(tqdm(e2eloader_test)):
        mixed_sig, target = mixed_sig.float().cuda(device), target.float().cuda(device)
        out = e2enet(mixed_sig)
        corrects += compute_corrects(out, target)
    print('test acc:', corrects/len(e2eset_test))
pass

HBox(children=(FloatProgress(value=0.0, max=511.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f952dd5c680>
Traceback (most recent call last):
  File "/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f952dd5c680>
  File "/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
Traceback (most recent call last):
    w.join()
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f952dd5c680>
  File "/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/home/junzhez2/anaconda3/lib/python3.7/multiprocessing/process.py", line 138, in join
Traceback (most recent call last):
    assert self._parent_pid == os.getpid(), 'can only join a child process'
    self._shutdown_workers()
  File "/home/junzhez2/anacon

  File "/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f952dd5c680>
    w.join()
  File "/home/junzhez2/anaconda3/lib/python3.7/multiprocessing/process.py", line 138, in join
Traceback (most recent call last):
    assert self._parent_pid == os.getpid(), 'can only join a child process'
  File "/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 961, in __del__
AssertionError: can only join a child process
    self._shutdown_workers()
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f952dd5c680>
Traceback (most recent call last):
  File "/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 941, in _

In [58]:
idx = np.random.randint(batch_size)
sig_mixed = mixed_sig[idx].detach().cpu().numpy()
sig1, sig2 = e2enet.module.tasnet(mixed_sig)[idx].detach().cpu().numpy()

In [59]:
Audio(sig_mixed, rate = new_sr)

In [60]:
Audio(sig1, rate = new_sr)

In [61]:
Audio(sig2, rate = new_sr)