Skip to content

Add a way to operate on all outputs from training_step #446

@williamFalcon

Description

@williamFalcon

I realized the ddp2 implementation i put together doesn't allow the user to operate on the ouputs of all DP processes.

For instance, you calculate some logits on each process and current ddp2 forces the loss to be calculated in each process individually. However, if you wanted to say normalize across all examples in the batch you'd need to somehow share the output of each process.

Currently:

total_loss = []
for process:
     # training step
     out = model(x)
     loss = loss(out)
     total_loss.append(loss)

loss = total_loss.mean()

Proposed:

outs = []
for process:
     # training step
     out = model(x)
     outs.append(out)

# allow training_end to (softmax for instance) using ALL outputs
loss = model.training_end(outs)
loss.backward()

The implication is adding an optional:

def training_end(...):

To model which when defines gives you all the outputs of training_step.
If you don't need anything advanced like this, you return a loss from training_step and don't implement training_end. If you need more advanced control, you implement training_end.

@tullie @neggert Any thoughts on this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions