![Practicum AI Logo image](https://github.com/PracticumAI/practicumai.github.io/blob/main/images/logo/PracticumAI_logo_250x50.png?raw=true) <img src="https://github.com/PracticumAI/practicumai.github.io/blob/84b04be083ca02e5c7e92850f9afd391fc48ae2a/images/icons/practicumai_computer_vision.png?raw=true" alt="Practicum AI: Computer Vision icon" align="right" width=50>
***

# Transfer Learning Concepts

You may recall *Practicum AI*'s heroine Amelia, the AI-savvy nutritionist. At the end of our *[Deep Learning Foundations course](https://practicumai.org/courses/deep_learning/)*, Amelia was helping with a computer vision project. If only she had known about transfer learning, it could have saved her a lot of time! In this notebook, we will get some hands-on experience with transfer learning and show you how to use it to improve your workflows.

![Figure 2 of the AgriNet paper used as the cover image for this notebook. Figure 2 depicts using transfer learning to make a computer vision model more efficient](images/agrinet_figure-cover.jpg)


## AI Pathway review for Transfer Learning & AgriNet 

If you have taken our [Getting Started with AI course](https://practicumai.org/courses/getting_started/), you may remember this figure of the AI Application Development Pathway. Let's take a quick review of how we will apply this to our case study of AgriNet and it's use of transfer learning.

![AI Application Development Pathway image showing the 7 steps in developing an AI application](https://practicumai.org/getting_started/images/application_dev_pathway.png)

1. **Choose a problem to solve:** In this example, we will be trying to make a computer vision model that can recognize images of plants, and categorize them as "healthy", or in one or more disease classes. As you will see when you explore the data, there are over 20 crops and some have as many as 11 disease states. In total, there are 73 categories!
2. **Gather data:** The data for the example comes from [HuggingFace](https://huggingface.co/datasets/zahraa/AgriNet), a great repository of datasets, code, and models.
3. **Clean and prepare the data:** While the dataset is already fairly good, we have done some further processing of the data:
    * We have removed some crops with no "healthy" category. We have combined lemon into the Citrus category. We dropped some crops with relatively few images. 
    * Most importantly, we have made the dataset a bit more realistic and manageable for the course. We have randomly subsampled the images so that each category has no more than 100 images.
4. **Choose a model:** In the *Deep Learning Foundations* course, we presented the model with little detail. Here, we will mostly use the EfficientNet-B5 model (more on that below). The EfficientNet models are Convolutional Neural Networks (CNN). WE will train from scratch as a baseline (using only the model architecture, not the pre-trained weights), and compare it to starting with pre-trained model that we fine-tune on the AgriNet dataset. We will use an EfficientNet-B5 model trained on the ImageNet dataset, which is a large dataset of images with 1000 classes. By fine-tuning this model, we can leverage the knowledge it has learned from the ImageNet dataset and apply it to our specific problem of plant disease classification. Since the *domain* of ImageNet is distinctly different (1000 everyday objects) from the domain of AgriNet (healthy and diseased plants), this is a **Domain Transfer** example. 
   * In the step where you'd choose a model, one can approach this many ways, for this notebook we'll just mention two:
      * **Train from scratch:** This is where you start with a randomly initialized model and train it on your data. This can be computationally expensive and time-consuming.
      * **Domain Transfer via Fine-Tuning:** This is where you start with a pre-trained model and fine-tune it on your data. This is often faster and requires less data.
5. **Train the model:** As mentioned in step 4, we'll demonstrate two approaches in this notebook:
      - Training the EfficientNet model architecture from scratch.
      - Fine-tuning an EfficientNet model pre-trained on ImageNet, a domain-specific dataset to achieve Domain Transfer.
6. **Evaluate the model:** We will use the metrics we gather to make decisions about the model. 
7. **Deploy the model:** We won't get to this stage in this course, but ideally we would end up with a model that could be deployed and achieve relatively good accuracy at solving crop classification problems.


### A Quick Primer on the Baseline Model
We've trained a simple convolutional neural network (CNN) from scratch as a baseline for comparison. If you have time and want to see how the CNN is set up, check out the [00.5_transfer_learning_helper.ipynb](00.5_transfer_learning_helper.ipynb) notebook that is included in this repository. 

Strictly speaking, a thorough knowledge of CNNs is not required for this notebook, but if you're interested in learning more, we recommend the our [PracticumAI: Computer Vision](https://github.com/PracticumAI/computer_vision) Intermediate course. That said, with *any* machine learning work, the better you understand the model, the better you can tune it to your needs.


## PyTorch Lightning 

![PyTorch Lightning logo](images/Lightning_logo-with-text-dark.svg)

[PyTorch Lightning](https://lightning.ai/) is an open-source framework built on top of PyTorch that makes training deep learning models more straightforward. It abstracts many common tasks like managing training loops, logging, checkpointing, and handling hardware setups, allowing you to focus on the core aspects of your model and experimentation. 
Rather than writing repetitive code, you define key methods—such as `training_step` and `validation_step`—to describe the model's behavior while the Lightning trainer automates optimization details, synchronization, and even distributed training. This separation between scientific code and engineering routines leads to cleaner, more maintainable projects that are easier to scale.

Additionally, PyTorch Lightning integrates smoothly with popular tools such as TensorBoard and WandB, which simplifies tracking experiment metrics and visualizing performance. Overall, Lightning streamlines the training workflow, boosts reproducibility, and helps both beginners and seasoned researchers concentrate on innovation, not boilerplate coding.

This course will make use of Lightning to simplify training.


## Import Libraries

First, let's import the libraries we'll need. PyTorch is a popular open-source machine learning library for Python, and is developed by Facebook's AI Research lab (FAIR). We'll also use PyTorch Lightning and other common libraries

In [1]:
# Import Libraries
import numpy as np
import pandas as pd
import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score

import requests
import zipfile

# Import Computer Vision Libraries
import os
from PIL import Image, ImageFile

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Import EfficientNet-B5
from torchvision.models import efficientnet_b5  

# Import pytorch lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import RichProgressBar



#### Check for GPU availability

This cell will check that everything is configured correctly to use your GPU. If everything is correct, you should see something like: 

    Using GPU: [type of GPU]

If you see:
    
    Using CPU
    
Either you do not have a GPU or the kernel is not correctly configured to use it. You might be able to run this notebook, but some sections will take a loooooong time!


In [2]:
# Check for GPU availability
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name()}")
else:
    print("Using CPU")

Using GPU: NVIDIA A100-SXM4-80GB



## 1.1 Load the Data

First, we'll download, unpack and load the data. The data is unpacked into the `data` directory, with the training, validation and test sets loaded into `agri_net_train`, `agri_net_val` and `agri_net_test` respectively.

In [3]:
# Download the dataset, extract it to the data folder and remove the zip file
download_path = "https://data.rc.ufl.edu/pub/practicum-ai/Transfer_Learning_Intermediate/agrinet_curated.zip"
zip_path = "data/agrinet_curated.zip"
data_path = "data"

# Paths to dataset
train_dir = os.path.join(data_path, "agri_net_train100")
val_dir = os.path.join(data_path, "agri_net_val")
test_dir = os.path.join(data_path, "agri_net_test")

# Check if the data is already loaded
if not (
    os.path.exists(train_dir) and os.path.exists(val_dir) and os.path.exists(test_dir)
):
    # Create the data directory if it does not exist
    if not os.path.exists(data_path):
        os.makedirs(data_path)

    # Download the zip file
    r = requests.get(download_path)
    with open(zip_path, "wb") as f:
        f.write(r.content)

    # Extract the zip file
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(data_path)

    # Remove the zip file
    os.remove(zip_path)
    print(f'Data has been downloaded an unziped at {data_path}')
else:
    print("Data is already downloaded.")

# Download the models, extract them to the models folder and remove the zip file
model_download_path = "https://data.rc.ufl.edu/pub/practicum-ai/Transfer_Learning_Intermediate/transfer_learning_concepts_models.zip"
model_zip_path = "models/transfer_learning_concepts_models.zip"
model_data_path = "models"

# Paths to models
baseline_model_trained = os.path.join(model_data_path, "baseline_model.pt")
vgg19_model_ft = os.path.join(model_data_path, "vgg19_model.pt")

# Check if the data is already loaded
if not (os.path.exists(baseline_model_trained) and os.path.exists(vgg19_model_ft)):
    # Create the data directory if it does not exist
    if not os.path.exists(model_data_path):
        os.makedirs(model_data_path)

    # Download the zip file
    r = requests.get(model_download_path)
    with open(model_zip_path, "wb") as f:
        f.write(r.content)

    # Extract the zip file
    with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
        zip_ref.extractall(model_data_path)

    # Remove the zip file
    os.remove(model_zip_path)
    print(f'Models been downloaded an unziped at {model_data_path}')
else:
    print("Models are already loaded.")

Data is already downloaded.
Models are already loaded.


## 1.2 Create the Data Loaders

Next, we'll create the data loaders for the training, validation and test sets. We'll apply some simple data augmentation to the training data. This is another thing that you could experiment with.

While the GPU does most of the calculations in training AI models, the CPU of the computer server is responsible for loading images from disk, doing any transformations and sending the data to the GPU. PyTorch takes care of all of this and takes care of doing this in parallel. For maximum GPU performance, multiple cores are needed to constantly feed data to the GPU. The number of workers (`num_workers`) argument controls how many parallel tasks should be running to load data. 

The code block below will detect if your notebook is running in a Slurm job by checking for the environment variable `SLURM_CPUS_PER_TASK` and using that to set the number of workers if it is defined. If it isn't defined, the code checks for the number of cores on your computer and uses that value. To manually set the number of workers, change the first line by adding the value you want to use.

In [15]:
# Set the number of workers to use for data loading
num_workers = None # To manually set the number of workers change this to an integer

if num_workers is None:
    # Check if Slurm is being used
    # If Slurm is being used, set the number of workers to SLURM_CPUS_PER_TASK
    # If Slurm is not being used, set the number of workers to the number of available CPU cores
    if os.getenv("SLURM_CPUS_PER_TASK") is not None:
        num_workers = int(os.getenv("SLURM_CPUS_PER_TASK"))
        print(f"Using {num_workers} workers for data loading.")
    else:
        num_workers = os.cpu_count()
        print(f"Using {num_workers} workers for data loading.")

Using 16 workers for data loading.


In [10]:
# Allow loading of truncated images, since the dataset's images aren't all the same size!
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Define PyTorch data transforms
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),  # Randomly flip images horizontally
            transforms.RandomRotation(degrees=10),  # Randomly rotate images by up to 10 degrees
            transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),  # Randomly crop and resize
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color jitter
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "test": transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

# Load PyTorch datasets
image_datasets = {
    "train": datasets.ImageFolder(train_dir, data_transforms["train"]),
    "val": datasets.ImageFolder(val_dir, data_transforms["val"]),
    "test": datasets.ImageFolder(test_dir, data_transforms["test"]),
}

# Create PyTorch data loaders
dataloaders = {
    "train": torch.utils.data.DataLoader(
        image_datasets["train"],
        batch_size=32,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    ),
    "val": torch.utils.data.DataLoader(
        image_datasets["val"],
        batch_size=32,
        shuffle=False,
        pin_memory=True,
        num_workers=num_workers,
    ),
    "test": torch.utils.data.DataLoader(
        image_datasets["test"],
        batch_size=32,
        shuffle=False,
        pin_memory=True,
        num_workers=num_workers,
    ),
}

## 1.3 Define the Model


Our first model will be taking the approach of adapting a model for our new task. [EfficientNet](https://en.wikipedia.org/wiki/EfficientNet) is a class of models released by Google Research in 2019 ([Tan and Le, 2019](https://arxiv.org/abs/1905.11946)). Figure 1 from the Tan and Le (2019) paper is below, showing classification accuracy vs model size. The red line connecting them highlights the good performance of EfficientNet models with relatively few parameters, especially compared to the competing models of the time.

![Figure 1 from Tan and Le, 2019](images/EfficientNet_fig1.png)

We will start using the EfficientNet-B5 model as a good tradeoff between accuracy and model size, however, feel free to try different models. The model is imported in the imports cell above. See the [PyTorch documentation](https://pytorch.org/vision/main/models.html) on available models and pre-trained weights.


📝 **Note:**
If you'd like more information on how CNNs work, we explored them as part of Deep Learning Foundations (DLF) course, and have a full Computer Vision Intermediate course. The final notebook of the DLF course, `DLF_01.1_bees_vs_wasps.ipynb`, is included in this repository if you'd like to review the material.

In [22]:
# Define the number of classes in your dataset
num_classes = len(image_datasets["train"].classes)

# Load the EfficientNet model, without pretrained weights
effNet_random_wt = efficientnet_b5(weights=None)

# Replace the classifier (final layer) to match the number of classes
effNet_random_wt.classifier[1] = nn.Linear(
    effNet_random_wt.classifier[1].in_features, num_classes
)

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
effNet_random_wt.to(device)

# Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(effNet_random_wt.parameters(), lr=0.001)

# Print the model summary (optional)
print(effNet_random_wt)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
            (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormAct

In [23]:
# Set Hyperparameters

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(effNet_random_wt.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Define early stopping parameters
early_stopping_patience = 3
best_loss = float("inf")
patience_counter = 0

## 1.4 Train the EfficientNet-B5 Model

Next we'll train the EfficientNet-B5 model, using the architecture, but starting with random weights.

In [13]:
# Set up PyTorch Lightning to train the model
class BaselineModelLightning(pl.LightningModule):
    def __init__(self, model, criterion, optimizer):
        super(BaselineModelLightning, self).__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.best_loss = float("inf")  # Initialize best_loss
        self.patience_counter = 0      # Initialize patience counter

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)

        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()

        # Log training loss and accuracy
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "accuracy": acc}

    def validation_step(self, batch):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)

        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()

        # Log validation loss and accuracy
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)

        return {"val_loss": loss, "val_accuracy": acc}

    def on_training_epoch_end(self, outputs):
        # Aggregate training accuracy and loss
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["accuracy"] for x in outputs]).mean()

        print(f"Epoch {self.current_epoch}: Train Loss: {avg_loss:.4f}, Train Accuracy: {avg_acc:.4f}")

    def on_validation_epoch_end(self, outputs):
        # Aggregate validation accuracy and loss
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()

        print(f"Epoch {self.current_epoch}: Val Loss: {avg_loss:.4f}, Val Accuracy: {avg_acc:.4f}")

    def configure_optimizers(self):
        return {
            "optimizer": self.optimizer,
            "monitor": "val_loss",
        }

In [None]:
class BaselineModelLightning(pl.LightningModule):
    def __init__(self, model, criterion, optimizer):
        super(BaselineModelLightning, self).__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        acc = (preds == labels).float().mean()
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return self.optimizer

# The trainer
trainer = Trainer(max_epochs=10, accelerator="gpu" if torch.cuda.is_available() else "cpu")
trainer.fit(BaselineModelLightning(effNet_random_wt, criterion, optimizer), dataloaders["train"], dataloaders["val"])

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | EfficientNet     | 28.5 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
28.5 M    Trainable params
0         Non-trainable params
28.5 M    Total params
113.961   Total estimated model params size (MB)
790       Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |                                                                  | 0/? [00:00<?, ?it/s]

Training: |                                                                         | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


### 1.5 Save the EffiecientNet_b5

Finally, we'll save the trained model to a file. We'll use the `torch.save()` method to save the model to a file, and the `torch.load()` method to load the model from a file.

In [24]:
# Create a folder to save the models if it does not exist
if not os.path.exists("models"):
    os.makedirs("models")

# Save the trained CNN model
torch.save(effNet_random_wt.state_dict(), "models/effNet_random_wt.pt")

## 1.6 Load the EfficientNet-B5 Model wih pre-trained weights

Next we'll train the EfficientNet-B5 model, starting with the weights from training the model on ImageNet. The `weights=DEFAULT` uses the latest training weights, which as of this writing is ImageNet v.2. Using this will get the current state of the art for a model, but you should check to see what that is when selecting that option.

In [None]:

# Load the EfficientNet model, this time keeping the pretrained weights
effNet_pretrain_wt = efficientnet_b5(weights="DEFAULT")
# Freeze all layers except the classifier
for param in effNet_pretrain_wt.parameters():
    param.requires_grad = False
# Replace the classifier (final layer) to match the number of classes
effNet_pretrain_wt.classifier[1] = nn.Linear(
    effNet_pretrain_wt.classifier[1].in_features, num_classes
)

# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
effNet_pretrain_wt.to(device)

# Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(effNet_pretrain_wt.parameters(), lr=0.001)

# Print the model summary (optional)
print(effNet_pretrain_wt)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
            (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormAct

### 1.7 Fine-tune the EfficientNet-B5 Model

Next we'll train the EfficientNet-B5 model, starting with the weights from training the model on ImageNet. The `

In [None]:
trainer = Trainer(max_epochs=10, accelerator="gpu" if torch.cuda.is_available() else "cpu")
trainer.fit(BaselineModelLightning(effNet_pretrain_wt, criterion, optimizer), dataloaders["train"], dataloaders["val"])

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/lustre/fs0/bsc4892/share/conda/envs/practicum_pl/lib/python3.13/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /lustre/fs0/bsc4892/magitz/transfer_learning/lightning_logs/version_1875/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | EfficientNet     | 28.5 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
149 K     Trainable params
28.3 M    Non-trainable params
28.5 M    Total params
113.961   Total estimated model params size (MB)
790       Modules in train mode
0         Modules in 

Sanity Checking: |                                                                  | 0/? [00:00<?, ?it/s]

Training: |                                                                         | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                       | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## 1.8 Save the EfficientNet-B5 trained from pre-trained weights model

Finally, we'll save the trained model to a file. We'll use the `torch.save()` method to save the model to a file, and the `torch.load()` method to load the model from a file.

In [25]:
# Create a folder to save the models if it does not exist
if not os.path.exists("models"):
    os.makedirs("models")

# Save the fine-tuned VGG19 model
torch.save(effNet_pretrain_wt.state_dict(), "models/effNet_pretrain_wt.pt")

NameError: name 'effNet_pretrain_wt' is not defined

### 1.9 Transfer Learning Concepts - Helper: Conclusion

That's it! We've trained a CNN model and a VGG19 model on the Agrinet dataset. We've saved the models to files, and we can now use them to make predictions on new images. For evaluations and predictions, please see the `01.0_Transfer_Learning_Concepts.ipynb` notebook.