From 7b10bb7017a94effefc04fa09f0597ee2f8a5531 Mon Sep 17 00:00:00 2001 From: Marten Lienen Date: Wed, 31 May 2023 15:31:58 +0200 Subject: [PATCH] Fix norm for differently sized features across ranks --- .../SE3Transformer/se3_transformer/model/layers/norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py index ba83aee06..915963c98 100644 --- a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py @@ -87,6 +87,6 @@ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Ten for degree, feat in features.items(): norm = clamped_norm(feat, self.NORM_CLAMP) new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1)) - output[degree] = rescale(new_norm, feat, norm) + output[degree] = rescale(feat, norm, new_norm) return output