# Transfer Learning (MNIST DATASET)
---

In [None]:
# Gray Color Images , i.e., n° of channnels = 1

In [2]:
versioninfo() # -> v"1.11.1"

Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-8565U CPU @ 1.80GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
Environment:
  LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  JULIA_NUM_THREADS = 8


In [None]:
using Pkg; pkg"activate ."

In [None]:
using Metalhead

Load the pre-trained model

[API Reference](https://fluxml.ai/Metalhead.jl/dev/api/reference/#API-Reference)

In [None]:
vgg = VGG(16; pretrain=true).layers;

In [None]:
using DataAugmentation

In [None]:
tfm = CenterCrop((224, 224)) |> ImageToTensor()

In [None]:
using Flux
using Flux: onecold, onehotbatch

In [None]:
model = Chain(
    vgg[1:end-1],
    vgg[end][1:end-1],
    # Replace the last layer
    Dense(4096, 256),
    Dense(256, 10)
)

In [None]:
using MLDatasets

Load the MNIST dataset

In [None]:
function get_data(split)
    data = MNIST(split)
    imgs, y = data.features ./ 255, onehotbatch(data.targets, 0:9);
    X = []
    for i in 1:length(y)
        img = apply(tfm, Image(RGB.(imgs[:,:,i]))) |> itemdata
        push!(X, img)
    end
    loader = Flux.Data.DataLoader((X, y); batchsize=64, shuffle=true);
    return loader
end

In [None]:
using Images

In [None]:
train_loader = get_data(:train);
test_loader = get_data(:test);

**Define a loss function and an optimizer**

In [None]:
loss_fn = Flux.logitcrossentropy
opt_state = Flux.setup(Adam(3e-3), model[end]) # Freeze the weights of the pre-trained layers

In [None]:
using ProgressMeter

In [None]:
epochs = 3

Fine-tune the model

In [None]:
for epoch in 1:epochs
    @showprogress for (X, y) in train_loader
        # Compute the gradient of the loss with respect to the model's parameters
        loss, ∇ = Flux.withgradient(model) do m
            ŷ = m(X)
            loss_fn(ŷ, y)
        end
        # Update the model's parameters using the optimizer
        Flux.update!(opt_state, model, ∇[1])
    end
    @info "Calculate the accuracy on the test set"
    for (X, y) in test_loader
        accuracy = sum(onecold(model(X)) .== onecold(y)) / length(y)
        println("Epoch: $epoch, Accuracy: $accuracy")
    end
end