diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index aef596a492..cd5b261acf 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -65,14 +65,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - temperature_tensor = torch.tensor(self.temperature).to(input.device) + temperature_tensor = torch.as_tensor(self.temperature).to(input.device) norm_i = F.normalize(input, dim=1) norm_j = F.normalize(target, dim=1) negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) - negatives_mask = torch.tensor(negatives_mask, dtype=torch.float) - negatives_mask = torch.clone(torch.as_tensor(negatives_mask)).to(input.device) + negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device) repr = torch.cat([norm_i, norm_j], dim=0) sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)