<a href="https://colab.research.google.com/github/a-mhamdi/jlai/blob/main/Codes/Julia/Part-3/transfer-learning/transfer-learning-mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transfer Learning (MNIST DATASET)
---

In [None]:
versioninfo()

Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-8565U CPU @ 1.80GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
Environment:
  LD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  DYLD_LIBRARY_PATH = /home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:/home/mhamdi/torch/install/lib:
  JULIA_NUM_THREADS = 8


In [None]:
using Pkg; pkg"activate ."

[32m[1m  Activating[22m[39m project at `~/Work/git-repos/AI-ML-DL/jlai/Codes/Julia/Part-3/transfer-learning`


[32m[1mStatus[22m[39m `~/Work/git-repos/AI-ML-DL/jlai/Codes/Julia/Part-3/transfer-learning/Project.toml`
  [90m[88a5189c] [39mDataAugmentation v0.3.2
[33m⌅[39m [90m[587475ba] [39mFlux v0.14.25
  [90m[916415d5] [39mImages v0.26.1
  [90m[eb30cadb] [39mMLDatasets v0.7.18
  [90m[dbeba491] [39mMetalhead v0.9.4
  [90m[c3e4b0f8] [39mPluto v0.20.4
  [90m[7f904dfe] [39mPlutoUI v0.7.60
  [90m[d6f4376e] [39mMarkdown v1.11.0
[36m[1mInfo[22m[39m Packages marked with [33m⌅[39m have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`


In [None]:
Pkg.add(["DataAugmentation", "Flux", "Images", "ProgressMeter", "MLDatasets", "Metalhead", "CUDA", "cuDNN"])

In [None]:
pkg"status"

Load the pre-trained model

In [None]:
using Metalhead

[API Reference](https://fluxml.ai/Metalhead.jl/dev/api/reference/#API-Reference)

In [None]:
vgg = VGG(16; pretrain=true).layers;

In [None]:
using DataAugmentation

In [None]:
tfm = CenterCrop((224, 224)) |> ImageToTensor()

Sequence{Tuple{Crop{2, DataAugmentation.FromCenter}, ImageToTensor{Float32}}}((Crop{2, DataAugmentation.FromCenter}((224, 224), DataAugmentation.FromCenter()), ImageToTensor{Float32}()))

In [None]:
using Flux
using Flux: onecold, onehotbatch

In [None]:
model = Chain(
    vgg[1:end-1],
    vgg[end][1:end-1],
    # Replace the last layer
    Dense(4096, 256),
    Dense(256, 10)
    ) |> gpu

Chain(
  Chain(
    Chain(
      Conv((3, 3), 3 => 64, relu, pad=1),  [90m# 1_792 parameters[39m
      Conv((3, 3), 64 => 64, relu, pad=1),  [90m# 36_928 parameters[39m
      MaxPool((2, 2)),
      Conv((3, 3), 64 => 128, relu, pad=1),  [90m# 73_856 parameters[39m
      Conv((3, 3), 128 => 128, relu, pad=1),  [90m# 147_584 parameters[39m
      MaxPool((2, 2)),
      Conv((3, 3), 128 => 256, relu, pad=1),  [90m# 295_168 parameters[39m
      Conv((3, 3), 256 => 256, relu, pad=1),  [90m# 590_080 parameters[39m
      Conv((3, 3), 256 => 256, relu, pad=1),  [90m# 590_080 parameters[39m
      MaxPool((2, 2)),
      Conv((3, 3), 256 => 512, relu, pad=1),  [90m# 1_180_160 parameters[39m
      Conv((3, 3), 512 => 512, relu, pad=1),  [90m# 2_359_808 parameters[39m
      Conv((3, 3), 512 => 512, relu, pad=1),  [90m# 2_359_808 parameters[39m
      MaxPool((2, 2)),
      Conv((3, 3), 512 => 512, relu, pad=1),  [90m# 2_359_808 parameters[39m
      Conv((3, 3), 512 => 512, relu,

In [None]:
using MLDatasets

Load the MNIST dataset

In [None]:
function get_data(split)
    data = MNIST(split)
    imgs, y = data.features ./ 255, onehotbatch(data.targets, 0:9);
    X = []
    for i in 1:length(y)
        img = apply(tfm, Image(RGB.(imgs[:,:,i]))) |> itemdata
        push!(X, img)
    end
    loader = Flux.Data.DataLoader((X, y); batchsize=64, shuffle=true) |> gpu;
    return loader
end

get_data (generic function with 1 method)

In [None]:
using Images

In [None]:
train_loader = get_data(:train);
test_loader = get_data(:test);

LoadError: UndefVarError: `get_data` not defined in `Main`
Suggestion: check for spelling errors or missing imports.

**Define a loss function and an optimizer**

In [None]:
loss_fn = Flux.logitcrossentropy
opt_state = Flux.setup(Adam(3e-3), model[end]) # Freeze the weights of the pre-trained layers

In [None]:
using ProgressMeter

In [None]:
epochs = 3

Fine-tune the model

In [None]:
for epoch in 1:epochs
    @showprogress for (X, y) in train_loader
        # Compute the gradient of the loss with respect to the model's parameters
        loss, ∇ = Flux.withgradient(model) do m
            ŷ = m(X)
            loss_fn(ŷ, y)
        end
        # Update the model's parameters using the optimizer
        Flux.update!(opt_state, model, ∇[1])
    end
    @info "Calculate the accuracy on the test set"
    for (X, y) in test_loader
        ys = onecold(y, 0:9) |> cpu
        ŷ  = onecold(model(X), 0:9) |> cpu
        accuracy = sum( ys .== ŷ ) / length(ys)
        println("Epoch: $epoch, Accuracy: $accuracy")
    end
end