In [1]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
from conv_tasnet_base import *
from pit_criterion import cal_loss
from collections import Counter
import os
device = 0
device_ids = [0, 1, 2, 3]
root = './'
sr = 8000
torch.cuda.set_device(device)
iterations = 3
print(torch.__version__)

1.5.0


In [2]:
base = ConvTasNet(N = 256, L = 20, B = 256, H = 512, P = 3, X = 8, R = 2, C = 3, norm_type="gLN", causal=0,
             mask_nonlinear='relu').cuda(device)

# checkpoint = torch.load('models/stable/tasnet-baseline.pth')


# own_state = base.state_dict()
# for name, param in tqdm(checkpoint['model_state_dict'].items()):
#     if name not in own_state:
#          continue
#     if isinstance(param, torch.nn.Parameter):
#         param = param.data
#     try:
#         own_state[name].copy_(param)
#     except:
#         print(name, 'shape mismatch')



class Wrapper(torch.nn.Module):
    def __init__(self, base, iterations):
        super().__init__()
        self.base = base
        self.iterations = iterations
        self.alpha = torch.nn.Parameter(torch.Tensor([0.5]))
        
        
    def unit_norm(self, tensor):
        # tensor.shape = [batch_size, C, length]
        mean = tensor.mean(dim = -1, keepdim = True) # [batch_size, C, 1]
        var = torch.var(tensor, dim = -1, keepdim=True, unbiased=False)  # [batch_size, C, 1]
        tensor = (tensor - mean) / torch.pow(var + EPS, 0.5)
        return tensor
        
    def forward(self, x):
        shortcut = x
        out_list = []
        for i in range(self.iterations):
            if i > 0:
                x = shortcut*self.alpha+self.unit_norm(x)*(1-self.alpha)
                x = self.base(x)
            else:
                x = self.base(x)
            out_list.append(x)
        return out_list
       
wrapper = Wrapper(base, iterations).cuda(device)
optimizer = torch.optim.Adam(wrapper.parameters(), lr = 0.001)

wrapper = torch.nn.DataParallel(wrapper, device_ids = device_ids)
checkpoint = torch.load('models/base.pth')
wrapper.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


#example_forward_input = torch.rand(1, 3, 16000).cuda(device)
#wrapper = torch.jit.trace(wrapper, example_forward_input)

In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, num_talker = 3):
        super().__init__()
        self.segments = pd.read_csv(root+csv)
        self.speakers = list(set(self.segments['speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.num_talker = num_talker
        
    def __len__(self):
        return len(self.segments)
    def __getitem__(self, idx):
        talkers = []
        sigs = []
        for i in range(self.num_talker):
            if i!=0:
                idx = np.random.randint(len(self.segments))
            seg = self.segments.iloc[idx]
            while seg['speaker'] in talkers:
                idx = np.random.randint(len(self.segments))
                seg = self.segments.iloc[idx]
            sig = np.load(root+seg['segfile'])
            sig-=np.mean(sig)
            sig/=np.std(sig)
            sigs.append(sig)
            talkers.append(seg['speaker'])
        sigs = np.array(sigs)
        mixture = np.sum(sigs, axis = 0)[None, ...]
        mixture = np.repeat(mixture, 3, 0)
        return mixture, sigs


trainset = OverlayDataSet('files/train-segments.csv')
valset = OverlayDataSet('files/val-segments.csv')
testset = OverlayDataSet('files/test-segments.csv')

torch.manual_seed(0)
np.random.seed(0)
idx = np.random.randint(len(testset))

mixture, sources = testset[idx]
print(mixture.shape)
Audio(mixture[0], rate = sr)

(3, 16000)


In [None]:
batch_size = 20
trainloader  = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)

for epoch in range(0, 64):
    running_loss = 0.0
    wrapper.train()
    for batch_idx, (mixture, sources) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
        out_list = wrapper(mixture)
        loss = 0.0
        for stage_num, out in enumerate(out_list):
            stage_loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(sources, out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr) # 2 seconds has 2*sr samples
            loss += stage_loss*(stage_num+1)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(base.parameters(), 0.5)
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f alpha: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, wrapper.module.alpha[0].item()))
            running_loss = 0.0
            torch.save({
            'model_state_dict': wrapper.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/base.pth')
            
    
    torch.manual_seed(0)
    np.random.seed(0)
    with torch.no_grad():
        running_loss = 0.0
        wrapper.eval()
        for batch_idx, (mixture, sources) in enumerate(tqdm(valloader)):
            mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
            out_list = wrapper(mixture)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(sources, out_list[-1], torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr)

            running_loss += loss.item()*sources.shape[0]
        print(running_loss/len(valset))

HBox(children=(FloatProgress(value=0.0, max=11295.0), HTML(value='')))

[1,   200] loss: 11.394 alpha: 0.616
[1,   400] loss: 10.916 alpha: 0.650
[1,   600] loss: 9.615 alpha: 0.683
[1,   800] loss: 9.353 alpha: 0.716
[1,  1000] loss: 9.564 alpha: 0.761
[1,  1200] loss: 7.814 alpha: 0.791
[1,  1400] loss: 7.068 alpha: 0.817
[1,  1600] loss: 6.947 alpha: 0.845
[1,  1800] loss: 6.317 alpha: 0.872
[1,  2000] loss: 5.620 alpha: 0.882
[1,  2200] loss: 4.729 alpha: 0.905
[1,  2400] loss: 5.548 alpha: 0.903
[1,  2600] loss: 5.008 alpha: 0.895
[1,  2800] loss: 4.609 alpha: 0.920
[1,  3000] loss: 3.817 alpha: 0.932
[1,  3200] loss: 3.274 alpha: 0.938
[1,  3400] loss: 4.077 alpha: 0.944
[1,  3600] loss: 3.245 alpha: 0.942
[1,  3800] loss: 3.554 alpha: 0.942
[1,  4000] loss: 2.278 alpha: 0.930
[1,  4200] loss: 3.210 alpha: 0.940
[1,  4400] loss: 2.070 alpha: 0.963
[1,  4600] loss: 1.510 alpha: 0.962
[1,  4800] loss: 0.800 alpha: 0.968
[1,  5000] loss: 1.250 alpha: 0.979
[1,  5200] loss: 0.929 alpha: 0.959
[1,  5400] loss: 0.828 alpha: 0.958
[1,  5600] loss: -0.565 al

HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


-0.5705340645590007


HBox(children=(FloatProgress(value=0.0, max=11295.0), HTML(value='')))

[2,   200] loss: -5.699 alpha: 0.731
[2,   400] loss: -4.997 alpha: 0.722
[2,   600] loss: -6.388 alpha: 0.727
[2,   800] loss: -5.803 alpha: 0.718
[2,  1000] loss: -5.133 alpha: 0.722
[2,  1200] loss: -6.229 alpha: 0.720
[2,  1400] loss: -6.945 alpha: 0.720
[2,  1600] loss: -6.840 alpha: 0.714
[2,  1800] loss: -6.434 alpha: 0.717
[2,  2000] loss: -7.106 alpha: 0.722
[2,  2200] loss: -8.358 alpha: 0.714
[2,  2400] loss: -6.496 alpha: 0.710
[2,  2600] loss: -7.807 alpha: 0.703
[2,  2800] loss: -7.077 alpha: 0.706
[2,  3000] loss: -7.914 alpha: 0.706
[2,  3200] loss: -9.018 alpha: 0.707
[2,  3400] loss: -8.279 alpha: 0.710
[2,  3600] loss: -8.782 alpha: 0.713
[2,  3800] loss: -7.038 alpha: 0.695
[2,  4000] loss: -9.873 alpha: 0.698
[2,  4200] loss: -8.251 alpha: 0.701
[2,  4400] loss: -8.716 alpha: 0.703
[2,  4600] loss: -9.861 alpha: 0.710
[2,  4800] loss: -9.948 alpha: 0.709
[2,  5000] loss: -10.125 alpha: 0.703
[2,  5200] loss: -9.174 alpha: 0.709
[2,  5400] loss: -8.416 alpha: 0.697


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


-1.8325794757124911


HBox(children=(FloatProgress(value=0.0, max=11295.0), HTML(value='')))

[3,   200] loss: -12.818 alpha: 0.684
[3,   400] loss: -12.762 alpha: 0.686
[3,   600] loss: -13.675 alpha: 0.701
[3,   800] loss: -12.935 alpha: 0.693
[3,  1000] loss: -12.244 alpha: 0.720
[3,  1200] loss: -13.890 alpha: 0.708
[3,  1400] loss: -13.926 alpha: 0.696
[3,  1600] loss: -13.787 alpha: 0.699
[3,  1800] loss: -13.831 alpha: 0.699
[3,  2000] loss: -13.368 alpha: 0.699
[3,  2200] loss: -14.974 alpha: 0.693
[3,  2400] loss: -13.296 alpha: 0.691
[3,  2600] loss: -13.790 alpha: 0.692
[3,  2800] loss: -13.968 alpha: 0.716
[3,  3000] loss: -13.886 alpha: 0.704
[3,  3200] loss: -14.947 alpha: 0.697
[3,  3400] loss: -14.485 alpha: 0.691
[3,  3600] loss: -14.666 alpha: 0.696
[3,  3800] loss: -14.016 alpha: 0.691
[3,  4000] loss: -16.720 alpha: 0.697
[3,  4200] loss: -14.115 alpha: 0.698
[3,  4400] loss: -14.667 alpha: 0.701
[3,  4600] loss: -15.130 alpha: 0.703
[3,  4800] loss: -15.333 alpha: 0.703
[3,  5000] loss: -16.258 alpha: 0.697
[3,  5200] loss: -15.476 alpha: 0.696
[3,  5400] l

HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


-2.718311654049087


HBox(children=(FloatProgress(value=0.0, max=11295.0), HTML(value='')))

[4,   200] loss: -17.919 alpha: 0.692
[4,   400] loss: -16.384 alpha: 0.687
[4,   600] loss: -18.497 alpha: 0.692
[4,   800] loss: -18.124 alpha: 0.685
[4,  1000] loss: -16.285 alpha: 0.691
[4,  1200] loss: -18.833 alpha: 0.693
[4,  1400] loss: -18.801 alpha: 0.699
[4,  1600] loss: -18.482 alpha: 0.693
[4,  1800] loss: -17.813 alpha: 0.705
[4,  2000] loss: -18.883 alpha: 0.705
[4,  2200] loss: -19.194 alpha: 0.693
[4,  2400] loss: -17.238 alpha: 0.699
[4,  2600] loss: -18.415 alpha: 0.699
[4,  2800] loss: -17.894 alpha: 0.696
[4,  3000] loss: -19.121 alpha: 0.693
[4,  3200] loss: -19.213 alpha: 0.703
[4,  3400] loss: -18.664 alpha: 0.702
[4,  3600] loss: -20.116 alpha: 0.702
[4,  3800] loss: -18.375 alpha: 0.694
[4,  4000] loss: -20.822 alpha: 0.688
[4,  4200] loss: -19.141 alpha: 0.701
[4,  4400] loss: -18.793 alpha: 0.698
[4,  4600] loss: -19.542 alpha: 0.699
[4,  4800] loss: -19.637 alpha: 0.695
[4,  5000] loss: -19.761 alpha: 0.702
[4,  5200] loss: -19.241 alpha: 0.684
[4,  5400] l

HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


-3.3666031548701434


HBox(children=(FloatProgress(value=0.0, max=11295.0), HTML(value='')))

[5,   200] loss: -20.866 alpha: 0.705
[5,   400] loss: -19.825 alpha: 0.699
[5,   600] loss: -21.822 alpha: 0.689
[5,   800] loss: -20.872 alpha: 0.697
[5,  1000] loss: -19.691 alpha: 0.702
[5,  1200] loss: -21.193 alpha: 0.695
[5,  1400] loss: -21.828 alpha: 0.700
[5,  1600] loss: -21.597 alpha: 0.691
[5,  1800] loss: -21.223 alpha: 0.702
[5,  2000] loss: -21.572 alpha: 0.698
[5,  2200] loss: -22.485 alpha: 0.703
[5,  2400] loss: -20.969 alpha: 0.699
[5,  2600] loss: -20.874 alpha: 0.695
[5,  2800] loss: -21.701 alpha: 0.727
[5,  3000] loss: -21.633 alpha: 0.717
[5,  3200] loss: -22.523 alpha: 0.714
[5,  3400] loss: -21.550 alpha: 0.717
[5,  3600] loss: -22.477 alpha: 0.710
[5,  3800] loss: -21.102 alpha: 0.708
[5,  4000] loss: -22.825 alpha: 0.708
[5,  4200] loss: -21.909 alpha: 0.707
[5,  4400] loss: -21.656 alpha: 0.716
[5,  4600] loss: -22.483 alpha: 0.713
[5,  4800] loss: -22.101 alpha: 0.708
[5,  5000] loss: -22.861 alpha: 0.710
[5,  5200] loss: -22.032 alpha: 0.710
[5,  5400] l

HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


-3.9914750946470465


HBox(children=(FloatProgress(value=0.0, max=11295.0), HTML(value='')))

[6,   200] loss: -23.656 alpha: 0.711
[6,   400] loss: -22.739 alpha: 0.717
[6,   600] loss: -24.327 alpha: 0.712
[6,   800] loss: -23.209 alpha: 0.710
[6,  1000] loss: -22.339 alpha: 0.722
[6,  1200] loss: -24.351 alpha: 0.713
[6,  1400] loss: -24.435 alpha: 0.700
[6,  1600] loss: -23.628 alpha: 0.705
[6,  1800] loss: -22.964 alpha: 0.708
[6,  2000] loss: -23.690 alpha: 0.708
[6,  2200] loss: -25.818 alpha: 0.702
[6,  2400] loss: -23.638 alpha: 0.709
[6,  2600] loss: -24.154 alpha: 0.699
[6,  2800] loss: -23.791 alpha: 0.704
[6,  3000] loss: -24.169 alpha: 0.697
[6,  3200] loss: -25.007 alpha: 0.716
[6,  3400] loss: -23.961 alpha: 0.707
[6,  3600] loss: -24.474 alpha: 0.709
[6,  3800] loss: -23.659 alpha: 0.712
[6,  4000] loss: -26.210 alpha: 0.707
[6,  4200] loss: -24.338 alpha: 0.707
[6,  4400] loss: -23.970 alpha: 0.709
[6,  4600] loss: -25.271 alpha: 0.700
[6,  4800] loss: -25.974 alpha: 0.698
[6,  5000] loss: -25.289 alpha: 0.705
[6,  5200] loss: -24.708 alpha: 0.704
[6,  5400] l

In [None]:
torch.manual_seed(0)
np.random.seed(0)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, pin_memory = True, num_workers = 32)
with torch.no_grad():
    running_loss = 0.0
    base.eval()
    for batch_idx, (mixture, sources) in enumerate(tqdm(testloader)):
        mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
        out_list = wrapper(mixture)
        final_out = out_list[1]
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(sources, final_out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr)

        running_loss += loss.item()*sources.shape[0]
    print(running_loss/len(testset))

In [13]:
Audio(final_out[2][0].detach().cpu().numpy(), rate = sr)

In [14]:
Audio(final_out[2][1].detach().cpu().numpy(), rate = sr)

In [15]:
Audio(final_out[2][2].detach().cpu().numpy(), rate = sr)

In [65]:
def softmax(a, b):
    return np.exp(a)/(np.exp(a)+np.exp(b))
def unit_norm(tensor):
    # tensor.shape = [batch_size, C, length]
    mean = tensor.mean(dim = -1, keepdim = True) # [batch_size, C, 1]
    var = torch.var(tensor, dim = -1, keepdim=True, unbiased=False)  # [batch_size, C, 1]
    tensor_normalized = (tensor - mean) / torch.pow(var + EPS, 0.5)
    return tensor_normalized
input = mixture*softmax(0.33, 0.66)+unit_norm(final_out)*softmax(0.66,0.33)
Audio(input[1][0].detach().cpu().numpy(), rate = sr)

0.035571189272636174
22026.465794806718
