In [None]:
%pip uninstall torch torchvision torchaudio -y

In [None]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

In [None]:
%pip install scikit-learn

In [1]:
import pandas as pd

import os
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.metrics import f1_score, accuracy_score, classification_report

from tqdm import tqdm

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
PROJECT_ROOT = ".."

In [4]:
df_labels = pd.read_csv(f'{PROJECT_ROOT}/data/splits/label_space.csv')
NUM_LABELS = len(df_labels)

# PyTorch Dataset & DataLoader

In [5]:
class PlantDiseaseDataset(Dataset):
    def __init__(self, split, processor, transform=None):
        self.split = split
        self.processor = processor
        self.transform = transform

        self.samples = []
        self.class_to_idx = {}

        if split == "train":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_train.csv')
        elif split == "val":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_val.csv')
        elif split == "test_pv":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_test.csv')
        elif split == "test_pd":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/plantdoc_test_mapped.csv')

        self.samples = [
            (f"{PROJECT_ROOT}/{row['filepath_rel']}", row['canonical_id'])
            for _, row in df.iterrows()
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # HF processor → tensor + normalization
        encoding = self.processor(
            image,
            return_tensors="pt"
        )

        return {
            "pixel_values": encoding["pixel_values"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [6]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained(
    "google/vit-base-patch16-224-in21k"
)

train_dataset = PlantDiseaseDataset(
    split="train",
    processor=processor
)

val_dataset = PlantDiseaseDataset(
    split="val",
    processor=processor,
)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [8]:
# Sanity check
batch = next(iter(train_loader))

print(batch["pixel_values"].shape) 
print(batch["labels"].shape)        

torch.Size([16, 3, 224, 224])
torch.Size([16])


# Fine Tuning

## Normal ViT

In [9]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=NUM_LABELS,
    ignore_mismatched_sizes=True
).to(DEVICE)

Loading weights: 100%|██████████| 198/198 [00:00<00:00, 735.58it/s, Materializing param=vit.layernorm.weight]                                 
ViTForImageClassification LOAD REPORT from: google/vit-base-patch16-224-in21k
Key                 | Status     | 
--------------------+------------+-
pooler.dense.weight | UNEXPECTED | 
pooler.dense.bias   | UNEXPECTED | 
classifier.bias     | MISSING    | 
classifier.weight   | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


In [None]:
# Freeze ViT layers
for param in model.vit.parameters():
    param.requires_grad = False

# Unfreeze classifier
for param in model.classifier.parameters():
    param.requires_grad = True

# Unfreeze last few encoder layers
for block in model.vit.encoder.layer[-4:]:
    for param in block.parameters():
        param.requires_grad = True

In [15]:
# Set up optimizer with parameter groups

# Pick out parameters to apply weight decay to
decay, no_decay = [], []

for name, param in model.named_parameters():
    if not param.requires_grad or "classifier" in name:
        continue
    
    if param.ndim == 1 or "bias" in name or "layernorm" in name.lower():
        no_decay.append(param)
    else:
        decay.append(param)

# Optimizer
optimizer = torch.optim.AdamW([
    {
        "params": decay,
        "lr": 1e-4,
        "weight_decay": 0.05
    },
    {
        "params": no_decay,
        "lr": 1e-4,
        "weight_decay": 0.0
    },
    {
        "params": model.classifier.parameters(),
        "lr": 1e-3,
        "weight_decay": 0.0
    }
])

In [None]:
# Mixed precision
scaler = torch.amp.GradScaler()

num_epochs = 10

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # -------------------
    # Training
    # -------------------
    model.train()
    train_loss = 0.0
    
    for batch in tqdm(train_loader, desc="Training"):
        pixel_values = batch["pixel_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        optimizer.zero_grad()

        # Mixed precision forward
        with torch.amp.autocast(device_type=DEVICE.type):
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

        # Backward + step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Train Loss: {avg_train_loss:.4f}")

    # -------------------
    # Validation
    # -------------------
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        all_preds = []
        all_labels = []
    
        for batch in tqdm(val_loader, desc="Validation"):
            pixel_values = batch["pixel_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            with torch.amp.autocast(device_type=DEVICE.type):
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss

            val_loss += loss.item()

            preds = torch.argmax(outputs.logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(all_labels, all_preds)
    val_macro_f1 = f1_score(all_labels, all_preds, average='macro')
    val_micro_f1 = f1_score(all_labels, all_preds, average='micro')

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, Macro F1 Score: {val_macro_f1:.4f}, Micro F1 Score: {val_micro_f1:.4f}")


Epoch 1/10


Training: 100%|██████████| 1804/1804 [05:49<00:00,  5.17it/s]


Average Train Loss: 0.4437


Validation: 100%|██████████| 226/226 [00:48<00:00,  4.70it/s]


Validation Loss: 0.1755, Accuracy: 0.9592, Macro F1 Score: 0.9549, Micro F1 Score: 0.9592

Epoch 2/10


Training: 100%|██████████| 1804/1804 [05:00<00:00,  6.01it/s]


Average Train Loss: 0.1267


Validation: 100%|██████████| 226/226 [00:42<00:00,  5.32it/s]


Validation Loss: 0.1110, Accuracy: 0.9739, Macro F1 Score: 0.9706, Micro F1 Score: 0.9739

Epoch 3/10


Training: 100%|██████████| 1804/1804 [04:22<00:00,  6.86it/s]


Average Train Loss: 0.0836


Validation: 100%|██████████| 226/226 [00:31<00:00,  7.13it/s]


Validation Loss: 0.0855, Accuracy: 0.9778, Macro F1 Score: 0.9738, Micro F1 Score: 0.9778

Epoch 4/10


Training: 100%|██████████| 1804/1804 [03:34<00:00,  8.42it/s]


Average Train Loss: 0.0628


Validation: 100%|██████████| 226/226 [00:41<00:00,  5.45it/s]


Validation Loss: 0.0736, Accuracy: 0.9828, Macro F1 Score: 0.9805, Micro F1 Score: 0.9828

Epoch 5/10


Training: 100%|██████████| 1804/1804 [03:12<00:00,  9.38it/s]


Average Train Loss: 0.0499


Validation: 100%|██████████| 226/226 [00:22<00:00, 10.04it/s]


Validation Loss: 0.0623, Accuracy: 0.9842, Macro F1 Score: 0.9811, Micro F1 Score: 0.9842

Epoch 6/10


Training: 100%|██████████| 1804/1804 [03:46<00:00,  7.98it/s]


Average Train Loss: 0.0414


Validation: 100%|██████████| 226/226 [00:23<00:00,  9.45it/s]


Validation Loss: 0.0561, Accuracy: 0.9845, Macro F1 Score: 0.9804, Micro F1 Score: 0.9845

Epoch 7/10


Training: 100%|██████████| 1804/1804 [04:15<00:00,  7.05it/s]


Average Train Loss: 0.0351


Validation: 100%|██████████| 226/226 [00:27<00:00,  8.17it/s]


Validation Loss: 0.0544, Accuracy: 0.9856, Macro F1 Score: 0.9831, Micro F1 Score: 0.9856

Epoch 8/10


Training: 100%|██████████| 1804/1804 [05:01<00:00,  5.99it/s]


Average Train Loss: 0.0307


Validation: 100%|██████████| 226/226 [00:22<00:00, 10.08it/s]


Validation Loss: 0.0466, Accuracy: 0.9881, Macro F1 Score: 0.9856, Micro F1 Score: 0.9881

Epoch 9/10


Training: 100%|██████████| 1804/1804 [04:30<00:00,  6.66it/s]


Average Train Loss: 0.0271


Validation: 100%|██████████| 226/226 [00:40<00:00,  5.60it/s]


Validation Loss: 0.0472, Accuracy: 0.9878, Macro F1 Score: 0.9853, Micro F1 Score: 0.9878

Epoch 10/10


Training:  53%|█████▎    | 963/1804 [02:36<01:34,  8.86it/s]

In [21]:
torch.save(model.state_dict(), f"{PROJECT_ROOT}/models/vit_finetuned.pth")

# Evaluation

In [10]:
test_pv_dataset = PlantDiseaseDataset(
    split = "test_pv",
    processor=processor
)

test_pd_dataset = PlantDiseaseDataset(
    split = "test_pd",
    processor=processor
)

test_pv_loader = DataLoader(
    test_pv_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_pd_loader = DataLoader(
    test_pd_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [11]:
# Sanity check
batch = next(iter(test_pd_loader))

print(batch["pixel_values"].shape) 
print(batch["labels"].shape)        

torch.Size([16, 3, 224, 224])
torch.Size([16])


In [12]:
model.eval()

def run_test(test_loader):
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            pixel_values = batch["pixel_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(pixel_values=pixel_values)
            preds = torch.argmax(outputs.logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    micro_f1 = f1_score(all_labels, all_preds, average="micro")
    report = classification_report(all_labels, all_preds, target_names=df_labels['canonical_label'].tolist(), zero_division=0)

    print(f"Test Accuracy: {accuracy:.4f}, Macro F1 Score: {macro_f1:.4f}, Micro F1 Score: {micro_f1:.4f}")
    print("Classification Report:")
    print(report)

In [13]:
run_test(test_pv_loader)

Testing: 100%|██████████| 226/226 [01:08<00:00,  3.29it/s]


Test Accuracy: 0.0458, Macro F1 Score: 0.0217, Micro F1 Score: 0.0458
Classification Report:
                                              precision    recall  f1-score   support

                           apple__apple_scab       0.09      0.11      0.10        63
                     apple__cedar_apple_rust       0.00      0.00      0.00        27
                              apple__healthy       0.00      0.00      0.00       165
                          blueberry__healthy       0.04      0.07      0.05       150
                             cherry__healthy       0.00      0.00      0.00        86
   corn__cercospora_leaf_spot_gray_leaf_spot       0.00      0.00      0.00        52
                           corn__common_rust       0.05      0.01      0.01       119
                  corn__northern_leaf_blight       0.01      0.02      0.01        99
                            grape__black_rot       0.02      0.05      0.03       118
                              grape__healthy  

In [18]:
run_test(test_pd_loader)

Testing: 100%|██████████| 14/14 [00:05<00:00,  2.38it/s]

Test Accuracy: 0.4384, Macro F1 Score: 0.3835, Micro F1 Score: 0.4384
Classification Report:
                                              precision    recall  f1-score   support

                           apple__apple_scab       1.00      0.40      0.57        10
                     apple__cedar_apple_rust       0.88      0.70      0.78        10
                              apple__healthy       0.37      0.78      0.50         9
                          blueberry__healthy       0.50      0.45      0.48        11
                             cherry__healthy       0.00      0.00      0.00        10
   corn__cercospora_leaf_spot_gray_leaf_spot       0.10      0.25      0.14         4
                           corn__common_rust       1.00      0.10      0.18        10
                  corn__northern_leaf_blight       0.41      0.58      0.48        12
                            grape__black_rot       0.62      0.62      0.62         8
                              grape__healthy  


