In [1]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
from conv_tasnet import *
from pit_criterion import cal_loss
from collections import Counter
import os
device = 0
device_ids = [0, 1, 2, 3]
root = './'
sr = 8000
torch.cuda.set_device(device)

# Pre-Train Baseline

In [2]:
tas2 = ConvTasNet.load_model('final.pth.tar').cuda(device)
tasnet = ConvTasNet(N = 256, L = 20, B = 256, H = 512, P = 3, X = 8, R = 4, C = 3, norm_type="gLN", causal=0,
             mask_nonlinear='relu').cuda(device)

own_state = tasnet.state_dict()
for name, param in tqdm(tas2.state_dict().items()):
    if name not in own_state:
         continue
    if isinstance(param, torch.nn.Parameter):
        param = param.data
    try:
        own_state[name].copy_(param)
    except:
        print('shape mismatch')

optimizer = torch.optim.Adam(tasnet.parameters(), lr = 0.001)
tasnet = nn.DataParallel(tasnet, device_ids = device_ids)

if os.path.exists('models/tasnet.pth'):
    print('load model')
    checkpoint = torch.load('models/tasnet.pth')
    tasnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
    
tasnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=294.0), HTML(value='')))

shape mismatch

load model


In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv, num_talker = 3):
        super().__init__()
        self.segments = pd.read_csv(root+csv)
        self.speakers = list(set(self.segments['speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}
        self.num_talker = num_talker
        
    def __len__(self):
        return len(self.segments)
    def __getitem__(self, idx):
        talkers = []
        sigs = []
        for i in range(self.num_talker):
            if i!=0:
                idx = np.random.randint(len(self.segments))
            seg = self.segments.iloc[idx]
            while seg['speaker'] in talkers:
                idx = np.random.randint(len(self.segments))
                seg = self.segments.iloc[idx]
            sig = np.load(root+seg['segfile'])
            sig-=np.mean(sig)
            sig/=np.std(sig)
            sigs.append(sig)
            talkers.append(seg['speaker'])
        sigs = np.array(sigs)
        return np.sum(sigs, axis = 0), sigs


trainset = OverlayDataSet('files/train-segments.csv')
valset = OverlayDataSet('files/val-segments.csv')
testset = OverlayDataSet('files/test-segments.csv')
idx = np.random.randint(len(testset))

mixture, sources = testset[idx]
Audio(mixture, rate = sr)

In [4]:
Audio(sources[0], rate = sr)

In [5]:
Audio(sources[1], rate = sr)

In [6]:
Audio(sources[2], rate = sr)

In [None]:
batch_size = 32
trainloader  = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 32)
tasnet.train()

for epoch in range(4, 64):
    running_loss = 0.0
    for batch_idx, (mixture, sources) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
        out = tasnet(mixture)
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(sources, out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr) # 2 seconds has 2*sr samples
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
            torch.save({
            'model_state_dict': tasnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/tasnet.pth')
    with torch.no_grad():
        running_loss = 0.0
        tasnet.eval()
        for batch_idx, (mixture, sources) in enumerate(tqdm(valloader)):
            mixture, sources = mixture.float().cuda(device), sources.float().cuda(device)
            out = tasnet(mixture)
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(sources, out, torch.ones(mixture.shape[0], dtype = torch.int32).cuda(device)*2*sr)

            running_loss += loss.item()*sources.shape[0]
        print(running_loss/len(valset))

HBox(children=(FloatProgress(value=0.0, max=7060.0), HTML(value='')))

  frame = signal.new_tensor(frame).long()  # signal may in GPU or CPU


# Maybe also try doing recursive defogging