From 05b2b1a80a4dad0343bb230ef5e3b8eaf20e163d Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Tue, 1 Mar 2022 13:53:48 +0800 Subject: [PATCH] enhance contrastive loss Signed-off-by: Yiheng Wang --- monai/losses/contrastive.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index aef596a492..cd5b261acf 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -65,14 +65,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - temperature_tensor = torch.tensor(self.temperature).to(input.device) + temperature_tensor = torch.as_tensor(self.temperature).to(input.device) norm_i = F.normalize(input, dim=1) norm_j = F.normalize(target, dim=1) negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool) - negatives_mask = torch.tensor(negatives_mask, dtype=torch.float) - negatives_mask = torch.clone(torch.as_tensor(negatives_mask)).to(input.device) + negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device) repr = torch.cat([norm_i, norm_j], dim=0) sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)