Skip to content

Commit

Permalink
add a comment on the NCE loss implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanGongND committed Aug 12, 2022
1 parent b28150e commit a1a3eec
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/models/ast_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def mpc(self, x, mask_patch, cluster, show_mask=False):
correct = torch.tensor(0.0).to(x.device)
for i in np.arange(0, B):
# negative samples are from the same batch
# equation (1) of the ssast paper
# 8/12/2022: has a difference with equation (1) in the ssast paper but (likely) performance-wise similar, see https://github.com/YuanGongND/ssast/issues/13
total = torch.mm(encode_samples[i], torch.transpose(pred[i], 0, 1)) # e.g. size 100*100
correct += torch.sum(torch.eq(torch.argmax(self.softmax(total), dim=0), torch.arange(0, mask_patch, device=x.device))) # correct is a tensor
nce += torch.sum(torch.diag(self.lsoftmax(total))) # nce is a tensor
Expand Down

0 comments on commit a1a3eec

Please sign in to comment.