# Retraining a pretrained network

Here we'll take an in-depth look at how we can use dtoolAI to help with retraining a pre-trained image recognition network on new types of image. Using a pre-trained network makes training much faster.

We're going to load a network trained on the ImageNet <http://www.image-net.org/> dataset, a large collection of images with 1000 different labels. We'll then retrain our network on new data.


## Loading and examining data

To provde a simple example, we've created a small DataSet containing just two categories of images from the CalTech 101 objects <http://www.vision.caltech.edu/Image_Datasets/Caltech101/> dataset. If you'd like to do this using your own data, the second half of the dtoolAI documentation on retraining: <https://dtoolai.readthedocs.io/en/latest/retraining.html#part-2-with-raw-data> explains how to prepare data.

Let's load the dataset:

In [0]:
from dtoolai.data import ImageDataSet, scaled_float_array_to_pil_image

In [0]:
train_ds = ImageDataSet("http://bit.ly/3aRvimq")

We can look at the metadata associated with this training DataSet:

In [0]:
print(train_ds.dataset.get_readme_content())

Now we can extract a single image and label to look at:

In [0]:
imarray, label = train_ds[0]

In [0]:
scaled_float_array_to_pil_image(imarray)

We can check this images label, both numerically:

In [0]:
label

and, by looking up the categorical encoding, work out what this means:

In [0]:
train_ds.cat_encoding

or we can look at another example:

In [0]:
imarray, label = train_ds[3]
scaled_float_array_to_pil_image(imarray)

## Setting parameters

Before training, we'll need to set some parameters. We do this using dtoolAI's ``Parameters`` class, which provides support for recording these parameters automatically during model training.

In [0]:
from dtoolai.parameters import Parameters

We need to make sure that we tell the model we'll create how many categories
it will need to classify. This corresponds to the size of the category encoding
in our input dataset.

In [0]:
init_params = {
    'n_outputs': len(train_ds.cat_encoding)
}

params = Parameters(
    batch_size=4,
    learning_rate=0.001,
    n_epochs=1,
    init_params=init_params
)

## Loading a pretrained model

Then we load our pretrained model. We're using ResNet <https://arxiv.org/abs/1512.03385>, with a new classifier added at the end.

In [0]:
from dtoolai.models import ResNet18Pretrained
model = ResNet18Pretrained(**init_params)

Now we need to set a loss function and an optimiser:

In [0]:
import torch
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=params.learning_rate)

## Retraining the model

Now we're ready to retrain the model on our new data.

First we'll import the functions dtoolAI provides to support training/retraining:



In [0]:
from dtoolai.training import train_model_with_metadata_capture
from dtoolcore import DerivedDataSetCreator

We'll need to create a directory to which we can write our trained model:

In [0]:
import os
os.mkdir("../scratch")

Now we're ready to train our model. This might take a few minutes!:

In [0]:
with DerivedDataSetCreator('twocat.image.model', '../scratch', train_ds) as output_ds:
    train_model_with_metadata_capture(model, train_ds, optim, loss_fn, params, output_ds)

## Evaluating the retrained model

To evaluate the model, we can take advantage of the way the two category DataSet has been created. Some of the images are marked as training data, and some as test data. When we loaded the data earlier, we got the training set, now we can load the test set:

In [0]:
test_ds = ImageDataSet("http://bit.ly/3aRvimq", usetype="test")

We can check that the train and test DataSets have different sizes:

In [0]:
print(f"Training dataset has {len(train_ds)} items, test dataset has {len(test_ds)}")

Now we can a helper function to evaluate our model:

In [0]:
from dtoolai.utils import evaluate_model
from torch.utils.data import DataLoader

Then we run the evaluation:

In [0]:
test_dl = DataLoader(test_ds)
correct = evaluate_model(model, test_dl)
print(f"Model correct on {correct} out of {len(test_ds)} items")

## Improving the model

The model only trained for a single epoch. Let's see if we can improve its performance by training for longer.

First we'll change our parameters to train for 5 epochs, rather than one, then create a new retrained model:

In [0]:
params = Parameters(
    batch_size=4,
    learning_rate=0.001,
    n_epochs=5,
    init_params=init_params
)
with DerivedDataSetCreator('twocat.image.model.5', '../scratch', train_ds) as output_ds:
    train_model_with_metadata_capture(model, train_ds, optim, loss_fn, params, output_ds)

Now we can evaluate our new model:

In [0]:
correct = evaluate_model(model, test_dl)
print(f"Model correct on {correct} out of {len(test_ds)} items")

Much better!

## Applying the model to a new image

Let's try applying our model to a new image. You can use the example image below, an image of a hedgehog from wikipedia, or find your own. If it's not a hedgehog or a llama, it might confuse the model though!

First we'll need some libraries to load the image:

In [0]:
from imageio import imread
from PIL import Image

Now we can load the image from a URL:

In [0]:
imarray = imread("https://upload.wikimedia.org/wikipedia/commons/7/72/Igel.JPG")
image = Image.fromarray(imarray)

now we can look at it:

In [0]:
image

Now let's load the model that we trained:

In [0]:
from dtoolai.trained import TrainedTorchModel

In [0]:
model = TrainedTorchModel("../scratch/twocat.image.model")

We need to do some work to convert the image format into that which the model expects. dtoolAI has recorded the image dimensions used by the model in the model's metadata, so we can retrieve these to use for the conversion:

In [0]:
dim = model.model_params['input_dim']
channels = model.model_params['input_channels']
input_format = [channels, dim, dim]

Now we'll load helper functions to convert the image:

In [0]:
from dtoolai.imageutils import coerce_to_target_dim
from torchvision.transforms.functional import to_tensor

resized_converted = coerce_to_target_dim(image, input_format)
as_tensor = to_tensor(resized_converted)

Then we can use the model to categorise the image:

In [0]:
result = model.predict(as_tensor[None])

and check the classification:

In [0]:
print(f"Classified image as {result}")