From e20572c3141314dce72e7418056a8d7ad814ed90 Mon Sep 17 00:00:00 2001 From: Kin Wai Cheuk Date: Tue, 6 Jun 2023 03:38:30 +0800 Subject: [PATCH] added SDR calculation --- models/conv128.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/models/conv128.py b/models/conv128.py index 1dffc4e..fbed970 100644 --- a/models/conv128.py +++ b/models/conv128.py @@ -4,14 +4,18 @@ import torch.optim as optim def calculate_sdr(ref, est): + """ + ref: (B, L) + est: (B, L) + """ assert ref.dim()==est.dim(), f"ref {ref.shape} has a different size than est {est.shape}" s_true = ref s_artif = est - ref sdr = 10. * ( - torch.log10(torch.clip(torch.mean(s_true ** 2, 1), 1e-8, torch.inf)) \ - - torch.log10(torch.clip(torch.mean(s_artif ** 2, 1), 1e-8, torch.inf))) + torch.log10(torch.clip(torch.mean(s_true ** 2, -1), 1e-8, torch.inf)) \ + - torch.log10(torch.clip(torch.mean(s_artif ** 2, -1), 1e-8, torch.inf))) return sdr class Conv128(pl.LightningModule): @@ -55,8 +59,17 @@ def test_step(self, batch, batch_idx): label = batch[1] # (batch, 4, 2, len) loss = torch.nn.functional.mse_loss(pred, label.flatten(1,2)) - self.log('Test/mse_loss', recon_loss) - return loss + self.log('Test/mse_loss', loss) + + sdr = calculate_sdr(label.flatten(1,2), pred) + sdr1, sdr2, sdr3, sdr4 = \ + torch.split(sdr,2, dim=1) + self.log('Test/sdr', sdr.mean()) + self.log('Test/sdr1', sdr1.mean()) + self.log('Test/sdr2', sdr2.mean()) + self.log('Test/sdr3', sdr3.mean()) + self.log('Test/sdr4', sdr4.mean()) + return loss, sdr, sdr1, sdr2, sdr3, sdr4