![Practicum AI Logo image](https://github.com/PracticumAI/practicumai.github.io/blob/main/images/logo/PracticumAI_logo_250x50.png?raw=true) <img src="https://github.com/PracticumAI/practicumai.github.io/blob/84b04be083ca02e5c7e92850f9afd391fc48ae2a/images/icons/practicumai_computer_vision.png?raw=true" alt="Practicum AI: Computer Vision icon" align="right" width=50>
***

# Transfer Learning Helper

Training the models we're going to be looking at in this course can be *very* time consuming. To make the course more manageable, we've separated the training of the models from the rest of the course. Think of it like a cooking show: we'll show you how to make the dish, but we won't make you wait for it to bake in the oven!

Below you'll find the logic to train each of the models we use in this course. When you have time, you can experiment with the code and train the models yourself!

## Import Libraries

First, let's import the libraries we'll need. These models are trained primarily using the `torchvision` library, which is a part of PyTorch. We'll also use `torch`, `torchvision`, and `torchsummary` to help us train the models.

PyTorch is a popular open-source machine learning library for Python, and is developed by Facebook's AI Research lab (FAIR).

In [None]:
# Import Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score

import requests
import zipfile

# Import Computer Vision Libraries
import os
from PIL import Image, ImageFile
import glob
import torch
import torchvision
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Check for GPU availability, and if not available, use CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

## 1.0 Transer Learning Concepts - Helper

The first models we'll train will be a CNN model trained on our curated Agrinet set, and then VGG19 fine-tuned on the same set. We'll use the `torchvision` library to load the models and datasets, and `torch` to train the models.

### 1.1 Load the Data

First, we'll download, unpack and load the data. The data is unpacked into the `data` directory, with the training, validation and test sets loaded into `agri_net_train`, `agri_net_val` and `agri_net_test` respectively.

In [None]:
# Download the dataset, extract it to the data folder and remove the zip file
download_path = "https://data.rc.ufl.edu/pub/practicum-ai/Transfer_Learning_Intermediate/agrinet_curated.zip"
zip_path = "data/agrinet_curated.zip"
data_path = "data"

# Paths to dataset
train_dir = os.path.join(data_path, "agri_net_train")
val_dir = os.path.join(data_path, "agri_net_val")
test_dir = os.path.join(data_path, "agri_net_test")

# Check if the data is already loaded
if not (
    os.path.exists(train_dir) and os.path.exists(val_dir) and os.path.exists(test_dir)
):
    # Create the data directory if it does not exist
    if not os.path.exists(data_path):
        os.makedirs(data_path)

    # Download the zip file
    r = requests.get(download_path)
    with open(zip_path, "wb") as f:
        f.write(r.content)

    # Extract the zip file
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(data_path)

    # Remove the zip file
    os.remove(zip_path)
else:
    print("Data is already loaded.")

### 1.2 Create the Data Loaders

Next, we'll create the data loaders for the training, validation and test sets. We'll use the `DataLoader` class from `torch.utils.data` to create the data loaders, and the `transforms` module from `torchvision` to apply transformations to the images.

In [None]:
# Define PyTorch data transforms
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "test": transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

# Load PyTorch datasets
image_datasets = {
    "train": datasets.ImageFolder(train_dir, data_transforms["train"]),
    "val": datasets.ImageFolder(val_dir, data_transforms["val"]),
    "test": datasets.ImageFolder(test_dir, data_transforms["test"]),
}

# Create PyTorch data loaders
dataloaders = {
    "train": torch.utils.data.DataLoader(
        image_datasets["train"],
        batch_size=128,
        shuffle=True,
        pin_memory=True,
        num_workers=2,
    ),
    "val": torch.utils.data.DataLoader(
        image_datasets["val"],
        batch_size=128,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    ),
    "test": torch.utils.data.DataLoader(
        image_datasets["test"],
        batch_size=128,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    ),
}

### 1.3 Define the Model

We'll define the model we're going to train. The first model we'll train is a CNN model, which is a simple convolutional neural network. We'll define the layers of the model using the `nn` module from `torch`.

📝 **Note:**
If you'd like more information on how CNNs work, we explored them as part of Deep Learning Foundations (DLF) course, and have a full Computer Vision Intermediate course. The final notebook of the DLF course, `DLF_01.1_bees_vs_wasps.ipynb`, is included in this repository if you'd like to review the material.

In [None]:
# Handle truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True


# Define baseline model using PyTorch
class BaselineModel(nn.Module):
    def __init__(self, num_classes):
        super(BaselineModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 112 * 112, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


num_classes = len(image_datasets["train"].classes)
baseline_model_pt = BaselineModel(num_classes).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(baseline_model_pt.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Define early stopping parameters
early_stopping_patience = 3
best_loss = float("inf")
patience_counter = 0

### 1.4 Train the CNN Model

Next we'll train the CNN. We'll use the `train()` method to train the model. The `train()` method takes the model, the data loaders, the loss function, the optimizer, and the number of epochs as arguments, and trains the model for the specified number of epochs.

In [None]:
# Train the baseline model using PyTorch
num_epochs = 5  # Number of epochs to train. Increase this value for better results
for epoch in range(num_epochs):
    baseline_model_pt.train()
    running_loss = 0.0
    progress_bar = tqdm.tqdm(  # tqdm is used to show the progress of the training
        dataloaders["train"], desc=f"Epoch {epoch+1}/{num_epochs}", leave=False
    )
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = baseline_model_pt(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        progress_bar.set_postfix(loss=loss.item())
    scheduler.step()
    epoch_loss = running_loss / len(image_datasets["train"])
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

    # Early stopping check. If the loss does not improve for 'early_stopping_patience' epochs, stop training
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

### 1.5 Save the CNN Model

Finally, we'll save the trained model to a file. We'll use the `torch.save()` method to save the model to a file, and the `torch.load()` method to load the model from a file.

In [None]:
# Create a folder to save the models if it does not exist
if not os.path.exists("models"):
    os.makedirs("models")

# Save the trained CNN model
torch.save(baseline_model_pt.state_dict(), "models/baseline_model.pt")

### 1.6 Load the VGG19 Model and Define our Hyperparameters

Per the paper ["The Power of Transfer Learning for Agricultural Applications: Agrinet"](https://arxiv.org/abs/2207.03881), the best result the research team was able to achieve used an ImageNet pre-trained VGG19 model, and was fine-tuned on the Agrinet dataset. VGG19 (Visual Geometry Group, with 19 layers) is a specific computer vision model, unlike the very generic CNN we just trained as a baseline. We'll load the pre-trained VGG19 model from `torchvision.models`, and fine-tune it on the Agrinet dataset!

In [None]:
# Load VGG19 model pre-trained on ImageNet dataset
vgg19 = models.vgg19(pretrained=True)

# Freeze the pre-trained layers
for param in vgg19.parameters():
    param.requires_grad = False

# Modify the classifier to match the number of classes in the dataset (their are 86 classes in the dataset!)
num_features = vgg19.classifier[6].in_features
vgg19.classifier[6] = nn.Linear(num_features, num_classes)

# Move the model to the device (GPU if available)
vgg19 = vgg19.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg19.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Define early stopping parameters
early_stopping_patience = 3
best_loss = float("inf")
patience_counter = 0

### 1.7 Train the VGG19 Model

Next we'll train the VGG19 model. We'll use the `train()` method to train the model. The `train()` method takes the model, the data loaders, the loss function, the optimizer, and the number of epochs as arguments, and trains the model for the specified number of epochs.

In [None]:
# Train the VGG19 model using PyTorch
num_epochs = 5  # Number of epochs to train. Increase this value for better results

for epoch in range(num_epochs):
    vgg19.train()
    running_loss = 0.0
    progress_bar = tqdm.tqdm(  # tqdm is used to show the progress of the training
        dataloaders["train"], desc=f"Epoch {epoch+1}/{num_epochs}", leave=False
    )
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = vgg19(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        progress_bar.set_postfix(loss=loss.item())
    scheduler.step()
    epoch_loss = running_loss / len(image_datasets["train"])
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

    # Early stopping check. If the loss does not improve for 'early_stopping_patience' epochs, stop training
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

### 1.8 Save the VGG19 Model

Finally, we'll save the trained model to a file. We'll use the `torch.save()` method to save the model to a file, and the `torch.load()` method to load the model from a file.

In [None]:
# Create a folder to save the models if it does not exist
if not os.path.exists("models"):
    os.makedirs("models")

# Save the fine-tuned VGG19 model
torch.save(vgg19.state_dict(), "models/vgg19_model.pt")

### 1.9 Transfer Learning Concepts - Helper: Conclusion

That's it! We've trained a CNN model and a VGG19 model on the Agrinet dataset. We've saved the models to files, and we can now use them to make predictions on new images. For evaluations and predictions, please see the `01.0_Transfer_Learning_Concepts.ipynb` notebook.