Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive loss implementation discrepancy between the paper and codebase #8

Closed
Gyat opened this issue May 12, 2021 · 1 comment
Closed

Comments

@Gyat
Copy link

Gyat commented May 12, 2021

Hello,

This is in relation to the losses described in the paper and implemented in the codebase. Need your help in understanding the following:

  1. The 4th Page in the paper reads that: "the contrastive alignment loss enforces alignment between the embedded representations of the object at the output of the decoder, and the text representation at the output of the cross encoder." However, in the code transformer.py, the following snippet is being used for the loss calculations:

"text_pooled_op": encoded_text.pooler_output if self.CLS is not None else None,

"img_pooled_op": img_memory[0] if self.CLS is not None else None, # Return the CLS token

which essentially means that we are deriving the embedded representation of the text from the BERT-based text backbone encoder's classification token and the image embedded representation is being derived from the output of the transformer encoder. Is this genuinely a discrepancy? If not, can you kindly point towards the snippet for these loss calculations where you are tapping in the decoder output?

  1. Also, is the following understanding correct: The 'Soft token prediction' loss from the paper is actually called 'contrastive_align_loss' in the codebase and the 'Contrastive alignment' loss from the paper is actually named 'contrastive_loss' in the codebase.

Thank you.

@ashkamath
Copy link
Owner

Hi,
It looks like you're confusing the contrastive_align_loss with the contrastive_loss.
In our paper and published results, we do not use the contrastive loss (which is akin to an image-text matching loss from other vision+language pre-training papers). We only left it in the code for completeness since it is something we tried at some point, and thought it would be useful if other users of our code base were interested in experimenting with it. For the two losses that we do use, read the following:

  1. Contrastive align loss, which is calculated between the predictions of the decoder and the embedded representations of the text and the output of the cross encoder. Relevant lines in the code:

    if contrastive_align_loss:
    ,
    if self.contrastive_align_loss:
    ,
    def loss_contrastive_align(self, outputs, targets, positive_map, indices, num_boxes):

  2. Contrastive alignment -> loss_contrastive_align that we just discussed above. Soft token prediction is loss_labels

    def loss_labels(self, outputs, targets, positive_map, indices, num_boxes):

Hope this makes it more clear! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants