In [None]:
!pip install -q torch-audiomentations
!pip install -q torchaudio
!pip install -q julius
!pip install -q pesq

In [1]:
import numpy as np
import os
from glob import glob
from IPython.display import Audio
import torch
import matplotlib.pyplot as plt

import librosa
import librosa.display

In [2]:
speak_files=glob("../../datas/cv-corpus-9.0-2022-04-27/zh-TW/clips/*.mp3")
firetruck_files=glob("../../datas/sounds/firetruck/*.wav")
construction_files=glob("../../datas/sounds/splited_construction/*.wav")
print(
    "#人聲:",len(speak_files),
    "\n#消防車聲:",len(firetruck_files),
    "\n#工地聲:",len(construction_files)
)

#人聲: 116969 
#消防車聲: 200 
#工地聲: 11997


In [3]:
def set_seed(seed=2022):
    np.random.seed(seed)
    torch.manual_seed(seed)

In [4]:
from sklearn.model_selection import train_test_split
def split(x,test_val_ratio):
    x0,x1=train_test_split(x,test_size=sum(test_val_ratio))
    x1,x2=train_test_split(x1,test_size=test_val_ratio[-1])
    return x0,x1,x2
set_seed(2022)
sig_train,sig_val,sig_test=split(speak_files,(0.1,0.1))
art_train,art_val,art_test=split(firetruck_files,(0.1,0.1))
noise_train,noise_val,noise_test=split(construction_files,(0.1,0.1))

In [5]:
SEED=28
SAMPLE_RATE=16000
BATCH_SIZE=2

In [6]:
import torch.utils.data as tud
import sirenns.datasets.loader as ldr
L=16000*4
common_kwargs=dict(
    signal_len=L,
    transform=ldr.transform,
    artifact_transform=ldr.artifact_noise_transform,
    noise_transform=ldr.artifact_noise_transform
)

train_ds=ldr.SyntheticCallDataset(sig_train,
                              artifact_files=art_train,
                              noise_files=noise_train,
                              **common_kwargs)
train_dl=tud.DataLoader(train_ds,batch_size=BATCH_SIZE,collate_fn=train_ds.collate_fn)

val_ds=ldr.SyntheticCallDataset(sig_val,
                              artifact_files=art_val,
                              noise_files=noise_val,
                              **common_kwargs)
val_dl=tud.DataLoader(val_ds,batch_size=BATCH_SIZE,collate_fn=val_ds.collate_fn)

In [7]:
%%time
for  batch in val_dl:
    x,signal,artifact,noise= batch
    break

CPU times: user 179 ms, sys: 34.4 ms, total: 214 ms
Wall time: 115 ms


In [8]:
# BATCH_SEED=3
# Audio(x[BATCH_SEED].cpu() ,rate=8000,normalize=False)
# Audio(signal[BATCH_SEED].cpu() ,rate=16000,normalize=False)
# Audio(artifact[BATCH_SEED].cpu() ,rate=16000,normalize=False)
# Audio(noise[BATCH_SEED].cpu() ,rate=16000,normalize=False)
# plt.plot(artifact[BATCH_SEED,0].cpu().numpy().T,alpha=0.9)
# plt.plot(noise[BATCH_SEED,0].cpu().numpy().T,alpha=0.7)
# plt.plot(signal[BATCH_SEED,0].cpu().numpy().T,alpha=0.7)

In [9]:
from torchaudio.models import ConvTasNet
import torch.nn as nn
import sirenns.utils.losses as lsf 
net=ConvTasNet(num_sources=3,
               enc_num_feats=256,
               msk_num_hidden_feats=128
              ).cuda()
# Loss cocmbo 1: MSE+SGD
# criterion = nn.MSELoss()
# optimizer = torch.optim.SGD(net.parameters(),lr=1e-3,momentum=0.9,weight_decay=0.0005)
# Loss cocmbo 2: SDR+clip norm+Adam
criterion = lambda y,pred: lsf.cal_loss(y,pred,float(L))
# Clip Norm
for p in net.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -5, 5))
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)

In [10]:
import torch.nn.functional as F
lamb=[0.8,0.1,0.1]
def one_batch(i_iter,log,sample_batched,model,criterion,optimizer):
    # Prep input
    x,signal,artifact,noise=[_ for _ in sample_batched]
    x=F.interpolate(x,scale_factor=2,mode="linear")
    pred=model(x)
    loss = criterion(pred[:,0:1], signal)*lamb[0]+\
           criterion(pred[:,1:2], artifact)*lamb[1]+\
           criterion(pred[:,2:3], noise)*lamb[2]
    if model.training:
        #Update
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    with torch.no_grad():
        loss_sig=criterion(pred[:,0:1], signal).item()
        if i_iter==0:
            log['loss_sig'] = loss_sig
        else:
            log['loss_sig'] = (loss_sig+log['loss_sig']*i_iter)/(i_iter + 1)
    # Record
    loss_rec = loss.item()
    log['loss'] = (loss_rec+log['loss']*i_iter)/(i_iter + 1)
    return pred,loss

In [11]:
from tqdm import tqdm
os.makedirs("../snapshots",exist_ok=True)

In [None]:
EPOCH=100
net.train()
best_loss=np.inf
PAITIENCE=5
count=0
try:
    for e in range(EPOCH):
        log= {'epoch':e,'step':0,'loss': 0, 'loss_sig': 0}
        net.train()
        session=tqdm(enumerate(train_dl))
        for i_iter, sample_batched in session:
            pred,loss=one_batch(i_iter,log,sample_batched,net,criterion,optimizer)
            # print loss and take snapshots
            if (i_iter + 1) % 5 == 0:
                log['step']=i_iter+1
                session.set_postfix(log)
        # validate
        log = {'epoch':e,'step':'val','loss': 0, 'loss_sig': 0}
        net.eval()
        with torch.no_grad():
            session=tqdm(enumerate(val_dl))
            for i_iter,sample_batched in session:
                
                pred,loss=one_batch(i_iter,log,sample_batched,net,criterion,optimizer)
                if (i_iter + 1) % 5 == 0:
                    session.set_postfix(log)
        # early stop        
        if log["loss_sig"]<best_loss:
            best_loss=log["loss_sig"]
            torch.save(net.state_dict(), '../snapshots/best.pth')
        elif count<=PAITIENCE: count+=1
        else:
            count=0
            best_loss=np.inf
            break 
            
except KeyboardInterrupt:
    print("\nHuman Interrupted")
torch.save(net.state_dict(), '../snapshots/latest.pth')

46788it [2:47:14,  4.66it/s, epoch=0, step=46785, loss=-6.82, loss_sig=-8.78]
10527it [21:46,  8.06it/s, epoch=0, step=val, loss=-7.47, loss_sig=-9.42]
46788it [2:47:26,  4.66it/s, epoch=1, step=46785, loss=-7.64, loss_sig=-9.55]
10527it [21:47,  8.05it/s, epoch=1, step=val, loss=-7.87, loss_sig=-9.8] 
46788it [2:47:43,  4.65it/s, epoch=2, step=46785, loss=-7.95, loss_sig=-9.84]
10527it [21:46,  8.06it/s, epoch=2, step=val, loss=-8.12, loss_sig=-10] 
46788it [2:47:40,  4.65it/s, epoch=3, step=46785, loss=-8.16, loss_sig=-10]  
10527it [21:44,  8.07it/s, epoch=3, step=val, loss=-8.18, loss_sig=-10.1]
46788it [2:47:32,  4.65it/s, epoch=4, step=46785, loss=-8.29, loss_sig=-10.2]
10527it [21:44,  8.07it/s, epoch=4, step=val, loss=-8.27, loss_sig=-10.2]
46788it [2:47:26,  4.66it/s, epoch=5, step=46785, loss=-8.39, loss_sig=-10.3]
10527it [21:44,  8.07it/s, epoch=5, step=val, loss=-8.41, loss_sig=-10.3]
46788it [2:47:31,  4.65it/s, epoch=6, step=46785, loss=-8.49, loss_sig=-10.4]
10527it [21

In [40]:
test_ds=ldr.SyntheticCallDataset(sig_test,
                              artifact_files=art_test,
                              noise_files=noise_test,
                              **common_kwargs)
test_dl=tud.DataLoader(test_ds,batch_size=BATCH_SIZE,collate_fn=test_ds.collate_fn)

In [42]:
import pesq
def mean_pesq(good,bad):
    result=[]
    for g,b in zip(np.squeeze(good),np.squeeze(bad)):
        result.append(pesq.pesq(16000,g,b,mode="wb",on_error=1))
    return np.mean(result)

In [43]:
log = {'epoch':e,
       'step':'val',
       'loss': 0,
       'loss_sig': 0,
       'pesq_pre': 0,
       'pesq_post': 0,
       'pesq_diff':999}
net.eval()
with torch.no_grad():
    session=tqdm(enumerate(test_dl))
    for i_iter,sample_batched in session:
        x,signal,artifact,noise=[_ for _ in sample_batched]
        x=F.interpolate(x,scale_factor=2,mode="linear")
        pred,loss=one_batch(i_iter,log,sample_batched,net,criterion,optimizer)
        pesq_pre=mean_pesq(signal.cpu().numpy(),x.cpu().numpy())
        pesq_post=mean_pesq(signal.cpu().numpy(),pred[:,0:1].cpu().numpy())
        log['pesq_pre'] = (pesq_pre+log['pesq_pre']*i_iter)/(i_iter + 1)
        log['pesq_post'] = (pesq_post+log['pesq_post']*i_iter)/(i_iter + 1)
        log['pesq_diff'] =log['pesq_post']-log['pesq_pre']
        if (i_iter + 1) % 5 == 0:
            session.set_postfix(log)

229it [06:30,  1.71s/it, epoch=11, step=val, loss=0.00114, loss_sig=0.00174, pesq_pre=1.24, pesq_post=1.24, pesq_diff=0.00601]


KeyboardInterrupt: 

In [34]:
log

{'epoch': 11,
 'step': 'val',
 'loss': 0.001225284591782838,
 'loss_sig': 0.0020091817423235625,
 'pesq_pre': 1.5318121959765751,
 'pesq_post': 1.58041051030159,
 'pesq_diff': 0.048598314325014824}