# Training Utilities

This notebook presents a design of training utilities. 

In [7]:
import TensorFlow

## Training example data structure

A training example data structure consists of training data and a label.

In [8]:
/// A training example, containing training data and a label. Depending on `Data` and
/// `Label`'s implementations, the contents may represent a batch.
public struct Example<Data: Differentiable, Label> {
    public var data: Data
    public var label: Label
    
    public init(data: Data, label: Label) {
        self.data = data
        self.label = label
    }
}

## Trainer (learner)

A `Trainer` is responsible for initializing and training a model on a given dataset. It can be considered as a controller and an environment of model training.

### Core properties

`Trainer` contains three kinds of properties:
* Core units: `model`, `dataset`, `optimizer`, `lossFunction`
* Training states: `epochCount`, `currentEpoch`, `currentGradient`, `currentLoss`
* Event handlers: User-configurable callback functions that are called on various events during model training.

In [16]:
public enum TrainerAction: Error {
    case skipEpoch
    case skipBatch
    case stop
}

In [9]:
/// A model trainer, responsible for initializing and training a model on a given dataset.
// NOTE: When TF-421 is fixed, make `Label` not constrained to `Differentiable`.
public final class Trainer<Dataset: Collection, Label: Differentiable,
                           Loss: Differentiable & BinaryFloatingPoint,
                           Optimizer: TensorFlow.Optimizer & AnyObject>
    where Dataset.Element == Example<Optimizer.Model.Input, Label>,
          Optimizer.Scalar: Differentiable,
          Loss == Loss.CotangentVector
{
    // Common type aliases.
    public typealias Model = Optimizer.Model
    public typealias Data = Model.Input
    public typealias Variables = Model.AllDifferentiableVariables
    // NOTE: When TF-421 is fixed, replace with:
    //   public typealias LossFunction = @differentiable (Model.Output, @nondiff Label) -> Loss
    public typealias LossFunction = @differentiable (Model.Output, Label) -> Loss
    public typealias EventHandler = (Trainer) throws -> Void
    
    /// The dataset on which the model will be trained.
    public let dataset: Dataset
    /// The optimizer used for updating model parameters along gradient vectors.
    public var optimizer: Optimizer
    /// The function that computes a loss value when given a prediction and a label.
    public var lossFunction: LossFunction
    /// The model being trained.
    public var model: Model
    
    /// The number of total epochs.
    public private(set) var epochCount: Int = .zero
    /// The current epoch.
    public private(set) var currentEpoch: Int = .zero
    /// The current gradient.
    public private(set) var currentGradient: Model.CotangentVector = .zero
    /// The current loss.
    public private(set) var currentLoss: Loss = .zero

    /// A closure which will be called upon the start of model fitting.
    public var fittingStartHandler: EventHandler?
    /// A closure which will be called upon the completion of model fitting.
    public var fittingCompletionHandler: EventHandler?
    /// A closure which will be called upon the start of an epoch.
    public var epochStartHandler: EventHandler?
    /// A closure which will be called upon the completion of an epoch.
    public var epochCompletionHandler: EventHandler?
    /// A closure which will be called upon the start of model validation.
    public var validationStartHandler: EventHandler?
    /// A closure which will be called upon the start of training on a batch.
    public var batchStartHandler: EventHandler?
    /// A closure which will be called upon the completion of training on a batch.
    public var batchCompletionHandler: EventHandler?
    /// A closure which will be called when a new loss has been computed.
    public var newLossHandler: EventHandler?
    /// A closure which will be called when a new gradient has been computed.
    public var newGradientHandler: EventHandler?
    /// A closure which will be called upon the completion of an optimizer update.
    public var optimizerUpdateCompletionHandler: EventHandler?
    
    /// The context used for layer applications.
    private let context = Context(learningPhase: .training)

    /// Creates a trainer.
    ///
    /// - Parameters:
    ///   - dataset: The dataset which will be trained on.
    ///   - lossFunction: The loss function.
    ///   - optimizer: The optimizer used for updating model parameters along
    ///     gradient vectors.
    ///   - modelInitializer: The closure that produces an model to be trained.
    ///
    public init(dataset: Dataset,
                lossFunction: @escaping LossFunction,
                optimizer: Optimizer,
                initializingWith modelInitializer: () -> Model) {
        self.dataset = dataset
        self.optimizer = optimizer
        self.lossFunction = lossFunction
        self.model = modelInitializer()
    }
}

### Methods

The core method on `Trainer` is `fit(epochCount:)`.

In [19]:
extension Trainer {
    /// Train operation on batch.
    ///
    /// - Parameter batch: The batch of input data and labels to be trained on.
    ///
    private func train(on batch: Dataset.Element) throws {
        // NOTE: When the "subset of parameters" bug is fixed, replace with:
        //   let (loss, grad) = model.valueWithGradient { model -> Loss in
        //      let y = model.applied(to: batch.data, in: context)
        //      return lossFunction(y, batch.label)
        //   }
        let (loss, (grad, _)) = model.valueWithGradient(at: batch.label) { (model, label) -> Loss in
            let y = model.applied(to: batch.data, in: context)
            return lossFunction(y, label)
        }
        // NOTE: Put this inside `valueWithGradient`'s trailing closure when differentiation
        // supports throwing functions.
        try newLossHandler?(self)
        try newGradientHandler?(self)
        optimizer.update(&model.allDifferentiableVariables, along: grad)
        try optimizerUpdateCompletionHandler?(self)
    }
    
    private func performEpoch(_ i: Int) throws {
        currentEpoch = i
        try epochStartHandler?(self)
        for batch in dataset {
            try batchStartHandler?(self)
            do { try train(on: batch) }
            catch TrainerAction.skipBatch { break }
            try batchCompletionHandler?(self)
        }
        try epochCompletionHandler?(self)
    }

    /// Fit parameters.
    ///
    /// - Parameter epochCount: The number of epochs that will be run.
    ///
    public func fit(epochCount: Int) throws {
        self.epochCount = epochCount
        self.currentEpoch = 0
        do {
            try fittingStartHandler?(self)
            for i in 0..<epochCount {
                do { try performEpoch(i) }
                catch TrainerAction.skipEpoch { break }
            }
            try fittingCompletionHandler?(self)
        } catch TrainerAction.stop { return }
    }
}

In the short term, we call it `Learner` instead.

In [11]:
public typealias Learner = Trainer

## Examples

### Simple training loop