##### Copyright 2020 The TensorFlow Authors. [Licensed under the Apache License, Version 2.0](#scrollTo=ByZjmtFgB_Y5).

In [None]:
%install '.package(url: "https://github.com/tensorflow/swift-models", .branch("master"))' Datasets ImageClassificationModels TrainingLoop
print("\u{001B}[2J")

In [None]:
// #@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

<table class="tfo-notebook-buttons" align="left">
 <td>
  <a target="_blank" href="https://www.tensorflow.org/swift/tutorials/simple_model_training"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
 </td>
 <td>
  <a target="_blank" href="https://colab.research.google.com/github/tensorflow/swift/blob/master/docs/site/tutorials/simple_model_training.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
 </td>
 <td>
  <a target="_blank" href="https://github.com/tensorflow/swift/blob/master/docs/site/tutorials/simple_model_training.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
 </td>
</table>

## Training an image classification model

Let's take a look at how you'd set up and train an image classification model using the models, datasets, and general training loop provided by [the swift-models repository](https://github.com/tensorflow/swift-models).

In this example, we'll be using the simple LeNet-5 model, the MNIST handwritten digit classification dataset, and a callback-based training loop.

First, we'll import the necessary modules:

In [None]:
import Datasets
import ImageClassificationModels
import TensorFlow
import TrainingLoop

Then we'll specify the training parameters:

In [None]:
let epochCount = 12
let batchSize = 128

The training can either be performed using the default eager mode runtime, or the XLA-based X10 backend. For performance, and to support TPUs, we'll use an XLA-based X10 device:

In [None]:
// The following is a workaround needed until X10 can set log levels and memory growth parameters.
let _ = _ExecutionContext.global

let device = Device.defaultXLA
device

Then we'll download and configure the MNIST dataset:

In [None]:
let dataset = MNIST(batchSize: batchSize, on: device)

and the LeNet-5 model, along with an SGD optimizer:

In [None]:
var model = LeNet()
var optimizer = SGD(for: model, learningRate: 0.1)

The general-purpose training loop uses a callback mechanism to respond to actions and customize model training. In this example, we'll use an animated progress bar to display training status and simple statistics:

In [None]:
let trainingProgress = TrainingProgress()

The training loop takes in the training and validation datasets, our optimizer, a loss function, and our custom callbacks. From these, it automatically handles the process of pulling epochs, shuffling batches, and placing the model and optimizer on the right device:

In [None]:
var trainingLoop = TrainingLoop(
  training: dataset.training,
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  callbacks: [trainingProgress.update])

Finally, we can train our model using the loop:

In [None]:
try! trainingLoop.fit(&model, epochs: epochCount, on: device)

Note that loss decreases and accuracy increases over time for both training and validation, as we'd expect. If you executed this notebook on a GPU- or TPU-backed instance, the training should have run transparently on an accelerator.

`model` now hosts parameters that have been trained against the MNIST dataset and can be used for additional work or serialized to disk.