Skip to content

Proper way to do contrastive learning with DDP & PT-Lightning #14390

Discussion options

You must be logged in to vote

@kkarrancsu You are definitely on the right track here. In the LightningModule, you have this method for gathering a tensor from all processes:

tensors_from_all = self.all_gather(my_tensor)

What you want is to back-propagate through this all_gather function, and this is possible if you set

tensors_from_all = self.all_gather(my_tensor, sync_grad=True)

In your case, your training_step method could look something like this:

    def training_step(self, batch, batch_idx):
        outputs = self(batch)
        ...

        all_outputs = self.all_gather(outputs, sync_grads=True)

        loss = contrastive_loss_fn(all_outputs, ...)
        return loss

Replies: 5 comments 7 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@kkarrancsu
Comment options

@awaelchli
Comment options

Answer selected by kkarrancsu
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@awaelchli
Comment options

@piconti
Comment options

Comment options

You must be logged in to vote
3 replies
@GitOutOfMyBed
Comment options

@yipliu
Comment options

@GitOutOfMyBed
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
6 participants