# Custom training loops

How can we compose changes to the training loop together?

As an example, we want to be able to do data-parallel training or GAN training, but also data-parallel GAN training.

I want to give some thoughts on this and propose possible solutions to make this possible in *FastAI.jl*. 

Currently, implementing custom training behavior is possible by subtyping `FitPhase` and implementing `fitbatchphase!(learner, phase::MyPhase)`. However, this is not composable, i.e. it doesn't allow you to combine a `DataParallelTrainingPhase` and `GANTrainingPhase`.



Below are some examples that show how the contents of `fitbatchphase!` can be changed to illustrate this. For simplicity, they don't include callbacks and state handling.


#### Regular training (1)

In [8]:
# inputs: model, xs, ys, lossfn, optim, params
grads = gradient(params) do
    return lossfn(model(xs), ys)
end
update!(optim, params, sum(grads))

LoadError: UndefVarError: gradient not defined

#### Data-parallel training on CPU (2)

In [None]:
# inputs: model, xs, ys, lossfn, optim, params

grads = Array{Grads}(undef, Threads.nthreads())

# run equally-sized slices of the batch in parallel (naive pseudocode)
Threads.@threads for (i, (xs_, ys_)) in enumerate(scatter((xs, ys), Threads.nthreads()))
    grads[i] = gradient(params) do
        return lossfn(model(xs_), ys_)
    end
end
gs = sum(grads)

# do parameter update with summed gradients
update!(optim, params, gs)

#### GAN training (3)

In [None]:
# inputs: mgen, mcrit, paramsgen, paramscrit, xs_true, lossfngen, lossfncrit, optim, batchsize

# critic step
xs_fake = mgen(batchsize)
xs = cat(xs_true, xs_fake)
ys = onehot.(vcat(trues(batchsize), falses(batchsize)))

grads = gradient(paramscrit) do
    return lossfncrit(mcrit(xs), ys)
end
update!(optim, paramscrit, grads)

# generator step
grads = gradient(paramsgen) do
    xs_fake = mgen(batchsize)
    ys_fake = onehot.(falses(batchsize))
    return -lossfncrit(crit(xs_fake), ys_fake)
end

### Solutions

I have found two approaches to deal with this. **Both focus on removing execution logic from `fitbatchphase!`, making them composable with custom `Phase`s** like `GANTrainingPhase` that change
the *semantics* of the training loop.

On one hand, there are extensions to the training loop that change the *execution* (e.g. parallel and distributed CPU and GPU training), on the other hand you have those that change the *semantics* (e.g. GAN training).

The proposed solutions make the assumption that different *semantics* don't need to be composed, but should be composable with different execution contexts.

#### (S1) Abstract gradient step (and possibly others) out

Modifications to the execution of the training loop could be implemented by wrapping in an execution context.

In the below example `gradientphase` could dispatch to the regular gradient calculation in (1) or the data-parallel approach (2) depending on `executionctx`.

This would mean that only *semantic* changes to the training loop would use overloading of `fitbatchphase!` with a custom `FitPhase`. Changes to the *execution* work by dispatching on execution contexts, e.g. `gradientphase(::Linear, ...)` or `gradientphase(::DataParallel, ...)`.

In [None]:
# gradientphase passes (possibly modified) state to a closure
# in the case executionctx::DataParallel, xs_, ys_ will be slices
# of the batch.
grads = gradient(executionctx, params) do model_, params_, xs_, ys_, 
    return lossfn(model(xs_), ys_)
end
update!(optim, params, grads)

**Advantages**

- implementation definitely doable

**Disadvantages**

- implementation dependent on requirements, i.e. unsure which pieces of the training step need to be overloadable and which state needs to be passed to closures.

#### (S2) Wrapper for `model`

The idea is to wrap the `model` in an execution context, e.g. `DataParallel(model)`. The wrapper is then responsible for exhibiting the correct behavior on the forward and backward pass. This is [what PyTorch does](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html?highlight=parallel#torch.nn.DataParallel).

No changes to the training loop would need to be made. The implementation for the forward pass should be straightforward and similar to the above sketch (2), however I'm not sure how to make sure that the backward pass is also computed in parallel (can custom gradient definitions include multi-threading code?, what about the loss function that is not wrapped?).

In [None]:
# before training
model = DataParallel(model)

# training step doesn't change at all
grads = gradient(params) do
    return lossfn(model(xs), ys)
end
update!(optim, params, sum(grads))

**Advantages**

- no changes needed to state and event handling

**Disadvantages**

- not sure if such a simple API is possible to implement for all scenarios
- bit unelegant; model is not a pure function anymore