-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on
Description
I realized the ddp2 implementation i put together doesn't allow the user to operate on the ouputs of all DP processes.
For instance, you calculate some logits on each process and current ddp2 forces the loss to be calculated in each process individually. However, if you wanted to say normalize across all examples in the batch you'd need to somehow share the output of each process.
Currently:
total_loss = []
for process:
# training step
out = model(x)
loss = loss(out)
total_loss.append(loss)
loss = total_loss.mean()Proposed:
outs = []
for process:
# training step
out = model(x)
outs.append(out)
# allow training_end to (softmax for instance) using ALL outputs
loss = model.training_end(outs)
loss.backward()The implication is adding an optional:
def training_end(...):To model which when defines gives you all the outputs of training_step.
If you don't need anything advanced like this, you return a loss from training_step and don't implement training_end. If you need more advanced control, you implement training_end.
awaelchli and mibaumgartner
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on