In [1]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
from conv_tasnet_base import *
from pit_criterion import cal_loss
from collections import Counter
import os
device = 0
device_ids = [0, 1, 2, 3]
root = './'
sr = 8000
torch.cuda.set_device(device)
iterations = 3
print(torch.__version__)

1.5.0


In [5]:
base = ConvTasNet(N = 256, L = 20, B = 256, H = 512, P = 3, X = 8, R = 4, C = 3, norm_type="gLN", causal=0,
             mask_nonlinear='relu').cuda(device)
optimizer = torch.optim.Adam(base.parameters(), lr = 0.0001)

checkpoint = torch.load('models/base.pth')
base.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

#base = nn.DataParallel(base, device_ids = device_ids)

# own_state = base.state_dict()
# for name, param in tqdm(checkpoint['model_state_dict'].items()):
#     if name not in own_state:
#          continue
#     if isinstance(param, torch.nn.Parameter):
#         param = param.data
#     try:
#         own_state[name].copy_(param)
#     except:
#         print(name, 'shape mismatch')



class Wrapper(torch.nn.Module):
    def __init__(self, base, iterations):
        super().__init__()
        self.base = base
        self.iterations = iterations
        
    def unit_norm(self, tensor):
        # tensor.shape = [batch_size, C, length]
        mean = tensor.mean(dim = -1, keepdim = True) # [batch_size, C, 1]
        var = torch.var(tensor, dim = -1, keepdim=True, unbiased=False)  # [batch_size, C, 1]
        tensor_normalized = (tensor - mean) / torch.pow(var + EPS, 0.5)
        return tensor_normalized
        
    def forward(self, mixture):
        out = torch.rand(mixture.shape).cuda(device)
        out.requires_grad = False
        out_list = []
        for it in range(self.iterations):
            mixture_part = mixture*(1-it/self.iterations)
            out_part = self.unit_norm(out.clone())*(it/self.iterations)
            input = mixture_part + out_part
            out = self.base(input)
            out_list.append(out)
        out_list = tuple(out_list)
        return out_list
       
wrapper = Wrapper(base, iterations).eval()

example_forward_input = torch.rand(1, 3, 16000).cuda(device)
wrapper = torch.jit.trace(wrapper, example_forward_input)

In [6]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, num_talker = 3):
        super().__init__()
        self.segments = pd.read_csv(root+csv)
        self.speakers = list(set(self.segments['speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.num_talker = num_talker
        
    def __len__(self):
        return len(self.segments)
    def __getitem__(self, idx):
        talkers = []
        sigs = []
        for i in range(self.num_talker):
            if i!=0:
                idx = np.random.randint(len(self.segments))
            seg = self.segments.iloc[idx]
            while seg['speaker'] in talkers:
                idx = np.random.randint(len(self.segments))
                seg = self.segments.iloc[idx]
            sig = np.load(root+seg['segfile'])
            sig-=np.mean(sig)
            sig/=np.std(sig)
            sigs.append(sig)
            talkers.append(seg['speaker'])
        sigs = np.array(sigs)
        mixture = np.sum(sigs, axis = 0)[None, ...]
        mixture = np.repeat(mixture, 3, 0)
        return mixture, sigs


trainset = OverlayDataSet('files/train-segments.csv')
valset = OverlayDataSet('files/val-segments.csv')
testset = OverlayDataSet('files/test-segments.csv')

torch.manual_seed(0)
np.random.seed(0)
idx = np.random.randint(len(testset))

mixture, sources = testset[idx]
print(mixture.shape)
Audio(mixture[0], rate = sr)

(3, 16000)


In [7]:
torch.autograd.set_detect_anomaly(True)
batch_size = 3
trainloader  = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)

for epoch in range(0, 64):
    running_loss = 0.0
    base.train()
    for batch_idx, (mixture, sources) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
        out_list = wrapper(mixture)
        loss = 0.0
        for stage_num, out in enumerate(out_list):
            stage_loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(sources, out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr) # 2 seconds has 2*sr samples
            loss += stage_loss*(stage_num+1)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(base.parameters(), 0.5*iterations)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
            torch.save({
            'model_state_dict': base.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/base.pth')
            
    
    torch.manual_seed(0)
    np.random.seed(0)
    with torch.no_grad():
        running_loss = 0.0
        base.eval()
        for batch_idx, (mixture, sources) in enumerate(tqdm(valloader)):
            mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
            out_list = wrapper(mixture)
            final_out = out_list[-1]
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(sources, final_out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr)

            running_loss += loss.item()*sources.shape[0]
        print(running_loss/len(valset))

HBox(children=(FloatProgress(value=0.0, max=75300.0), HTML(value='')))

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py(208): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py(100): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
Conv-TasNet/src/conv_tasnet_base.py(272): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py(100): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
Conv-TasNet/src/conv_tasnet_base.py(235): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py(100): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py(100): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py(100): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
Conv-TasNet/src/conv_tasnet_base.py(201): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
Conv-TasNet/src/conv_tasnet_base.py(53): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
<ipython-input-5-1e54a8e9d37a>(45): forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(534): _slow_forward
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py(548): __call__
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py(1027): trace_module
/home/junzhez2/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py(875): trace
<ipython-input-5-1e54a8e9d37a>(53): <module>
/home/junzhez2/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3331): run_code
/home/junzhez2/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3254): run_ast_nodes
/home/junzhez2/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3063): run_cell_async
/home/junzhez2/anaconda3/lib/python3.7/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner
/home/junzhez2/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2886): _run_cell
/home/junzhez2/anaconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2858): run_cell
/home/junzhez2/anaconda3/lib/python3.7/site-packages/ipykernel/zmqshell.py(536): run_cell
/home/junzhez2/anaconda3/lib/python3.7/site-packages/ipykernel/ipkernel.py(300): do_execute
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/gen.py(209): wrapper
/home/junzhez2/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py(541): execute_request
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/gen.py(209): wrapper
/home/junzhez2/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py(268): dispatch_shell
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/gen.py(209): wrapper
/home/junzhez2/anaconda3/lib/python3.7/site-packages/ipykernel/kernelbase.py(361): process_one
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/gen.py(748): run
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/gen.py(787): inner
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/ioloop.py(743): _run_callback
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/ioloop.py(690): <lambda>
/home/junzhez2/anaconda3/lib/python3.7/asyncio/events.py(88): _run
/home/junzhez2/anaconda3/lib/python3.7/asyncio/base_events.py(1782): _run_once
/home/junzhez2/anaconda3/lib/python3.7/asyncio/base_events.py(538): run_forever
/home/junzhez2/anaconda3/lib/python3.7/site-packages/tornado/platform/asyncio.py(153): start
/home/junzhez2/anaconda3/lib/python3.7/site-packages/ipykernel/kernelapp.py(583): start
/home/junzhez2/anaconda3/lib/python3.7/site-packages/traitlets/config/application.py(664): launch_instance
/home/junzhez2/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py(16): <module>
/home/junzhez2/anaconda3/lib/python3.7/runpy.py(85): _run_code
/home/junzhez2/anaconda3/lib/python3.7/runpy.py(193): _run_module_as_main
RuntimeError: 


In [59]:
torch.manual_seed(0)
np.random.seed(0)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, pin_memory = True, num_workers = 32)
with torch.no_grad():
    running_loss = 0.0
    base.eval()
    for batch_idx, (mixture, sources) in enumerate(tqdm(testloader)):
        mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
        out_list = wrapper(mixture)
        final_out = out_list[1]
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(sources, final_out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr)

        running_loss += loss.item()*sources.shape[0]
    print(running_loss/len(testset))

HBox(children=(FloatProgress(value=0.0, max=393.0), HTML(value='')))


-8.12657971572178


In [63]:
Audio(final_out[0][0].detach().cpu().numpy(), rate = sr)

In [64]:
Audio(final_out[0][1].detach().cpu().numpy(), rate = sr)

In [65]:
Audio(final_out[0][2].detach().cpu().numpy(), rate = sr)