### Inferenced

In [8]:
!python inference.py -m samples/5ed8a26d77f2a_5f59d27615178.flac -c configs/dualpathrnn.yml -w weights/dualpathrnn/DualPath_RNN_epoch_395_-9.2936.pt -s samples

saved in: samples/DualPath_RNN/spk1/5ed8a26d77f2a.flac
saved in: samples/DualPath_RNN/spk2/5f59d27615178.flac


### Ploting loss

In [None]:
from utils.plot import save_graph_tb_log_metrics

**Dualpathrnn**

In [None]:
save_graph_tb_log_metrics(
    first_csv_path='checkpoints/train_rnn/Loss_Train.csv',
    second_csv_path='checkpoints/train_rnn/Loss_Validation.csv',
    name_ox='Epoch',
    name_oy='Loss',
    loc = 'upper right',
    pth_save='pics/Loss_Train_Val.png'
)

**ConvTasnet**

**Sepformer**

### Evaluation

In [1]:
import argparse
import sys

import torch
from torchmetrics.audio import PermutationInvariantTraining as PIT
from torchmetrics.functional.audio import signal_distortion_ratio as sdr
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio as sisnr
from tqdm.notebook import tqdm

from utils.load_config import load_config  
from models import MODELS
from data import DiarizationDataset

In [2]:
def evaluate(cfg, test_dataloader, weight):
    model_class = MODELS[cfg['trainer']['model_name']]
    model = model_class(**cfg['model'])
    device = cfg['trainer']['device']
    model.to(device)

    dicts = torch.load(weight, map_location=device, weights_only=False)
    model.load_state_dict(dicts['model_state_dict'])
    
    model.eval()
    pit_sdr = PIT(sdr).to(device)
    pit_sisnr = PIT(sisnr).to(device)

    for inputs, labels in tqdm(test_dataloader):
        '''
        batch = 1 spk = 2 time = 3200
        inputs: [batch, time] 
        outputs and labels: [torch.randn(batch, time) for _ in range(spk)]  
        expectention outputs and labels for torch audio-loss: torch.Size([batch, spk, time])
        '''
        inputs, labels = inputs.to(device), [l.to(device) for l in labels]
        with torch.no_grad():
            outputs = [s.detach() for s in model(inputs)]
            labels = torch.stack(labels, dim=1).to(device)
            outputs = torch.stack(outputs, dim=1).to(device)
            pit_sdr.update(outputs, labels)
            pit_sisnr.update(outputs, labels)
        
    print('sdr', pit_sdr.compute().item())
    print('sisnr', pit_sisnr.compute().item())

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/test_dataset.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  
testdataset_cfg = load_config(args.hparams)

datamodule = DiarizationDataset(**testdataset_cfg['data']).setup(stage = 'eval')
test_dataloader = datamodule.test_dataloader()

Size of test set: 9404
Elapsed time 'setup': 00:00:08.86


**Dualpath-RNN**

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/dualpathrnn.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  

dualpathrnn_cfg = load_config(args.hparams)
dualpathrnn_weight = './weights/dualpathrnn/DualPath_RNN_epoch_321_-9.0376.pt'

In [5]:
evaluate(dualpathrnn_cfg, 
         test_dataloader, 
         dualpathrnn_weight)

  0%|          | 0/9404 [00:00<?, ?it/s]

sdr 9.450394630432129
sisnr 8.572038650512695


**Conv-TasNet**

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/convtasnet.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  

convtasnet_cfg = load_config(args.hparams)
convtasnet_weight = './weights/convtasnet/best/DualPath_RNN_14_-7.3494.pt'

In [None]:
evaluate(convtasnet_cfg, 
         test_dataloader, 
         convtasnet_weight)

**Sepformer**

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--hparams", type=str, default="./configs/sepformer.yml", help="hparams config file")
args, unknown = parser.parse_known_args()  

sepformer_cfg = load_config(args.hparams)
sepformer_weight = './weights/sepformer/best/DualPath_RNN_14_-7.3494.pt'

In [None]:
evaluate(sepformer_cfg, 
         test_dataloader, 
         sepformer_weight)

In [5]:
# sdr -2.0657341480255127
# sdr second 5.21800422668457
# sisnr -4.563594581563693

In [None]:
# sdr 5.21800422668457
# sisnr 4.563263416290283