# 04 — Transfer Learning with ResNet18

This notebook uses transfer learning to train a high-performance plant disease classifier.
Instead of training a CNN from scratch, we load a pretrained ResNet18 model trained on ImageNet.
We replace the final layer and fine-tune the model on our PlantVillage dataset.
This method dramatically improves accuracy and reduces training time.


In [1]:
#basic PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

#pretrained models + image transforms
from torchvision import models, transforms

import sys, os
sys.path.append("..")

from dataset import PlantVillageDataset
from train_utils import train_one_epoch, validate

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [2]:
#transforms prepare images the way ResNet18 expects
#ImageNet-trained models need imgs resized to 224x224 and normalized to specific vals
#ImageNet -> dataset
#ResNet -> model, type of nn
train_transforms = transforms.Compose([
  transforms.Resize((160,160)), #size needed for ResNet
  #augmentation: change orientation of imgs while training so it can recognize disease regardless if leaf is flipped or rotated
  transforms.RandomHorizontalFlip(), # flips left or right randomly
  transforms.RandomRotation(10), # rotates images randomly between -10 deg and 20 deg
  transforms.ToTensor(),
  transforms.Normalize(#standard imagenet normalization using z-score standardization
    mean=[0.485,0.456,0.406], 
    std=[0.229,0.224,0.225]
  )
])

val_transforms = transforms.Compose([
  #no augmentation for val bc validation must reflect real unmodified images
  transforms.Resize((160,160)),
  transforms.ToTensor(),
  transforms.Normalize(
    mean=[0.485,0.456,0.406],
    std=[0.229,0.224,0.225]
  )
])

In [3]:
#create dataset then apply transforms
train_dataset = PlantVillageDataset("../data/PlantVillage/train", transform=train_transforms)
val_dataset = PlantVillageDataset("../data/PlantVillage/val", transform=val_transforms)

#loaders to feed data in small batches while training
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [4]:
#load pretrained ResNet18
#model has already learned features from 1.2 mil images (through ImageNet)
# -------------------------------
# LOAD RESNET18 PRETRAINED MODEL
# -------------------------------
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Number of plant disease classes
num_classes = len(train_dataset.classes)

# Replace the final classification layer so it matches plant disease labels
model.fc = nn.Linear(512, num_classes)

# Move model to CPU or GPU depending on your system
model = model.to(device)

# --------------------------------------------------
# FREEZE THE EARLY LAYERS (so they don't get updated)
# --------------------------------------------------
# ResNet18 has multiple "children" (layers).
# We freeze the first ~6 layers because they learn VERY basic features
# like edges, curves, textures — these features work for almost any image.
layer_counter = 0
for child in model.children():
    layer_counter += 1
    
    # Freeze layers 1–6
    if layer_counter < 7:
        for param in child.parameters():
            param.requires_grad = False
            
# Only the later layers + final layer will train.
# This reduces training time A LOT while keeping high accuracy.

In [5]:
# CREATE OPTIMIZER (only trains unfrozen layers)
# --------------------------------------------
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4 #learning rate
)

# Standard classification loss
criterion = nn.CrossEntropyLoss()


In [6]:
num_epochs = 5

for epoch in range(num_epochs):
  print(f"Epoch {epoch+1}/{num_epochs}")

  print("Running: Training Epoch")
  train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
  
  print("Running: Validation epoch")
  val_loss, val_acc = validate(model, val_loader, criterion)

  print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
  print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
  print("-" * 50)
  

Epoch 1/5
Running: Training Epoch
Running: Validation epoch
Train Loss: 0.2412, Train Acc: 0.9357
Val Loss: 0.0533, Val Acc: 0.9843
--------------------------------------------------
Epoch 2/5
Running: Training Epoch
Running: Validation epoch
Train Loss: 0.0712, Train Acc: 0.9784
Val Loss: 0.0428, Val Acc: 0.9852
--------------------------------------------------
Epoch 3/5
Running: Training Epoch
Running: Validation epoch
Train Loss: 0.0525, Train Acc: 0.9834
Val Loss: 0.0316, Val Acc: 0.9893
--------------------------------------------------
Epoch 4/5
Running: Training Epoch
Running: Validation epoch
Train Loss: 0.0430, Train Acc: 0.9866
Val Loss: 0.0295, Val Acc: 0.9906
--------------------------------------------------
Epoch 5/5
Running: Training Epoch
Running: Validation epoch
Train Loss: 0.0347, Train Acc: 0.9888
Val Loss: 0.0266, Val Acc: 0.9917
--------------------------------------------------


In [7]:
torch.save(model.state_dict(), "../models/resnet18_best.pth")