# Finetuning a pretrained model

To access pretrained computer vision models, you'll have to install a work-in-progress branch of Metalhead.jl:

```
using Pkg
Pkg.add(Pkg.PackageSpec(url="https://github.com/darsnack/Metalhead.jl", rev="darsnack/vision-refactor"))
```

In [None]:
using FastAI
using Metalhead
using Zygote

Let's load the image classification dataset ImageNette. You're free to replace this by any of the other classification datasets in `FastAI.DATASETS`. 

In [3]:
DATASETNAME = "imagenette2-160";

In [4]:
taskdata = Datasets.loadtaskdata(Datasets.datasetpath(DATASETNAME), ImageClasssification)
classes = Datasets.getclassesclassification(DATASETNAME);
method = ImageClassification(classes, (128, 128))

ImageClassification() with 10 classes

Now we load a pretrained model backbone:

In [23]:
# load model with pretrained weights 
backbone = Metalhead.resnet50(pretrain = true)[1:end-3];

We pass it to `methodlearner` which will call `methodmodel` to stack a classification head on top of the backbone:

In [24]:
learner = methodlearner(method, taskdata, backbone, ToGPU(), Metrics(accuracy))

Learner()

The fine-tuning itself is done with [`finetune!`](#). It follows the same protocol as the [fastai implementation](https://github.com/fastai/fastai/blob/f2ab8ba78b63b2f4ebd64ea440b9886a2b9e7b6f/fastai/callback/schedule.py#L153). 

In [26]:
finetune!(learner, 3)

[32mEpoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:56[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   1.0 │ 1.69407 │  0.48564 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:08[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   1.0 │ 1.86775 │  0.54736 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:15[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   2.0 │ 1.16304 │  0.62684 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:07[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   2.0 │ 4.08772 │   0.5708 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:16[39m


┌───────────────┬───────┬────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m   Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼────────┼──────────┤
│ TrainingPhase │   3.0 │ 0.9757 │  0.68877 │
└───────────────┴───────┴────────┴──────────┘


[32mEpoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:07[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   3.0 │ 0.65138 │  0.79778 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:01:16[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   4.0 │ 0.55522 │  0.82361 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:08[39m


┌─────────────────┬───────┬────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m   Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼────────┼──────────┤
│ ValidationPhase │   4.0 │ 0.4671 │  0.84763 │
└─────────────────┴───────┴────────┴──────────┘


Learner()