In [1]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
sys.path.append('SincNet/')
from conv_tasnet import *
from pit_criterion import cal_loss
from dnn_models import *
from data_io import ReadList,read_conf_inp,str_to_bool
from collections import Counter
import os
device = 0
device_ids = [0, 1, 2, 3]
root = '../'
sr = 8000
torch.cuda.set_device(device)

In [3]:
tasnet = ConvTasNet.load_model('final.pth.tar').cuda(device)
optimizer = torch.optim.Adam(tasnet.parameters(), lr = 0.001)
if os.path.exists('models/tasnet.pth'):
    print('load model')
    checkpoint = torch.load('models/tasnet.pth')
    tasnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
loss
tasnet = nn.DataParallel(tasnet, device_ids = device_ids)

load model


In [3]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv):
        super().__init__()
        self.segments = pd.read_csv(root+csv)
        self.speakers = list(set(self.segments['speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}

    def __len__(self):
        return len(self.segments)
    def __getitem__(self, idx):
        seg1 = self.segments.iloc[idx]
        seg2 = self.segments.iloc[np.random.randint(len(self.segments))]
        while(seg1['speaker']==seg2['speaker']):
            seg2 = self.segments.iloc[np.random.randint(len(self.segments))]

        sig1 = np.load(root+seg1['segfile'])
        sig2 = np.load(root+seg2['segfile'])


        
        out_vec1 = np.zeros(len(self.speakers)) # maybe try PIT training too
        out_vec2 = np.zeros(len(self.speakers)) # maybe try PIT training too
        out_vec1[self.spkr2idx[seg1['speaker']]] = 1
        out_vec2[self.spkr2idx[seg2['speaker']]] = 1

        return sig1, sig2, out_vec1, out_vec2


#mean, std = compute_mean_std('overlay-train.csv')


trainset = OverlayDataSet('train-segments.csv')
valset = OverlayDataSet('val-segments.csv')
testset = OverlayDataSet('test-segments.csv')
idx = np.random.randint(len(testset))

sig1, sig2, target1, target2 = testset[idx]

mixture = torch.Tensor(sig1+sig2).cuda(device)
mixture = mixture[None, ...]
out = tasnet(mixture)
new_sig1, new_sig2 = out[0].cpu().detach().numpy()
Audio(new_sig1, rate = sr)

  frame = signal.new_tensor(frame).long()  # signal may in GPU or CPU


In [4]:
Audio(new_sig2, rate = sr)

In [None]:
batch_size = 8
trainloader  = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader  = torch.utils.data.DataLoader(valset, batch_size=8, shuffle=True, pin_memory = True, num_workers = 16)
tasnet.train()

for epoch in range(64):
    running_loss = 0.0
    for batch_idx, (sig1, sig2, _, _) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
        out = tasnet(sig1+sig2)
        source = torch.stack([sig1, sig2], dim = 1).detach()
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(source, out, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
#             torch.save({
#             'model_state_dict': tasnet.module.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': loss
#             }, 'models/tasnet.pth')
    with torch.no_grad():
        running_loss = 0.0
        tasnet.eval()
        valloader  = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, pin_memory = True, num_workers = 16)
        for batch_idx, (sig1, sig2, _, _) in enumerate(tqdm(valloader)):
            sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
            out = tasnet(sig1+sig2)
            source = torch.stack([sig1, sig2], dim = 1).detach()
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(source, out, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr)
            torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)

            running_loss += loss.item()*sig1.shape[0]
        print(running_loss/len(valset))

In [None]:
with torch.no_grad():    
    tasnet.eval()
    running_loss = 0.0
    valloader  = torch.utils.data.DataLoader(valset, batch_size=8, shuffle=True, pin_memory = True, num_workers = 16)
    for batch_idx, (sig1, sig2, _, _) in enumerate(tqdm(valloader)):
        sig1, sig2 = sig1.float().cuda(device), sig2.float().cuda(device)
        out = tasnet(sig1+sig2)
        source = torch.stack([sig1, sig2], dim = 1).detach()
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(source, out, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr)
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)

        running_loss += loss.item()*sig1.shape[0]
    print(running_loss/len(valset))

In [5]:
def find_max2(tensor):
    array = tensor.cpu().detach().numpy()
    max2 = []
    for row in array:
        max2.append(np.argsort(row)[::-1][:2])
    return np.array(max2)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max2(tensor1), find_max2(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

In [6]:
fs=sr
cw_len=200
cw_shift=10

wlen=int(fs*cw_len/1000.00)
wshift=int(fs*cw_shift/1000.00)




class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        cnn_arch = {
                'input_dim':wlen,
                'fs':fs,
                'cnn_N_filt':[80,60,60],
                'cnn_len_filt':[251,5,5],
                'cnn_max_pool_len':[3,3,3],
                'cnn_use_laynorm_inp':True,
                'cnn_use_batchnorm_inp':False,
                'cnn_use_laynorm':[True,True,True],
                'cnn_use_batchnorm':[False,False,False],
                'cnn_act':['leaky_relu','leaky_relu','leaky_relu'],
                'cnn_drop':[0.0,0.0,0.0]
                }
        self.cnn_net = SincNet(cnn_arch)

        dnn1_arch = {'input_dim': self.cnn_net.out_dim,
                  'fc_lay': [2048,2048,2048],
                  'fc_drop': [0.0,0.0,0.0], 
                  'fc_use_batchnorm': [True,True,True],
                  'fc_use_laynorm': [False,False,False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['leaky_relu','leaky_relu','leaky_relu']
                  }
        self.dnn1 = MLP(dnn1_arch)


        dnn2_arch = {'input_dim':2048 ,
                  'fc_lay': [20],
                  'fc_drop': [0.0], 
                  'fc_use_batchnorm': [False],
                  'fc_use_laynorm': [False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['linear'] # leakyrelu(1) is just identity mapping
                  }
        self.dnn2 = MLP(dnn2_arch)
        
        self.softmax = nn.Softmax(dim = 1)
        self.use_all_chunks = self.training
        
    def chop_chunk(self, signal):
        batch_size, signal_len = signal.shape
        N_fr=(signal_len-wlen)//wshift
        chunks = []
        for i in range(N_fr):
            chunks.append(signal[..., i*wshift:i*wshift+wlen]) # list of N_fr elements, each [batch_size, wlen]
        return chunks
    
    def estimate(self, chunks):
        out_vecs = []
        n_fr = len(chunks)
        batch_size, wlen = chunks[0].shape
        if not self.use_all_chunks: # ~200 chunk, each shift by 10ms, length 200 ms
            chunks = chunks[::20]
            n_fr = 20
        chunks = torch.stack(chunks, dim = 1) # [batch_size , n_fr , wlen]
        chunks = chunks.view(-1, wlen) # [batch_size*n_fr, wlen]
        out_vecs = self.softmax(self.dnn2(self.dnn1(self.cnn_net(chunks)))).clamp(min=1e-8) # [batch_size*n_fr, num_spkr]
        out_tensor = out_vecs.view(batch_size, n_fr, 20) # [batch_size, n_fr, num_spkr]
        out_tensor = out_tensor.mean(dim = 1) # batch_size*N_spkr
        return out_tensor 
    
    def forward(self, signal):
        X = self.chop_chunk(signal)
        out = self.estimate(X)
        return out

In [7]:
load_model = True

cls = Classifier().cuda(device)
optimizer = torch.optim.Adam(cls.parameters(), 0.0001)

checkpoint = torch.load('models/sincnet.pth')
cls.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
cls = nn.DataParallel(cls, device_ids = device_ids)
if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']

In [None]:
def cross_entropy(input, target, size_average=True):
    if size_average:
        return torch.mean(torch.sum(-target * torch.log(input), dim=1))
    else:
        return torch.sum(torch.sum(-target * torch.log(input), dim=1))

batch_size = 32
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
#valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)

criterion = cross_entropy

for epoch in range(64):
    tasnet.eval()
    cls.train()
    cls.use_all_chunks = False
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (sig1, sig2, target1, target2) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        sig1 = sig1.float().cuda(device)
        sig2 = sig2.float().cuda(device)
        mixture = sig1+sig2
        target1 = target1.float().cuda(device)
        target2 = target2.float().cuda(device)
        with torch.no_grad():
            estimate_source = tasnet(mixture).detach()
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(torch.stack([sig1, sig2], dim = 1), estimate_source, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*new_sr)
        pred1, pred2 = cls(reorder_estimate_source[:, 0]), cls(reorder_estimate_source[:, 1])
        loss = cross_entropy(pred1, target1)+cross_entropy(pred2, target2)
        loss.backward()
        optimizer.step()
        
        
        pred = torch.stack([pred1, pred2], dim = 0)
        pred, _ = torch.max(pred, dim = 0)
        running_loss += loss.item()
        running_accuracy += compute_corrects(pred, target1+target2)/batch_size

        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d]  loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            torch.save({
            'model_state_dict': cls.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'acc': running_accuracy,
            }, 'models/sincnet.pth')
            running_loss = 0.0
            running_accuracy = 0.0

In [9]:
def compute_half_correct(tensor1, tensor2):
    preds, truth = find_max2(tensor1), tensor2.cpu().detach().numpy()
    batch_size = preds.shape[0]
    half_corrects = 0
    for i in range(batch_size):
        if sum(truth[i][preds[i]]) == 1:
            half_corrects+=1
    return half_corrects

for i in range(5):
    batch_size = 32
    testset = OverlayDataSet('test-segments.csv')
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
    with torch.no_grad():
        corrects = 0
        half_corrects = 0
        count = 0
        tasnet.eval()
        cls.eval()
        cls.use_all_chunks = True   
        for batch_idx, (sig1, sig2, target1, target2) in enumerate(tqdm(testloader)):
            count+=sig1.shape[0]
            sig1 = sig1.float().cuda(device)
            sig2 = sig2.float().cuda(device)
            mixture = sig1+sig2
            target1 = target1.float().cuda(device)
            target2 = target2.float().cuda(device)

            estimate_source = tasnet(mixture).detach()
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(torch.stack([sig1, sig2], dim = 1), estimate_source, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr)
            pred1, pred2 = cls(reorder_estimate_source[:, 0]), cls(reorder_estimate_source[:, 1])

            pred = torch.stack([pred1, pred2], dim = 0)
            pred, _ = torch.max(pred, dim = 0)
            corrects += compute_corrects(pred, target1+target2)
            half_corrects += compute_half_correct(pred, target1+target2)
        print('test accuracy: %.4f, half corrects: %.4f' % (corrects/len(testset), half_corrects/len(testset)))

HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.9134, half corrects: 0.0833


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.9122, half corrects: 0.0853


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.9061, half corrects: 0.0911


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.9106, half corrects: 0.0868


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.9083, half corrects: 0.0892


In [11]:
sig1, sig2, target1, target2 = testset[0]
input = torch.Tensor(sig1+sig2)[None, ...].cpu()
from thop import profile
test_extractor = tasnet.module
test_extractor.cpu()
input = input.detach().cpu()
macs1, params1 = profile(test_extractor, inputs=(input, ))
print(macs1/10**9, params1/10**6)
test_discriminator = cls
test_discriminator.cpu()
macs2, params2 = profile(test_discriminator, inputs=(input,))
print(macs2/10**9, params2/10**6)
print((macs1+macs2*2)/10**9, (params1+params2)/10**6)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[91m[WARN] Cannot find rule for <class 'conv_tasnet.Encoder'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.ChannelwiseLayerNorm'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.PReLU'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.GlobalLayerNorm'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.DepthwiseSeparableConv'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.TemporalBlock'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.TemporalConvNet'>. Treat it as zero Macs a