In [None]:
!pip install -q torch-audiomentations
!pip install -q torchaudio
!pip install -q julius

In [1]:
# !sudo apt-get update
# !sudo apt-get -y install python3-pyaudio
# import sounddevice as sd
# sd.query_devices()

In [1]:
import numpy as np
from glob import glob
from IPython.display import Audio
import torch
import matplotlib.pyplot as plt

import librosa
import librosa.display

In [2]:
speak_files=glob("../../datas/cv-corpus-9.0-2022-04-27/zh-TW/clips/*.mp3")
firetruck_files=glob("../../datas/sounds/firetruck/*.wav")
construction_files=glob("../../datas/sounds/splited_construction/*.wav")
print(
    "#人聲:",len(speak_files),
    "\n#消防車聲:",len(firetruck_files),
    "\n#工地聲:",len(construction_files)
)

#人聲: 116969 
#消防車聲: 200 
#工地聲: 11997


In [3]:
def set_seed(seed=2022):
    np.random.seed(seed)
    torch.manual_seed(seed)

In [4]:
from sklearn.model_selection import train_test_split
def split(x,test_val_ratio):
    x0,x1=train_test_split(x,test_size=sum(test_val_ratio))
    x1,x2=train_test_split(x1,test_size=test_val_ratio[-1])
    return x0,x1,x2
set_seed(2022)
sig_train,sig_val,sig_test=split(speak_files,(0.1,0.1))
art_train,art_val,art_test=split(firetruck_files,(0.1,0.1))
noise_train,noise_val,noise_test=split(construction_files,(0.1,0.1))

In [5]:
SEED=28
SAMPLE_RATE=16000
BATCH_SIZE=4

In [6]:
with torch.cuda.device(0):
    import torch.utils.data as tud
    import sirenns.datasets.loader as ldr
    L=16000*4

    common_kwargs=dict(
        signal_len=L,
        transform=ldr.transform,
        artifact_transform=ldr.artifact_noise_transform,
        noise_transform=ldr.artifact_noise_transform
    )

    test_ds=ldr.SyntheticCallDataset(sig_test,
                                  artifact_files=art_test,
                                  noise_files=noise_test,
                                  **common_kwargs)
    test_dl=tud.DataLoader(test_ds,batch_size=BATCH_SIZE,collate_fn=test_ds.collate_fn)

In [34]:
%%time
with torch.cuda.device(0):
    for  batch in test_dl:
        x,signal,artifact,noise= map(lambda x: x,batch)
        break

CPU times: user 350 ms, sys: 8.72 ms, total: 359 ms
Wall time: 153 ms


In [8]:
# BATCH_SEED=3
# Audio(x[BATCH_SEED].cpu() ,rate=8000,normalize=False)
# Audio(signal[BATCH_SEED].cpu() ,rate=16000,normalize=False)
# Audio(artifact[BATCH_SEED].cpu() ,rate=16000,normalize=False)
# Audio(noise[BATCH_SEED].cpu() ,rate=16000,normalize=False)
# plt.plot(artifact[BATCH_SEED,0].cpu().numpy().T,alpha=0.9)
# plt.plot(noise[BATCH_SEED,0].cpu().numpy().T,alpha=0.7)
# plt.plot(signal[BATCH_SEED,0].cpu().numpy().T,alpha=0.7)

In [8]:
with torch.cuda.device(0):
#     from torchaudio.models import ConvTasNet
    from sirenns.models.residual import ResidualConvTas
    import torch.nn as nn
    import sirenns.utils.losses as lsf 
    net=ResidualConvTas(enc_num_feats=128,
                       msk_num_hidden_feats=64,
                       device=torch.device("cuda"))
    net.load_state_dict(torch.load("../snapshots/2convtas/best.pth"))
#     net.cuda()
    # Loss cocmbo 1: MSE+SGD
    # criterion = nn.MSELoss()
    # optimizer = torch.optim.SGD(net.parameters(),lr=1e-3,momentum=0.9,weight_decay=0.0005)
    # Loss cocmbo 2: SDR+clip norm+Adam
    criterion = lambda y,pred: lsf.cal_loss(y,pred,float(L))
    # Clip Norm
    for p in net.parameters():
        p.register_hook(lambda grad: torch.clamp(grad, -5, 5))
    optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)

In [9]:
!pip install -q torch-summary

In [10]:
from torchsummary import summary
with torch.cuda.device(0):
    summary(net,input_size=(1,L),batch_size=2)

Layer (type:depth-idx)                   Param #
├─ConvTasNet: 1-1                        --
|    └─Conv1d: 2-1                       2,048
|    └─MaskGenerator: 2-2                --
|    |    └─GroupNorm: 3-1               256
|    |    └─Conv1d: 3-2                  16,512
|    |    └─ModuleList: 3-3              601,520
|    |    └─PReLU: 3-4                   1
|    |    └─Conv1d: 3-5                  33,024
|    └─ConvTranspose1d: 2-3              2,048
├─ConvTasNet: 1-2                        --
|    └─Conv1d: 2-4                       2,048
|    └─MaskGenerator: 2-5                --
|    |    └─GroupNorm: 3-6               256
|    |    └─Conv1d: 3-7                  16,512
|    |    └─ModuleList: 3-8              601,520
|    |    └─PReLU: 3-9                   1
|    |    └─Conv1d: 3-10                 16,512
|    └─ConvTranspose1d: 2-6              2,048
├─ResampleFrac: 1-3                      --
Total params: 1,294,306
Trainable params: 1,294,306
Non-trainable params: 0


In [40]:
with torch.no_grad():
    import torch.nn.functional as F
    
    pred=net(net.resamp_8Kto16K(x))

In [36]:
BATCH_SEED=0
Audio(x[BATCH_SEED,0].detach().cpu().numpy(),rate=8000,normalize=True)

In [37]:
Audio(signal[BATCH_SEED,0].detach().cpu().numpy(),rate=16000,normalize=True)

In [41]:
Audio(pred[BATCH_SEED,0].detach().cpu().numpy(),rate=16000,normalize=True)

In [26]:
with torch.cuda.device(0):

    import torch.nn.functional as F
    lamb=[0.8,0.1,0.1]
    def one_batch(i_iter,log,sample_batched,model,criterion,optimizer):
        # Prep input
        x,signal,artifact,noise=[_ for _ in sample_batched]
        x=model.resamp_8Kto16K(x)
        pred=model(x)
        loss = criterion(pred[:,0:1], signal)*lamb[0]+\
               criterion(pred[:,1:2], artifact)*lamb[1]+\
               criterion(pred[:,2:3], noise)*lamb[2]
        if model.training:
            #Update
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        with torch.no_grad():
            loss_sig=criterion(pred[:,0:1], signal).item()
            log['loss_sig'] = (loss_sig+log['loss_sig']*i_iter)/(i_iter + 1)
            loss_rec = loss.item()
            log['loss'] = (loss_rec+log['loss']*i_iter)/(i_iter + 1)
        return pred,loss

In [27]:
with torch.cuda.device(0):

    import pesq
    def mean_pesq(good,bad):
        result=[]
        for g,b in zip(np.squeeze(good),np.squeeze(bad)):
            result.append(pesq.pesq(16000,g,b,mode="nb",on_error=1))
        return np.mean(result)

In [28]:
with torch.cuda.device(0):

    from tqdm import tqdm
    log = {'step':'val',
           'loss': 0,
           'loss_sig': 0,
           'pesq_pre': 0,
           'pesq_post': 0,
           'pesq_diff':999}
    net.eval()
    with torch.no_grad():
        session=tqdm(enumerate(test_dl))
        for i_iter,sample_batched in session:
            x,signal,artifact,noise=[_ for _ in sample_batched]
            x=net.resamp_8Kto16K(x)
            pred,loss=one_batch(i_iter,log,sample_batched,net,criterion,optimizer)
            pesq_pre=mean_pesq(signal.cpu().numpy(),x.cpu().numpy())
            pesq_post=mean_pesq(signal.cpu().numpy(),pred[:,0:1].cpu().numpy())
            log['pesq_pre'] = (pesq_pre+log['pesq_pre']*i_iter)/(i_iter + 1)
            log['pesq_post'] = (pesq_post+log['pesq_post']*i_iter)/(i_iter + 1)
            log['pesq_diff'] =log['pesq_post']-log['pesq_pre']
            if (i_iter + 1) % 5 == 0:
                session.set_postfix(log)

585it [09:35,  1.02it/s, step=val, loss=-8.62, loss_sig=-10.7, pesq_pre=1.49, pesq_post=2.21, pesq_diff=0.717]


In [17]:
with torch.cuda.device(1):
    from torchaudio.models import ConvTasNet
    import torch.nn as nn
    nets=[ConvTasNet(num_sources=1,
                   enc_num_feats=128,
                   msk_num_hidden_feats=128
                  ).cuda()]
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(net.parameters(),lr=1e-3,momentum=0.9,weight_decay=0.0005)

In [21]:
import torch.nn.functional as F
def one_batch(i_iter,log,sample_batched,model,criterion,optimizer):
    # Prep input
    x,signal,artifact,noise=[_ for _ in sample_batched]
    x=F.interpolate(x,scale_factor=2,mode="linear")
    pred=model(x)
    loss = criterion(pred[:,0:1], signal)+\
           criterion(pred[:,1:2], artifact)+\
           criterion(pred[:,2:3], noise)
    if model.training:
        #Update
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    with torch.no_grad():
        loss_sig=criterion(pred[:,0:1], signal).item()
        if i_iter==0:
            log['loss_sig'] = loss_sig
        else:
            log['loss_sig'] = (loss_sig+log['loss_sig']*i_iter)/(i_iter + 1)
    # Record
    loss_rec = loss.item()
    log['loss'] = (loss_rec+log['loss']*i_iter)/(i_iter + 1)
    return pred,loss

In [34]:
!mkdir snapshots

mkdir: cannot create directory ‘snapshots’: File exists


In [12]:
from tqdm import tqdm

In [14]:
with torch.cuda.device(1):

    EPOCH=100
    net.train()
    best_loss=np.inf
    PAITIENCE=5
    count=0
    try:
        for e in range(EPOCH):
            log= {'epoch':e,'step':0,'loss': 0, 'loss_sig': 0}
            net.train()
            session=tqdm(enumerate(train_dl))
            for i_iter, sample_batched in session:
                pred,loss=one_batch(i_iter,log,sample_batched,net,criterion,optimizer)
                # print loss and take snapshots
                if (i_iter + 1) % 5 == 0:
                    log['step']=i_iter+1
                    session.set_postfix(log)
            # validate
            if (e + 1) % 2 == 0:
                log = {'epoch':e,'step':'val','loss': 0, 'acc': 0}
                net.eval()
                with torch.no_grad():
                    session=tqdm(enumerate(val_dl))
                    for i_iter,sample_batched in session:
                        pred,loss=one_batch(i_iter,log,sample_batched,net,criterion,optimizer)
                        if (i_iter + 1) % 5 == 0:
                            session.set_postfix(log)
                # early stop        
                if log["loss"]<best_loss:
                    best_loss=log["loss"]
                    torch.save(net.state_dict(), 'snapshots1/best.pth')
                elif count<=PAITIENCE: count+=1
                else:
                    count=0
                    best_loss=np.inf
                    break 

    except KeyboardInterrupt:
        print("\nHuman Interrupted")
    torch.save(net.state_dict(), 'snapshots1/latest.pth')

11it [00:05,  1.96it/s, epoch=0, step=10, loss=0.0764, loss_sig=0.0439]



Human Interrupted
