In [1]:
import sys
sys.path.append('..')

import hydra
from hydra.utils import instantiate

hydra.initialize(version_base=None, config_path="../src/configs")
config = hydra.compose(config_name="conv-tasnet-baseline")

In [2]:
run_name = "conv-tasnet-baseline"
checkpoint_name = "checkpoint-epoch50.pth"
device = "cuda:1"

In [3]:
import torch
from torch import nn

from src.datasets import DLADataset
from src.datasets.data_utils import get_dataloaders

model = nn.DataParallel(instantiate(config.model))
dataloaders, batch_transforms = get_dataloaders(config, device)
dataloader = dataloaders["val"]

weights = torch.load(f"../saved/{run_name}/{checkpoint_name}", map_location=device)
model.load_state_dict(weights["state_dict"])
model = model.module.to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from IPython.display import Audio
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio

sisnr = ScaleInvariantSignalNoiseRatio().to(device)

def _sisnr(preds, target):
    def dot(x, y, axis):
        return torch.sum(x * y, axis=axis).unsqueeze(axis)

    def norm(x, axis):
        return torch.sum(x**2, axis=axis).unsqueeze(axis)

    signal = (
        dot(preds, target, axis=1) * target / norm(target, axis=1)
    )
    noise = preds - signal

    snr = 10 * torch.log10(norm(signal, axis=1) / norm(noise, axis=1))
    return snr


ValueError: Unexpected keyword arguments: `average`

In [5]:
from tqdm import tqdm

values = []

for batch in tqdm(dataloader):
    batch = {
        k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)
    }
    with torch.no_grad():
        batch.update(model(**batch))

    mix = batch["mix_wav"][:, 0]
    output1 = batch["output_wav"][:, 0]
    output2 = batch["output_wav"][:, 1]
    target1 = batch["speaker_1_wav"][:, 0]
    target2 = batch["speaker_2_wav"][:, 0]

    # print('TORCHMETRICS')
    # print('output1 and target1', sisnr(preds=output1, target=target1))
    # print('output1 and target2', sisnr(preds=output1, target=target2))
    # print('output2 and target1', sisnr(preds=output2, target=target1))
    # print('output2 and target2', sisnr(preds=output2, target=target2))

    # print('MINE')
    # print('output1 and target1', _sisnr(preds=output1, target=target1))
    # print('output1 and target2', _sisnr(preds=output1, target=target2))
    # print('output2 and target1', _sisnr(preds=output2, target=target1))
    # print('output2 and target2', _sisnr(preds=output2, target=target2))

    for output1, output2, target1, target2 in zip(output1, output2, target1, target2):
        v11 = sisnr(preds=output1, target=target1)
        v12 = sisnr(preds=output1, target=target2)
        v21 = sisnr(preds=output2, target=target1)
        v22 = sisnr(preds=output2, target=target2)

        v1 = (v11 + v22) / 2
        v2 = (v12 + v21) / 2

        best = max(v1, v2)
        values.append(best)

100%|██████████| 313/313 [01:07<00:00,  4.62it/s]


In [9]:
import numpy as np

print(torch.mean(torch.Tensor(values)))

tensor(7.4389)


In [6]:
Audio(output1[0].cpu().numpy(), rate=16000)

ValueError: Array audio input must be a 1D or 2D array

In [None]:
Audio(output2[0].cpu().numpy(), rate=16000)

In [None]:
Audio(target1[0].cpu().numpy(), rate=16000)

In [None]:
Audio(target2[0].cpu().numpy(), rate=16000)