In [None]:
%pip uninstall torch torchvision torchaudio -y

In [None]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

In [None]:
%pip install scikit-learn

In [17]:
import pandas as pd

import os
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.metrics import f1_score, accuracy_score, classification_report

from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
PROJECT_ROOT = ".."

In [4]:
df_labels = pd.read_csv('../data/splits/label_space.csv')

In [5]:
NUM_LABELS = len(df_labels)

# PyTorch Dataset & DataLoader

In [6]:
class PlantDiseaseDataset(Dataset):
    def __init__(self, split, processor, transform=None):
        self.split = split
        self.processor = processor
        self.transform = transform

        self.samples = []
        self.class_to_idx = {}

        if split == "train":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_train.csv')
        elif split == "val":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_val.csv')
        elif split == "test_pv":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_test.csv')
        elif split == "test_pd":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/plantdoc_test_mapped.csv')

        self.samples = [
            (f"{PROJECT_ROOT}/{row['filepath_rel']}", row['canonical_id'])
            for _, row in df.iterrows()
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # HF processor → tensor + normalization
        encoding = self.processor(
            image,
            return_tensors="pt"
        )

        return {
            "pixel_values": encoding["pixel_values"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [7]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained(
    "google/vit-base-patch16-224-in21k"
)

train_dataset = PlantDiseaseDataset(
    split="train",
    processor=processor
)

val_dataset = PlantDiseaseDataset(
    split="val",
    processor=processor,
)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [9]:
# Sanity check
batch = next(iter(train_loader))

print(batch["pixel_values"].shape) 
print(batch["labels"].shape)        

torch.Size([16, 3, 224, 224])
torch.Size([16])


# Fine Tuning

In [10]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=NUM_LABELS,
    ignore_mismatched_sizes=True
).to(device)

Loading weights: 100%|██████████| 198/198 [00:00<00:00, 1344.35it/s, Materializing param=vit.layernorm.weight]                                 
ViTForImageClassification LOAD REPORT from: google/vit-base-patch16-224-in21k
Key                 | Status     | 
--------------------+------------+-
pooler.dense.bias   | UNEXPECTED | 
pooler.dense.weight | UNEXPECTED | 
classifier.bias     | MISSING    | 
classifier.weight   | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


In [11]:
# Freeze ViT layers
for param in model.vit.parameters():
    param.requires_grad = False

# Unfreeze classifier
for param in model.classifier.parameters():
    param.requires_grad = True

# Unfreeze last few encoder layers
for block in model.vit.encoder.layer[-4:]:
    for param in block.parameters():
        param.requires_grad = True

In [12]:
# Set up optimizer with parameter groups

# Pick out parameters to apply weight decay to
decay, no_decay = [], []

for name, param in model.named_parameters():
    if not param.requires_grad or "classifier" in name:
        continue
    
    if param.ndim == 1 or "bias" in name or "layernorm" in name.lower():
        no_decay.append(param)
    else:
        decay.append(param)

# Optimizer
optimizer = torch.optim.AdamW([
    {
        "params": decay,
        "lr": 1e-4,
        "weight_decay": 0.05
    },
    {
        "params": no_decay,
        "lr": 1e-4,
        "weight_decay": 0.0
    },
    {
        "params": model.classifier.parameters(),
        "lr": 1e-3,
        "weight_decay": 0.0
    }
])

In [13]:
# Mixed precision
scaler = torch.amp.GradScaler()

num_epochs = 10

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # -------------------
    # Training
    # -------------------
    model.train()
    train_loss = 0.0
    
    for batch in tqdm(train_loader, desc="Training"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        # Mixed precision forward
        with torch.amp.autocast(device_type=device.type):
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

        # Backward + step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Train Loss: {avg_train_loss:.4f}")

    # -------------------
    # Validation
    # -------------------
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            with torch.amp.autocast(device_type=device.type):
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss

            val_loss += loss.item()
            preds = torch.argmax(outputs.logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = correct / total
    val_macro_f1 = f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, Macro F1 Score: {val_macro_f1:.4f}")


Epoch 1/10


Training: 100%|██████████| 1804/1804 [03:32<00:00,  8.49it/s]


Average Train Loss: 0.1386


Validation: 100%|██████████| 226/226 [00:19<00:00, 11.30it/s]


Validation Loss: 0.0368, Accuracy: 0.9895, Macro F1 Score: 1.0000

Epoch 2/10


Training: 100%|██████████| 1804/1804 [03:43<00:00,  8.06it/s]


Average Train Loss: 0.0201


Validation: 100%|██████████| 226/226 [00:20<00:00, 10.98it/s]


Validation Loss: 0.0286, Accuracy: 0.9911, Macro F1 Score: 1.0000

Epoch 3/10


Training: 100%|██████████| 1804/1804 [03:41<00:00,  8.14it/s]


Average Train Loss: 0.0125


Validation: 100%|██████████| 226/226 [00:19<00:00, 11.55it/s]


Validation Loss: 0.0216, Accuracy: 0.9945, Macro F1 Score: 1.0000

Epoch 4/10


Training: 100%|██████████| 1804/1804 [03:40<00:00,  8.17it/s]


Average Train Loss: 0.0154


Validation: 100%|██████████| 226/226 [00:20<00:00, 11.01it/s]


Validation Loss: 0.0215, Accuracy: 0.9942, Macro F1 Score: 1.0000

Epoch 5/10


Training: 100%|██████████| 1804/1804 [03:42<00:00,  8.10it/s]


Average Train Loss: 0.0094


Validation: 100%|██████████| 226/226 [00:19<00:00, 11.44it/s]


Validation Loss: 0.0252, Accuracy: 0.9931, Macro F1 Score: 1.0000

Epoch 6/10


Training: 100%|██████████| 1804/1804 [03:41<00:00,  8.15it/s]


Average Train Loss: 0.0085


Validation: 100%|██████████| 226/226 [00:20<00:00, 11.26it/s]


Validation Loss: 0.0145, Accuracy: 0.9967, Macro F1 Score: 1.0000

Epoch 7/10


Training: 100%|██████████| 1804/1804 [03:41<00:00,  8.15it/s]


Average Train Loss: 0.0107


Validation: 100%|██████████| 226/226 [00:19<00:00, 11.47it/s]


Validation Loss: 0.0199, Accuracy: 0.9953, Macro F1 Score: 1.0000

Epoch 8/10


Training: 100%|██████████| 1804/1804 [03:41<00:00,  8.16it/s]


Average Train Loss: 0.0021


Validation: 100%|██████████| 226/226 [00:20<00:00, 11.15it/s]


Validation Loss: 0.0218, Accuracy: 0.9950, Macro F1 Score: 1.0000

Epoch 9/10


Training: 100%|██████████| 1804/1804 [03:40<00:00,  8.16it/s]


Average Train Loss: 0.0085


Validation: 100%|██████████| 226/226 [00:19<00:00, 11.50it/s]


Validation Loss: 0.0358, Accuracy: 0.9925, Macro F1 Score: 1.0000

Epoch 10/10


Training: 100%|██████████| 1804/1804 [03:41<00:00,  8.14it/s]


Average Train Loss: 0.0075


Validation: 100%|██████████| 226/226 [00:19<00:00, 11.40it/s]

Validation Loss: 0.0676, Accuracy: 0.9859, Macro F1 Score: 1.0000





# Evaluation

In [16]:
test_pv_dataset = PlantDiseaseDataset(
    split = "test_pv",
    processor=processor
)

test_pd_dataset = PlantDiseaseDataset(
    split = "test_pd",
    processor=processor
)

test_pv_loader = DataLoader(
    test_pv_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_pd_loader = DataLoader(
    test_pd_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [27]:
model.eval()

def run_test(test_loader):
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=pixel_values)
            preds = torch.argmax(outputs.logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    micro_f1 = f1_score(all_labels, all_preds, average="micro")
    report = classification_report(all_labels, all_preds, target_names=df_labels['canonical_label'].tolist(), zero_division=0)

    print(f"Test Accuracy: {accuracy:.4f}, Macro F1 Score: {macro_f1:.4f}, Micro F1 Score: {micro_f1:.4f}")
    print("Classification Report:")
    print(report)

In [26]:
run_test(test_pv_loader)

Testing: 100%|██████████| 226/226 [00:41<00:00,  5.44it/s]

Test Accuracy: 0.9884, Macro F1 Score: 0.9787, Micro F1 Score: 0.9884
Classification Report:
                                              precision    recall  f1-score   support

                           apple__apple_scab       1.00      0.89      0.94        63
                     apple__cedar_apple_rust       0.79      1.00      0.89        27
                              apple__healthy       1.00      1.00      1.00       165
                          blueberry__healthy       1.00      1.00      1.00       150
                             cherry__healthy       1.00      1.00      1.00        86
   corn__cercospora_leaf_spot_gray_leaf_spot       0.91      1.00      0.95        52
                           corn__common_rust       1.00      1.00      1.00       119
                  corn__northern_leaf_blight       1.00      0.95      0.97        99
                            grape__black_rot       1.00      1.00      1.00       118
                              grape__healthy  




In [28]:
run_test(test_pd_loader)

Testing: 100%|██████████| 14/14 [00:05<00:00,  2.34it/s]

Test Accuracy: 0.4612, Macro F1 Score: 0.4077, Micro F1 Score: 0.4612
Classification Report:
                                              precision    recall  f1-score   support

                           apple__apple_scab       1.00      0.10      0.18        10
                     apple__cedar_apple_rust       0.69      0.90      0.78        10
                              apple__healthy       0.33      0.56      0.42         9
                          blueberry__healthy       0.67      0.55      0.60        11
                             cherry__healthy       0.17      0.10      0.12        10
   corn__cercospora_leaf_spot_gray_leaf_spot       0.22      0.50      0.31         4
                           corn__common_rust       1.00      0.10      0.18        10
                  corn__northern_leaf_blight       0.47      0.58      0.52        12
                            grape__black_rot       0.80      0.50      0.62         8
                              grape__healthy  


