<em><sub>This page is available as an executable or viewable <strong>Jupyter Notebook</strong></sub></em>
<br/><br/>
<a href="https://mybinder.org/v2/gh/avan1235/KotlinDL/notebooks?filepath=docs%2Ftraining_a_model.ipynb"
   target="_parent">
   <img align="left"
        src="https://mybinder.org/badge_logo.svg"
        height="20">
</a>
<a href="https://nbviewer.jupyter.org/github/avan1235/KotlinDL/blob/notebooks/docs/training_a_model.ipynb"
   target="_parent">
   <img align="right"
        src="https://raw.githubusercontent.com/jupyter/design/master/logos/Badges/nbviewer_badge.svg"
        height="20">
</a>
<br/><br/>

In [1]:
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.2.0")

# Training a model with KotlinDL

In the [first tutorial](create_your_first_nn.ipynb) we created some base example of Neural Network that was build of dense layers only. Let's recall the created code to use it for actual training of the model:

In [2]:
import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.reshaping.Flatten
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense


val model = Sequential.of(
    Input(28, 28, 1),
    Flatten(),
    Dense(300),
    Dense(100),
    Dense(10)
)

Before you can use data, typically some preprocessing is required.

In this case, it's minimal – all the images are already the same size and are grayscale. 
With the built-in functionality, we can convert the [Fashion MNIST image archives](https://github.com/zalandoresearch/fashion-mnist#get-the-data) into a dataset object that we can use for model training.    

In [3]:
import org.jetbrains.kotlinx.dl.dataset.fashionMnist


val (train, test) = fashionMnist()

Extracting 60000 images of 28x28 from /workspace/cache/datasets/fashionmnist/train-images-idx3-ubyte.gz
Extracting 60000 labels from /workspace/cache/datasets/fashionmnist/train-labels-idx1-ubyte.gz
Extracting 10000 images of 28x28 from /workspace/cache/datasets/fashionmnist/t10k-images-idx3-ubyte.gz
Extracting 10000 labels from /workspace/cache/datasets/fashionmnist/t10k-labels-idx1-ubyte.gz


You may also notice that we are splitting the data into two sets. 
We have the `test` set, which we won't be touching until we are satisfied with the model and want to confirm its performance on unseen data. 
We have also the `train` set which we'll use during the training process.
Usually we also split the provided `train` set into two parts - the actual training data (that contains most of the samples) and the validation data (that plays the role of temporal tests that we would perform during training not to touch the `test` set).

Now everything is ready to train the model. Use the `fit()` method for this:

In [4]:
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics

model.compile(
    optimizer = Adam(),
    loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
    metric = Metrics.ACCURACY
)

val trainHistory = model.fit(
    dataset = train,
    epochs = 10,
    batchSize = 100
)

Here are some important parameters that we need to pass to the `fit()` method:
* `epochs` - Number of iterations over the data you want the training process to perform. Epoch = iteration. 
* `batchSize` - How many examples will be used for updating the model's parameters (aka weights and biases) at a time.

After the model has been trained, it's important to evaluate its performance on the validation dataset, so that we can check how it generalizes to the new data. 

In [5]:
val accuracy = model.evaluate(dataset = test, batchSize = 100).metrics[Metrics.ACCURACY]

println("Accuracy: $accuracy")

Accuracy: 0.8832000494003296


---
**NOTE**

The results are nondeterministic, and you may have a slightly different Accuracy value. 

---

When we are happy with the model's evaluation metric, we can save the model for future use in the production environment and close it as already used resource.

In [6]:
import org.jetbrains.kotlinx.dl.api.core.WritingMode
import java.io.File


model.save(File("src/model/my_first_model"), writingMode = WritingMode.OVERRIDE)
model.close()

And just like that, we have trained, evaluated, and saved a deep learning model that we can now use to generate predictions (aka inference). 
In the [next tutorial](loading_trained_model_for_inference.ipynb), you'll learn how to load and use the model.