# Transfer Learning with Flux

This article is intended to be a general guide to how transfer learning works in the Flux ecosystem.
We assume a certain familiarity of the reader with the concept of transfer learning. Having said that,
we will start off with a basic definition of the setup and what we are trying to achieve. There are many
resources online that go in depth as to why transfer learning is an effective tool to solve many ML
problems, and we recommend checking some of those out.

Machine Learning today has evolved to use many highly trained models in a general task,
where they are tuned to perform especially well on a subset of the problem.

This is one of the key ways in which larger (or smaller) models are used in practice. They are trained on
a general problem, achieving good results on the test set, and then subsequently tuned on specialised datasets.

In this process, our model is already pretty well trained on the problem, so we don't need to train it
all over again as if from scratch. In fact, as it so happens, we don't need to do that at all! We only
need to tune the last couple of layers to get the most performance from our models. The exact last number of layers
is dependant on the problem setup and the expected outcome, but a common tip is to train the last few `Dense`
layers in a more complicated model.

So let's try to simulate the problem in Flux.

In [5]:
]activate .

[32m[1m Activating[22m[39m environment at `~/odsceurpoe/ODSCEurope2020/Project.toml`


In [7]:
using Pkg; Pkg.add("Metalhead")

[32m[1m   Updating[22m[39m registry at `~/.julia/registries/General`


[?25l    

[32m[1m   Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`




[32m[1m  Resolving[22m[39m package versions...
[32m[1m  Installed[22m[39m ArrayLayouts ─ v0.4.8
[32m[1mUpdating[22m[39m `~/odsceurpoe/ODSCEurope2020/Project.toml`
 [90m [dbeba491] [39m[92m+ Metalhead v0.5.1[39m
 [90m [91a5bcdd] [39m[95m↓ Plots v1.6.4 ⇒ v1.0.14[39m
[32m[1mUpdating[22m[39m `~/odsceurpoe/ODSCEurope2020/Manifest.toml`
 [90m [4c555306] [39m[93m↑ ArrayLayouts v0.4.7 ⇒ v0.4.8[39m
 [90m [fbb218c0] [39m[92m+ BSON v0.2.6[39m
 [90m [35d6a980] [39m[91m- ColorSchemes v3.9.0[39m
 [90m [3da002f7] [39m[95m↓ ColorTypes v0.10.9 ⇒ v0.9.1[39m
 [90m [5ae59095] [39m[95m↓ Colors v0.12.4 ⇒ v0.11.2[39m
 [90m [5ae413db] [39m[91m- EarCut_jll v2.1.5+0[39m
 [90m [c87230d0] [39m[95m↓ FFMPEG v0.4.0 ⇒ v0.3.0[39m
 [90m [53c48c17] [39m[95m↓ FixedPointNumbers v0.8.4 ⇒ v0.7.1[39m
 [90m [28b8d3ca] [39m[95m↓ GR v0.52.0 ⇒ v0.48.0[39m
 [90m [5c1252a2] [39m[91m- GeometryBasics v0.3.1[39m
 [90m [cd3eb016] [39m[91m- HTTP v0.8.19[39m
 [90m [83e

In [8]:
using Flux, Metalhead
using Flux: @epochs
using Metalhead.Images
resnet = ResNet().layers

┌ Info: Precompiling Metalhead [dbeba491-748d-5e0e-a39e-b530a07fa0cc]
└ @ Base loading.jl:1278


Chain(Conv((7, 7), 3=>64), MaxPool((3, 3), pad = (1, 1), stride = (2, 2)), Metalhead.ResidualBlock((Conv((1, 1), 64=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), Chain(Conv((1, 1), 64=>256), BatchNorm(256))), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), Chain(Conv((1, 1), 256=>512), BatchNorm(512))), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Con

If we intended to add a new class of objects in there, we need only `reshape` the output from the previous layers accordingly.
Our model would look something like so:

```julia
model = Chain(
  resnet[1:end-2],               # We only need to pull out the dense layer in here
  x -> reshape(x, size_we_want), # / global_avg_pooling layer
  Dense(reshaped_input_features, n_classes)
)
```

We will use the [Dogs vs. Cats](https://www.kaggle.com/c/dogs-vs-cats/data) dataset from Kaggle for our use here.
Make sure to extract the images in a `train` folder.

The `datatloader.jl` script contains some functions that will help us load batches of images, shuffled between
dogs and cats along with their correct labels.

In [9]:
include("dataloader.jl")

LoadError: LoadError: SystemError: unable to read directory /home/dhairyalgandhi/dogsvcats/train: No such file or directory
in expression starting at /home/dhairyalgandhi/odsceurpoe/ODSCEurope2020/dataloader.jl:7

Finally, the model looks something like:

In [10]:
model = Chain(
  resnet[1:end-2],
  Dense(2048, 1000),
  Dense(1000, 256),
  Dense(256, 2),        # we get 2048 features out, and we have 2 classes
  softmax
)

Chain(Chain(Conv((7, 7), 3=>64), MaxPool((3, 3), pad = (1, 1), stride = (2, 2)), Metalhead.ResidualBlock((Conv((1, 1), 64=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), Chain(Conv((1, 1), 64=>256), BatchNorm(256))), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>64), Conv((3, 3), 64=>64), Conv((1, 1), 64=>256)), (BatchNorm(64), BatchNorm(64), BatchNorm(256)), identity), Metalhead.ResidualBlock((Conv((1, 1), 256=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), Chain(Conv((1, 1), 256=>512), BatchNorm(512))), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128), Conv((1, 1), 128=>512)), (BatchNorm(128), BatchNorm(128), BatchNorm(512)), identity), Metalhead.ResidualBlock((Conv((1, 1), 512=>128), Conv((3, 3), 128=>128

To speed up training, let’s move everything over to the GPU

In [None]:
model = model |> gpu
dataset = [gpu.(load_batch(10)) for i in 1:10]

After this, we only need to define the other parts of the training pipeline like we usually do.

In [None]:
opt = ADAM()
loss(x,y) = Flux.crossentropy(model(x), y)

Now to train
As discussed earlier, we don’t need to pass all the parameters to our training loop. Only the ones we need to
fine-tune. Note that we could have picked and chosen the layers we want to train individually as well, but this
is sufficient for our use as of now.

In [None]:
ps = Flux.params(model[2:end])  # ignore the already trained layers of the ResNet

And now, let's train!

In [None]:
@epochs 2 Flux.train!(loss, ps, dataset, opt)

And there you have it, a pretrained model, fine tuned to tell the the dogs from the cats.

We can verify this too.

In [None]:
imgs, labels = gpu.(load_batch(10))
display(model(imgs))

labels