# Finetuning a pretrained model

To access pretrained computer vision models, you'll have to install a work-in-progress branch of Metalhead.jl as detailed on the [setup page](../docs/setup.md).

In [None]:
using FastAI
using Metalhead
using Zygote

Let's load the image classification dataset ImageNette. You're free to replace this by any of the other classification datasets in `FastAI.DATASETS`. 

In [3]:
DATASETNAME = "imagenette2-160";

In [4]:
taskdata = Datasets.loadtaskdata(Datasets.datasetpath(DATASETNAME), ImageClassification)
classes = Datasets.getclassesclassification(DATASETNAME);
method = ImageClassification(classes, (128, 128))

ImageClassification{2}(
    classes = ["n01440764", "n02102040", "n02979186", "n03000684", "n03028079", "n03394916", "…], 
    projections = ProjectiveTransforms{2}(
    sz = (128, 128), 
    buffered = true, 
    augmentations = DataAugmentation.Identity()
), 
    imageprepocessing = ImagePreprocessing(
    C = ColorTypes.RGB{FixedPointNumbers.N0f8}, 
    T = Float32, 
    buffered = true, 
    augmentations = DataAugmentation.Identity()
)
)

Now we load a pretrained model backbone:

In [5]:
backbone = Metalhead.ResNet50(pretrain=true).layers[1:end-3];

We pass it to `methodlearner` which will call `methodmodel` to stack a classification head on top of the backbone:

In [13]:
learner = methodlearner(method, taskdata, backbone, ToGPU(), Metrics(accuracy))

Learner()

The fine-tuning itself is done with [`finetune!`](#). It follows the same protocol as the [fastai implementation](https://github.com/fastai/fastai/blob/f2ab8ba78b63b2f4ebd64ea440b9886a2b9e7b6f/fastai/callback/schedule.py#L153). 

In [14]:
finetune!(learner, 4)

[32mEpoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:47[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   1.0 │ 1.71635 │  0.48276 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   1.0 │ 1.35517 │  0.59728 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:52[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   2.0 │ 1.05656 │  0.66247 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   2.0 │ 0.78783 │  0.75917 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:49[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   3.0 │ 0.85804 │  0.72977 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   3.0 │ 0.71917 │  0.78484 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:49[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   4.0 │ 0.62329 │   0.7972 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   4.0 │ 0.66632 │  0.80195 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:49[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   5.0 │ 0.42718 │  0.86343 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:02[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   5.0 │ 0.62322 │  0.81028 │
└─────────────────┴───────┴─────────┴──────────┘


Learner()