From 7ab8ed1149286ef3b934d52a4a4335da776af0c7 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 27 Nov 2024 18:38:38 +0800 Subject: [PATCH] fix: correct tensor broadcasting in avg function --- torchcomp/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchcomp/__init__.py b/torchcomp/__init__.py index e7e216c..8e54622 100644 --- a/torchcomp/__init__.py +++ b/torchcomp/__init__.py @@ -89,7 +89,7 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]): assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1) return sample_wise_lpc( - rms * avg_coef, + rms * avg_coef.unsqueeze(1), avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1, )