# Transfer Learning (CIFAR10 DATASET) 
---

In [1]:
versioninfo() # -> v"1.11.1"

Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-8565U CPU @ 1.80GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
Environment:
  LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  JULIA_NUM_THREADS = 8


In [None]:
using Metalhead

Load the pre-trained model

[API Reference](https://fluxml.ai/Metalhead.jl/dev/api/reference/#API-Reference)

In [None]:
resnet = ResNet(18; pretrain=true).layers;

In [None]:
using Flux
using Flux: onecold, onehotbatch

In [None]:
mdl = Chain(
    resnet[1:end-1],
    resnet[end][1:end-1],
    # Replace the last layer
    Dense(512 => 256, relu),
    Dense(256 => 10)
)

In [None]:
using MLDatasets: CIFAR10

Load the CIFAR10 dataset

In [None]:
function get_data(split, lm::Integer=1024)
    data = CIFAR10(split)
    X, y = data.features[:, :, :, 1:lm] ./ 255, onehotbatch(data.targets[1:lm], 0:9)
    loader = Flux.DataLoader((X, y); batchsize=16, shuffle=true)
    return loader
end

In [None]:
train_loader = get_data(:train, 512);
test_loader = get_data(:test, 128);

Define a setup of the optimizer

In [None]:
loss(X, y) = Flux.Losses.logitcrossentropy(mdl(X), y)
opt = Adam(3e-3)
ps = Flux.params(mdl[3:end])

In [None]:
for epoch in 1:5
    Flux.train!(loss, ps, train_loader, opt, cb=Flux.throttle(() -> println("Training"), 10))
end

In [None]:
for epoch in 1:100
  Flux.train!(model, train_set, opt_state) do m, x, y
    loss(m(x), y)
  end
end

In [None]:
using ImageShow, ImageInTerminal

In [None]:
idx = rand(1:50000)
convert2image(d, idx)
printstyled("Label is $(d.targets[idx])"; bold=true, color=:red)

In [None]:
#=
using Optimisers
opt_state = Optimisers.setup(Adam(3e-3), mdl[3:end]) # Freeze the weights of the pre-trained layers
using ProgressMeter
epochs = 5
# Fine-tune the model
for epoch in 1:epochs
    @showprogress for (X, y) in train_loader
        # Compute the gradient of the loss with respect to the model's parameters
        ∇ = Flux.gradient( m -> loss(m, X, y), mdl)
        # Update the `mdl`'s parameters
        Flux.update!(opt_state, mdl, ∇[1])
    end
    @info "Calculate the accuracy on the test set"
    for (X, y) in test_loader
        accuracy = sum(onecold(mdll(X)) .== onecold(y)) / length(y)
        println("Epoch: $epoch, Accuracy: $accuracy")
    end
end
=#