diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py index 1c4d3ba7d4..729b4158a9 100644 --- a/pl_bolts/callbacks/byol_updates.py +++ b/pl_bolts/callbacks/byol_updates.py @@ -69,5 +69,4 @@ def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Mo online_net.named_parameters(), # type: ignore[union-attr] target_net.named_parameters() # type: ignore[union-attr] ): - if 'weight' in name: - target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data + target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data