# Saving and loading models for inference 

In the end, we train models because we want to use them for inference, that is, using them to generate predictions on new targets. The general formula for doing this in FastAI.jl is to first train a `model` for a `method`, for example using [`fitonecycle!`](#) or [`finetune!`](#) and then save the model and the learning method configuration to a file using [`savemethodmodel`](#). In another session you can then use [`loadmethodmodel`](#) to load both. Since the learning method contains all preprocessing logic we can then use [`predict`](#) and [`predictbatch`](#) to generate predictions for new inputs.

Let's fine-tune an image classification model (see [here](./fitonecycle.ipynb) for more info) and go through that process.

In [20]:
using FastAI
using Metalhead

dir = joinpath(datasetpath("dogscats"), "train")
data = loadtaskdata(dir, ImageClassificationTask)
classes = Datasets.getclassesclassification(dir)
method = ImageClassification(classes, (128, 128))
backbone = Metalhead.resnet50(pretrain = true)[1:end-3]
learner = methodlearner(method, data, backbone, ToGPU(), Metrics(accuracy))
finetune!(learner, 3)

[32mEpoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:33[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   1.0 │ 0.66782 │  0.70212 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:06[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   1.0 │ 0.52379 │  0.75405 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:41[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   2.0 │ 0.47062 │  0.78196 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:06[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   2.0 │ 0.58706 │   0.8261 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:42[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   3.0 │ 0.34845 │  0.85136 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:06[39m


┌─────────────────┬───────┬────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m   Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │   3.0 │ 0.2253 │  0.90647 │
└─────────────────┴───────┴────────┴──────────┘


[32mEpoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:41[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   4.0 │ 0.21838 │   0.9113 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:06[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   4.0 │ 0.16623 │  0.93396 │
└─────────────────┴───────┴─────────┴──────────┘


Learner()

Now we can save the model using [`savemethodmodel`](#).

In [21]:
savemethodmodel("catsdogs.jld2", method, learner.model)

In another session we can now use [`loadmethodmodel`](#) to load both model and learning method from the file. Since the model weights are transferred to the CPU before being saved, we need to move them to the GPU manually if we want to use that for inference. 

In [22]:
method, model = FastAI.loadmethodmodel("catsdogs.jld2")
model = gpu(model);

Finally, let's select the first 8 cat images from the dataset and see if the model classifies them correctly:

In [33]:
# use it for inference
images = [getobs(data.input, i) for i in 1:8]
preds = predictbatch(method, model, images; device = gpu, context = Validation())

8-element Vector{SubString{String}}:
 "cats"
 "cats"
 "cats"
 "cats"
 "cats"
 "cats"
 "cats"
 "cats"