Skip to content

Commit

Permalink
added SDR calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
KinWaiCheuk committed Jun 5, 2023
1 parent 3010efa commit e20572c
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions models/conv128.py
Expand Up @@ -4,14 +4,18 @@
import torch.optim as optim

def calculate_sdr(ref, est):
"""
ref: (B, L)
est: (B, L)
"""
assert ref.dim()==est.dim(), f"ref {ref.shape} has a different size than est {est.shape}"

s_true = ref
s_artif = est - ref

sdr = 10. * (
torch.log10(torch.clip(torch.mean(s_true ** 2, 1), 1e-8, torch.inf)) \
- torch.log10(torch.clip(torch.mean(s_artif ** 2, 1), 1e-8, torch.inf)))
torch.log10(torch.clip(torch.mean(s_true ** 2, -1), 1e-8, torch.inf)) \
- torch.log10(torch.clip(torch.mean(s_artif ** 2, -1), 1e-8, torch.inf)))
return sdr

class Conv128(pl.LightningModule):
Expand Down Expand Up @@ -55,8 +59,17 @@ def test_step(self, batch, batch_idx):
label = batch[1] # (batch, 4, 2, len)
loss = torch.nn.functional.mse_loss(pred, label.flatten(1,2))

self.log('Test/mse_loss', recon_loss)
return loss
self.log('Test/mse_loss', loss)

sdr = calculate_sdr(label.flatten(1,2), pred)
sdr1, sdr2, sdr3, sdr4 = \
torch.split(sdr,2, dim=1)
self.log('Test/sdr', sdr.mean())
self.log('Test/sdr1', sdr1.mean())
self.log('Test/sdr2', sdr2.mean())
self.log('Test/sdr3', sdr3.mean())
self.log('Test/sdr4', sdr4.mean())
return loss, sdr, sdr1, sdr2, sdr3, sdr4



Expand Down

0 comments on commit e20572c

Please sign in to comment.