# Training a network

Let's take a detailed look at how we train a new deep learning model from pre-prepared data with dtoolAI. We're going to train a convolutional neural network (CNN) to recognise handwritten digits from the MNIST dataset.

## Loading our data

First, we'll need to load our data, using the appropriate class:

In [None]:
from dtoolai.data import TensorDataSet

We can now load our data from a persistent identifier (URI):

In [None]:
train_dataset_uri = "http://bit.ly/2uqXxrk"
train_ds = TensorDataSet(train_dataset_uri)

Let's look at what we've loaded:

In [None]:
data, label = train_ds[9000]
print(data.shape)
print(label)

The data are 1x28x28 arrays representing digits. We can load a helper function to visualise them:

In [None]:
from dtoolai.data import scaled_float_array_to_pil_image

Then visualise a single digit like this:

In [None]:
scaled_float_array_to_pil_image(data)

## Setting parameters

Next we'll set some parameters with which to train our model. We'll use dtoolAI's Parameters class for this, since we'll then be able to record those parameters during training automatically.

In [None]:
from dtoolai.parameters import Parameters

In [None]:
params = Parameters(
    batch_size=128,
    learning_rate=0.01,
    n_epochs=1
)

## Model, optimiser, loss function

In general, to train a deep learning model we need three things:

1. A suitable model architecture.
2. A loss function (how the difference between target labels and predicted labels will be calculated).
3. Training data.

We've loaded the data, now we need the model and a loss function. We'll also need to choose an optimiser.

Firstly, we'll load our generic classifier model:

In [None]:
from dtoolai.models import GenNet

then we can set the model's parameters from what we know about the dataset and initialise it:

In [None]:
params['init_params'] = dict(input_channels=train_ds.input_channels, input_dim=train_ds.dim)
model = GenNet(**params['init_params'])

Now we can create a loss function and optimiser:

In [None]:
import torch
loss_fn = torch.nn.NLLLoss()
optim = torch.optim.SGD(model.parameters(), lr=params.learning_rate)

## Training

Now we're ready to train our model. First we import a helper function from dtoolAI:

In [None]:
from dtoolai.training import train_model_with_metadata_capture

and another function to make sure our new model is packaged with useful metadata:

In [None]:
from dtoolcore import DerivedDataSetCreator

Now we can train the model. For this to work, you'll need to create a directory to which the trained model will be written. In this case, we can create a scratch directory (so named because we can delete it when we've finished):

In [None]:
import os
os.mkdir("../scratch")

In [None]:
with DerivedDataSetCreator('mnist.example.model', '../scratch', train_ds) as output_ds:
    train_model_with_metadata_capture(model, train_ds, optim, loss_fn, params, output_ds)

## Evaluating the model

The MNIST dataset actually contains 60,000 training examples (which we used to train our model), and 10,000 test examples. We can use this latter set of examples to test our model.

In [None]:
mnist_test_uri = "http://bit.ly/2NVFGQd"
test_ds = TensorDataSet(mnist_test_uri)

In [None]:
from torch.utils.data import DataLoader

In [None]:
test_dl = DataLoader(test_ds, batch_size=4)

Now we've loaded the data, we can use a utility function to evaluate the model's score:

In [None]:
from dtoolai.utils import evaluate_model

In [None]:
correct = evaluate_model(model, test_dl)

In [None]:
correct

If you want to improve the model's performance, try training it for more epochs. You'll need to go back and rerun the parameter setting code to use a higher value of n_epochs. You'll also need to remove the model you created earlier, or create a new one with a different name.

## Checking model provenance

Let's look at how we can extract provenance information (details of the history of creation) from our model.

First we need to load the model:

In [None]:
from dtoolai.trained import TrainedTorchModel

In [None]:
ttm = TrainedTorchModel("../scratch/mnist.example.model/")

Now we can use the API to determine the URI of the data used to train the model:

In [None]:
ttm.dataset.get_annotation("source_dataset_uri")

We can then follow this URI to load the training data itself:

In [None]:
import dtoolcore

In [None]:
source_dataset = dtoolcore.DataSet.from_uri('http://bit.ly/2uqXxrk')

then we have access to the training data, and its associated metadata, e.g.:

In [None]:
print(source_dataset.get_readme_content())