In [16]:
import sys
from librosa.core import resample
import pandas as pd
import numpy as np
from IPython.display import Audio
import torch
import pathlib
def create_dir(filename):
    pathlib.Path('/'.join(filename.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
from tqdm.notebook import tqdm
sys.path.append('Conv-TasNet/src/')
sys.path.append('SincNet/')
from conv_tasnet import *
from pit_criterion import cal_loss
from dnn_models import *
from data_io import ReadList,read_conf_inp,str_to_bool
from collections import Counter
import os
device = 0
device_ids = [0, 1, 2, 3]
root = '../'
sr = 8000
torch.cuda.set_device(device)

In [17]:
tas2 = ConvTasNet.load_model('final.pth.tar').cuda(device)
tasnet = ConvTasNet(N = 256, L = 20, B = 256, H = 512, P = 3, X = 8, R = 4, C = 3, norm_type="gLN", causal=0,
             mask_nonlinear='relu').cuda(device)

own_state = tasnet.state_dict()
for name, param in tqdm(tas2.state_dict().items()):
    if name not in own_state:
         continue
    if isinstance(param, torch.nn.Parameter):
        param = param.data
    try:
        own_state[name].copy_(param)
    except:
        print('shape mismatch')

optimizer = torch.optim.Adam(tasnet.parameters(), lr = 0.001)

if os.path.exists('models/tasnet.pth'):
    print('load model')
    checkpoint = torch.load('models/tasnet.pth')
    tasnet.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss = checkpoint['loss']
    
#tasnet = nn.DataParallel(tasnet, device_ids = device_ids)
tasnet.train()
pass

HBox(children=(FloatProgress(value=0.0, max=294.0), HTML(value='')))

shape mismatch

load model


In [18]:
class OverlayDataSet(torch.utils.data.Dataset):
    def __init__(self, csv):
        super().__init__()
        self.segments = pd.read_csv(root+csv)
        self.speakers = list(set(self.segments['speaker']))
        self.speakers.sort()
        self.spkr2idx = {spkr:i for i, spkr in enumerate(self.speakers)}

    def __len__(self):
        return len(self.segments)
    def __getitem__(self, idx):
        seg1 = self.segments.iloc[idx]
        seg2 = self.segments.iloc[np.random.randint(len(self.segments))]
        while(seg1['speaker']==seg2['speaker']):
            seg2 = self.segments.iloc[np.random.randint(len(self.segments))]
        seg3 = self.segments.iloc[np.random.randint(len(self.segments))]
        while(seg3['speaker']==seg2['speaker'] or seg3['speaker']==seg1['speaker']):
            seg3 = self.segments.iloc[np.random.randint(len(self.segments))]
        
        
        sig1 = np.load(root+seg1['segfile'])
        sig2 = np.load(root+seg2['segfile'])
        sig3 = np.load(root+seg3['segfile'])

        
        out_vec1 = np.zeros(len(self.speakers)) # maybe try PIT training too
        out_vec2 = np.zeros(len(self.speakers)) # maybe try PIT training too
        out_vec3 = np.zeros(len(self.speakers)) # maybe try PIT training too

        out_vec1[self.spkr2idx[seg1['speaker']]] = 1
        out_vec2[self.spkr2idx[seg2['speaker']]] = 1
        out_vec3[self.spkr2idx[seg3['speaker']]] = 1

        return sig1, sig2, sig3, out_vec1, out_vec2, out_vec3


#mean, std = compute_mean_std('overlay-train.csv')


trainset = OverlayDataSet('train-segments.csv')
valset = OverlayDataSet('val-segments.csv')
testset = OverlayDataSet('test-segments.csv')
idx = np.random.randint(len(testset))

sig1, sig2, sig3, target1, target2, target3 = testset[idx]

mixture = torch.Tensor(sig1+sig2+sig3).cuda(device)
mixture = mixture[None, ...]
out = tasnet(mixture)
new_sig1, new_sig2, new_sig3 = out[0].cpu().detach().numpy()
Audio(new_sig1, rate = sr)

In [19]:
Audio(new_sig2, rate = sr)

In [20]:
Audio(new_sig3, rate = sr)

In [15]:
batch_size = 8
trainloader  = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
valloader  = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
tasnet.train()

for epoch in range(64):
    running_loss = 0.0
    for batch_idx, (sig1, sig2, sig3, _, _, _) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        sig1, sig2, sig3 = sig1.float().cuda(device), sig2.float().cuda(device), sig3.float().cuda(device)
        out = tasnet(sig1+sig2+sig3)
        source = torch.stack([sig1, sig2, sig3], dim = 1).detach()
        loss, max_snr, estimate_source, reorder_estimate_source = \
            cal_loss(source, out, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr) # 2 seconds has 2*sr samples
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f ' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
            torch.save({
            'model_state_dict': tasnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, 'models/tasnet.pth')
    with torch.no_grad():
        running_loss = 0.0
        tasnet.eval()
        valloader  = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, pin_memory = True, num_workers = 16)
        for batch_idx, (sig1, sig2, sig3, _, _, _) in enumerate(tqdm(valloader)):
            sig1, sig2, sig3 = sig1.float().cuda(device), sig2.float().cuda(device), sig3.float().cuda(device)
            out = tasnet(sig1+sig2+sig3)
            source = torch.stack([sig1, sig2, sig3], dim = 1).detach()
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(source, out, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr)
            torch.nn.utils.clip_grad_norm_(tasnet.parameters(), 0.5)

            running_loss += loss.item()*sig1.shape[0]
        print(running_loss/len(valset))

HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[1,   200] loss: -0.307 
[1,   400] loss: -0.739 
[1,   600] loss: -1.386 
[1,   800] loss: -1.587 
[1,  1000] loss: -1.818 
[1,  1200] loss: -1.979 
[1,  1400] loss: -2.426 
[1,  1600] loss: -2.185 
[1,  1800] loss: -2.419 
[1,  2000] loss: -2.258 
[1,  2200] loss: -2.564 
[1,  2400] loss: -2.780 
[1,  2600] loss: -2.796 
[1,  2800] loss: -2.841 
[1,  3000] loss: -3.021 
[1,  3200] loss: -3.028 
[1,  3400] loss: -2.900 
[1,  3600] loss: -2.972 
[1,  3800] loss: -2.838 
[1,  4000] loss: -3.245 
[1,  4200] loss: -3.491 
[1,  4400] loss: -3.342 
[1,  4600] loss: -3.281 
[1,  4800] loss: -3.584 



HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))


-3.366440252639243


HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[2,   200] loss: -3.655 
[2,   400] loss: -3.307 
[2,   600] loss: -3.760 
[2,   800] loss: -3.655 
[2,  1000] loss: -3.709 
[2,  1200] loss: -3.579 
[2,  1400] loss: -3.944 
[2,  1600] loss: -3.678 
[2,  1800] loss: -3.826 
[2,  2000] loss: -3.730 
[2,  2200] loss: -3.686 
[2,  2400] loss: -3.871 
[2,  2600] loss: -4.037 
[2,  2800] loss: -3.982 
[2,  3000] loss: -4.128 
[2,  3200] loss: -4.055 
[2,  3400] loss: -3.889 
[2,  3600] loss: -4.012 
[2,  3800] loss: -3.778 
[2,  4000] loss: -4.126 
[2,  4200] loss: -4.550 
[2,  4400] loss: -4.174 
[2,  4600] loss: -4.071 
[2,  4800] loss: -4.450 



HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))


-4.054787210591158


HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[3,   200] loss: -4.546 
[3,   400] loss: -4.148 
[3,   600] loss: -4.624 
[3,   800] loss: -4.511 
[3,  1000] loss: -4.264 
[3,  1200] loss: -4.473 
[3,  1400] loss: -4.602 
[3,  1600] loss: -4.425 
[3,  1800] loss: -4.372 
[3,  2000] loss: -4.010 
[3,  2200] loss: -4.366 
[3,  2400] loss: -4.883 
[3,  2600] loss: -4.670 
[3,  2800] loss: -4.616 
[3,  3000] loss: -4.641 
[3,  3200] loss: -4.708 
[3,  3400] loss: -4.503 
[3,  3600] loss: -4.625 
[3,  3800] loss: -4.528 
[3,  4000] loss: -4.765 
[3,  4200] loss: -4.949 
[3,  4400] loss: -4.765 
[3,  4600] loss: -4.622 
[3,  4800] loss: -4.810 



HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))


-4.627735663727376


HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[4,   200] loss: -5.036 
[4,   400] loss: -4.618 
[4,   600] loss: -5.004 
[4,   800] loss: -4.851 
[4,  1000] loss: -4.790 
[4,  1200] loss: -4.795 
[4,  1400] loss: -5.090 
[4,  1600] loss: -4.770 
[4,  1800] loss: -4.853 
[4,  2000] loss: -4.530 
[4,  2200] loss: -4.899 
[4,  2400] loss: -5.101 
[4,  2600] loss: -5.027 
[4,  2800] loss: -5.042 
[4,  3000] loss: -5.033 
[4,  3200] loss: -4.990 
[4,  3400] loss: -4.959 
[4,  3600] loss: -4.961 
[4,  3800] loss: -4.842 
[4,  4000] loss: -4.978 
[4,  4200] loss: -5.464 
[4,  4400] loss: -4.935 
[4,  4600] loss: -4.946 
[4,  4800] loss: -5.226 



HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))


-4.883883290704826


HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[5,   200] loss: -5.367 
[5,   400] loss: -5.025 
[5,   600] loss: -5.350 
[5,   800] loss: -5.199 
[5,  1000] loss: -5.144 
[5,  1200] loss: -5.175 
[5,  1400] loss: -5.318 
[5,  1600] loss: -5.147 
[5,  1800] loss: -5.255 
[5,  2000] loss: -4.958 
[5,  2200] loss: -5.266 
[5,  2400] loss: -5.454 
[5,  2600] loss: -5.300 
[5,  2800] loss: -5.401 
[5,  3000] loss: -5.384 
[5,  3200] loss: -5.356 
[5,  3400] loss: -5.131 
[5,  3600] loss: -5.354 
[5,  3800] loss: -5.131 
[5,  4000] loss: -5.381 
[5,  4200] loss: -5.773 
[5,  4400] loss: -5.321 
[5,  4600] loss: -5.208 
[5,  4800] loss: -5.504 



HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))


-5.173843941737184


HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[6,   200] loss: -5.566 
[6,   400] loss: -5.407 
[6,   600] loss: -5.728 
[6,   800] loss: -5.566 
[6,  1000] loss: -5.329 
[6,  1200] loss: -5.414 
[6,  1400] loss: -5.722 
[6,  1600] loss: -5.396 
[6,  1800] loss: -5.515 
[6,  2000] loss: -5.166 
[6,  2200] loss: -5.471 
[6,  2400] loss: -5.787 
[6,  2600] loss: -5.498 
[6,  2800] loss: -5.504 
[6,  3000] loss: -5.588 
[6,  3200] loss: -5.606 
[6,  3400] loss: -5.364 
[6,  3600] loss: -5.486 
[6,  3800] loss: -5.345 
[6,  4000] loss: -5.635 
[6,  4200] loss: -5.945 
[6,  4400] loss: -5.561 
[6,  4600] loss: -5.488 
[6,  4800] loss: -5.761 



HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))


-5.494492834418013


HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[7,   200] loss: -5.943 
[7,   400] loss: -5.552 
[7,   600] loss: -5.942 
[7,   800] loss: -5.741 
[7,  1000] loss: -5.445 
[7,  1200] loss: -5.582 
[7,  1400] loss: -5.813 
[7,  1600] loss: -5.557 
[7,  1800] loss: -5.604 
[7,  2000] loss: -5.424 
[7,  2200] loss: -5.792 
[7,  2400] loss: -5.894 
[7,  2600] loss: -5.831 
[7,  2800] loss: -5.699 
[7,  3000] loss: -5.828 
[7,  3200] loss: -5.714 
[7,  3400] loss: -5.816 
[7,  3600] loss: -5.673 
[7,  3800] loss: -5.608 
[7,  4000] loss: -5.757 
[7,  4200] loss: -6.027 
[7,  4400] loss: -5.666 
[7,  4600] loss: -5.617 
[7,  4800] loss: -6.027 



HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))


-5.503629755333466


HBox(children=(FloatProgress(value=0.0, max=4803.0), HTML(value='')))

[8,   200] loss: -6.031 
[8,   400] loss: -5.632 
[8,   600] loss: -6.117 



KeyboardInterrupt: 

In [21]:
def find_max3(tensor):
    array = tensor.cpu().detach().numpy()
    max3 = []
    for row in array:
        max3.append(np.argsort(row)[::-1][:3])
    return np.array(max3)

def compute_corrects(tensor1, tensor2):
    max_1, max_2 = find_max3(tensor1), find_max3(tensor2)
    batch_size = max_1.shape[0]
    batch_corrects = 0
    for i in range(batch_size):
        if Counter(max_1[i])==Counter(max_2[i]):
            batch_corrects+=1
    return batch_corrects

In [29]:
fs=sr
cw_len=200
cw_shift=10

wlen=int(fs*cw_len/1000.00)
wshift=int(fs*cw_shift/1000.00)




class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        cnn_arch = {
                'input_dim':wlen,
                'fs':fs,
                'cnn_N_filt':[80,60,60],
                'cnn_len_filt':[251,5,5],
                'cnn_max_pool_len':[3,3,3],
                'cnn_use_laynorm_inp':True,
                'cnn_use_batchnorm_inp':False,
                'cnn_use_laynorm':[True,True,True],
                'cnn_use_batchnorm':[False,False,False],
                'cnn_act':['leaky_relu','leaky_relu','leaky_relu'],
                'cnn_drop':[0.0,0.0,0.0]
                }
        self.cnn_net = SincNet(cnn_arch)

        dnn1_arch = {'input_dim': self.cnn_net.out_dim,
                  'fc_lay': [2048,2048,2048],
                  'fc_drop': [0.0,0.0,0.0], 
                  'fc_use_batchnorm': [True,True,True],
                  'fc_use_laynorm': [False,False,False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['leaky_relu','leaky_relu','leaky_relu']
                  }
        self.dnn1 = MLP(dnn1_arch)


        dnn2_arch = {'input_dim':2048 ,
                  'fc_lay': [20],
                  'fc_drop': [0.0], 
                  'fc_use_batchnorm': [False],
                  'fc_use_laynorm': [False],
                  'fc_use_laynorm_inp': False,
                  'fc_use_batchnorm_inp': False,
                  'fc_act': ['linear'] # leakyrelu(1) is just identity mapping
                  }
        self.dnn2 = MLP(dnn2_arch)
        
        self.softmax = nn.Softmax(dim = 1)
        self.use_all_chunks = self.training
        
    def chop_chunk(self, signal):
        batch_size, signal_len = signal.shape
        N_fr=(signal_len-wlen)//wshift
        chunks = []
        for i in range(N_fr):
            chunks.append(signal[..., i*wshift:i*wshift+wlen]) # list of N_fr elements, each [batch_size, wlen]
        return chunks
    
    def estimate(self, chunks):
        out_vecs = []
        n_fr = len(chunks)
        batch_size, wlen = chunks[0].shape
        if not self.use_all_chunks: # ~200 chunk, each shift by 10ms, length 200 ms
            chunks = chunks[::20]
            n_fr = len(chunks)
        chunks = torch.stack(chunks, dim = 1) # [batch_size , n_fr , wlen]
        chunks = chunks.view(-1, wlen) # [batch_size*n_fr, wlen]
        out_vecs = self.softmax(self.dnn2(self.dnn1(self.cnn_net(chunks)))).clamp(min=1e-8) # [batch_size*n_fr, num_spkr]
        out_tensor = out_vecs.view(batch_size, n_fr, 20) # [batch_size, n_fr, num_spkr]
        out_tensor = out_tensor.mean(dim = 1) # batch_size*N_spkr
        return out_tensor 
    
    def forward(self, signal):
        X = self.chop_chunk(signal)
        out = self.estimate(X)
        return out

In [30]:
load_model = True

cls = Classifier().cuda(device)
optimizer = torch.optim.Adam(cls.parameters(), 0.0001)

checkpoint = torch.load('models/sincnet.pth')
cls.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#cls = nn.DataParallel(cls, device_ids = device_ids)
if 'bestacc' in checkpoint:
        bestacc = checkpoint['bestacc']

In [31]:
def cross_entropy(input, target, size_average=True):
    if size_average:
        return torch.mean(torch.sum(-target * torch.log(input), dim=1))
    else:
        return torch.sum(torch.sum(-target * torch.log(input), dim=1))

batch_size = 32
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
#valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)

criterion = cross_entropy

for epoch in range(64):
    tasnet.eval()
    cls.train()
    cls.use_all_chunks = False
    running_loss = 0.0
    running_accuracy = 0.0
    for batch_idx, (sig1, sig2, sig3, target1, target2, target3) in enumerate(tqdm(trainloader)):
        optimizer.zero_grad()
        sig1 = sig1.float().cuda(device)
        sig2 = sig2.float().cuda(device)
        sig3 = sig3.float().cuda(device)

        mixture = sig1+sig2+sig3
        target1 = target1.float().cuda(device)
        target2 = target2.float().cuda(device)
        target3 = target3.float().cuda(device)

        with torch.no_grad():
            estimate_source = tasnet(mixture).detach()
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(torch.stack([sig1, sig2, sig3], dim = 1), estimate_source, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr)
        pred1, pred2, pred3 = cls(reorder_estimate_source[:, 0]), cls(reorder_estimate_source[:, 1]), cls(reorder_estimate_source[:, 2])
        loss = cross_entropy(pred1, target1)+cross_entropy(pred2, target2)+cross_entropy(pred3, target3)
        loss.backward()
        optimizer.step()
        
        
        pred = torch.stack([pred1, pred2, pred3], dim = 0)
        pred, _ = torch.max(pred, dim = 0)
        running_loss += loss.item()
        running_accuracy += compute_corrects(pred, target1+target2+target3)/batch_size

        if batch_idx % 200 == 199:    # print every 2000 mini-batches
            print('[%d, %5d]  loss: %.3f accuracy: %.3f' % 
                  (epoch + 1, batch_idx + 1, running_loss / 200, running_accuracy / 200))
            torch.save({
            'model_state_dict': cls.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'acc': running_accuracy,
            }, 'models/sincnet.pth')
            running_loss = 0.0
            running_accuracy = 0.0

HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[1,   200]  loss: 6.575 accuracy: 0.399
[1,   400]  loss: 5.813 accuracy: 0.486
[1,   600]  loss: 5.990 accuracy: 0.496
[1,   800]  loss: 5.899 accuracy: 0.535
[1,  1000]  loss: 5.847 accuracy: 0.507
[1,  1200]  loss: 5.817 accuracy: 0.555



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[2,   200]  loss: 5.780 accuracy: 0.553
[2,   400]  loss: 5.682 accuracy: 0.562
[2,   600]  loss: 5.735 accuracy: 0.570
[2,   800]  loss: 5.721 accuracy: 0.583
[2,  1000]  loss: 5.656 accuracy: 0.551
[2,  1200]  loss: 5.604 accuracy: 0.583



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[3,   200]  loss: 5.664 accuracy: 0.573
[3,   400]  loss: 5.568 accuracy: 0.583
[3,   600]  loss: 5.798 accuracy: 0.569
[3,   800]  loss: 5.580 accuracy: 0.585
[3,  1000]  loss: 5.640 accuracy: 0.569
[3,  1200]  loss: 5.707 accuracy: 0.592



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[4,   200]  loss: 5.537 accuracy: 0.601
[4,   400]  loss: 5.520 accuracy: 0.605
[4,   600]  loss: 5.629 accuracy: 0.597
[4,   800]  loss: 5.524 accuracy: 0.603
[4,  1000]  loss: 5.479 accuracy: 0.604
[4,  1200]  loss: 5.611 accuracy: 0.615



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[5,   200]  loss: 5.600 accuracy: 0.610
[5,   400]  loss: 5.462 accuracy: 0.610
[5,   600]  loss: 5.581 accuracy: 0.624
[5,   800]  loss: 5.544 accuracy: 0.602
[5,  1000]  loss: 5.463 accuracy: 0.613
[5,  1200]  loss: 5.612 accuracy: 0.620



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[6,   200]  loss: 5.478 accuracy: 0.620
[6,   400]  loss: 5.504 accuracy: 0.618
[6,   600]  loss: 5.566 accuracy: 0.636
[6,   800]  loss: 5.420 accuracy: 0.636
[6,  1000]  loss: 5.401 accuracy: 0.632
[6,  1200]  loss: 5.540 accuracy: 0.641



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[7,   200]  loss: 5.321 accuracy: 0.645
[7,   400]  loss: 5.403 accuracy: 0.636
[7,   600]  loss: 5.480 accuracy: 0.649
[7,   800]  loss: 5.448 accuracy: 0.641
[7,  1000]  loss: 5.438 accuracy: 0.622
[7,  1200]  loss: 5.456 accuracy: 0.637



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[8,   200]  loss: 5.437 accuracy: 0.651
[8,   400]  loss: 5.298 accuracy: 0.653
[8,   600]  loss: 5.535 accuracy: 0.644
[8,   800]  loss: 5.420 accuracy: 0.647
[8,  1000]  loss: 5.336 accuracy: 0.646
[8,  1200]  loss: 5.392 accuracy: 0.661



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[9,   200]  loss: 5.280 accuracy: 0.658
[9,   400]  loss: 5.344 accuracy: 0.653
[9,   600]  loss: 5.504 accuracy: 0.655
[9,   800]  loss: 5.401 accuracy: 0.651
[9,  1000]  loss: 5.260 accuracy: 0.657
[9,  1200]  loss: 5.395 accuracy: 0.651



HBox(children=(FloatProgress(value=0.0, max=1201.0), HTML(value='')))

[10,   200]  loss: 5.342 accuracy: 0.655
[10,   400]  loss: 5.293 accuracy: 0.651
[10,   600]  loss: 5.437 accuracy: 0.652



KeyboardInterrupt: 

In [32]:
for i in range(5):
    batch_size = 32
    testset = OverlayDataSet('test-segments.csv')
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, pin_memory = True, num_workers = 16)
    with torch.no_grad():
        corrects = 0
        half_corrects = 0
        count = 0
        tasnet.eval()
        cls.eval()
        cls.use_all_chunks = True   
        for batch_idx, (sig1, sig2, sig3, target1, target2, target3) in enumerate(tqdm(testloader)):
            count+=sig1.shape[0]
            sig1 = sig1.float().cuda(device)
            sig2 = sig2.float().cuda(device)
            sig3 = sig3.float().cuda(device)
            mixture = sig1+sig2+sig3
            target1 = target1.float().cuda(device)
            target2 = target2.float().cuda(device)
            target3 = target3.float().cuda(device)

            estimate_source = tasnet(mixture).detach()
            loss, max_snr, estimate_source, reorder_estimate_source = \
                cal_loss(torch.stack([sig1, sig2, sig3], dim = 1), estimate_source, torch.ones(sig1.shape[0], dtype = torch.int32).cuda(device)*2*sr)
            pred1, pred2, pred3 = cls(reorder_estimate_source[:, 0]), cls(reorder_estimate_source[:, 1]), cls(reorder_estimate_source[:, 2])

            pred = torch.stack([pred1, pred2, pred3], dim = 0)
            pred, _ = torch.max(pred, dim = 0)
            corrects += compute_corrects(pred, target1+target2+target3)
        print('test accuracy: %.4f' % (corrects/len(testset)))

HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.5794


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.5778


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.5866


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.5872


HBox(children=(FloatProgress(value=0.0, max=153.0), HTML(value='')))


test accuracy: 0.5767


In [11]:
sig1, sig2, sig3, target1, target2, target3 = testset[0]
input = torch.Tensor(sig1+sig2+sig3)[None, ...].cpu()
from thop import profile
test_extractor = tasnet.module
test_extractor.cpu()
input = input.detach().cpu()
macs1, params1 = profile(test_extractor, inputs=(input, ))
print(macs1/10**9, params1/10**6)
test_discriminator = cls
test_discriminator.cpu()
macs2, params2 = profile(test_discriminator, inputs=(input,))
print(macs2/10**9, params2/10**6)
print((macs1+macs2*2)/10**9, (params1+params2)/10**6)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[91m[WARN] Cannot find rule for <class 'conv_tasnet.Encoder'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.ChannelwiseLayerNorm'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.activation.PReLU'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.GlobalLayerNorm'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.DepthwiseSeparableConv'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.TemporalBlock'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class 'conv_tasnet.TemporalConvNet'>. Treat it as zero Macs a