From a1a3eecb94731e226308a6812f2fbf268d789caf Mon Sep 17 00:00:00 2001 From: ygong Date: Fri, 12 Aug 2022 15:32:11 -0400 Subject: [PATCH] add a comment on the NCE loss implementation --- src/models/ast_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/ast_models.py b/src/models/ast_models.py index b920ea89..351283fc 100644 --- a/src/models/ast_models.py +++ b/src/models/ast_models.py @@ -346,7 +346,7 @@ def mpc(self, x, mask_patch, cluster, show_mask=False): correct = torch.tensor(0.0).to(x.device) for i in np.arange(0, B): # negative samples are from the same batch - # equation (1) of the ssast paper + # 8/12/2022: has a difference with equation (1) in the ssast paper but (likely) performance-wise similar, see https://github.com/YuanGongND/ssast/issues/13 total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 100*100 correct += torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=0), torch.arange(0, mask_patch, device=x.device))) # correct is a tensor nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor