diff --git a/torchcomp/__init__.py b/torchcomp/__init__.py index e7e216c..8e54622 100644 --- a/torchcomp/__init__.py +++ b/torchcomp/__init__.py @@ -89,7 +89,7 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]): assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1) return sample_wise_lpc( - rms * avg_coef, + rms * avg_coef.unsqueeze(1), avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1, )