Skip to content

Commit

Permalink
df: Adjust metrics due to reallfft change
Browse files Browse the repository at this point in the history
  • Loading branch information
Rikorose committed Feb 10, 2022
1 parent 9040057 commit 05da995
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions DeepFilterNet/df/scripts/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def try_eval_composite(clean, enhanced, sr):
).to(torch.float32)
logger.info(f"Got {m_enh}")
m_target = torch.as_tensor(
[2.3057448863983, 3.832368850708, 2.3624868392944, 3.054983377456, -2.792978048324]
[2.30616855621338, 3.832779407501221, 2.362725973129273, 3.05537247657776, -2.7911112308502]
)
assert torch.isclose(
m_enh, m_target, atol=__a_tol
).all(), f"Metric output not close. Expected {m_target}, diff: {m_target-m_enh}"
).all(), f"Metric output not close. Expected {m_target}, got {m_enh}, diff: {m_target-m_enh}"


def eval_pystoi(clean, enhanced, sr):
Expand All @@ -38,17 +38,17 @@ def eval_pystoi(clean, enhanced, sr):
logger.info(f"Got {m_enh:.4f}")
assert np.isclose(
[m_enh], [m_target], atol=__a_tol
), f"Metric output not close. Expected {m_target}, diff: {m_target-m_enh}"
), f"Metric output not close. Expected {m_target}, got {m_enh}, diff: {m_target-m_enh}"


def eval_sdr(clean, enhanced):
logger.info("Computing SI-SDR")
m_enh = si_sdr_speechmetrics(clean.numpy(), enhanced.numpy())
m_target = 18.878527879714966
m_target = 18.88543128967285
logger.info(f"Got {m_enh:.4f}")
assert np.isclose(
[m_enh], [m_target]
), f"Metric output not close. Expected {m_target}, diff: {m_target-m_enh}"
), f"Metric output not close. Expected {m_target}, got {m_enh}, diff: {m_target-m_enh}"


if __name__ == "__main__":
Expand Down

0 comments on commit 05da995

Please sign in to comment.