In [3]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parents[1]  # adjust if needed
SRC_PATH = PROJECT_ROOT / "src"

sys.path.append(str(SRC_PATH))

In [5]:
import pandas as pd

import os
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

import string
import math

from sklearn.metrics import f1_score, accuracy_score, classification_report

from tqdm import tqdm

import timm
from timm.data import resolve_model_data_config, create_transform

from utils.loader_cnn import PlantDiseaseDataset

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [6]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [7]:
df_labels = pd.read_csv(f'{PROJECT_ROOT}/data/splits/label_space.csv')
NUM_LABELS = len(df_labels)

# Training

In [18]:
model = timm.create_model(
    "vit_base_patch16_224",
    pretrained=True,
    num_classes=NUM_LABELS
).to(DEVICE)

In [19]:
data_config = resolve_model_data_config(model)

train_transform = create_transform(
    **data_config,
    is_training=True
)

val_transform = create_transform(
    **data_config,
    is_training=False
)

In [20]:
train_dataset = PlantDiseaseDataset(
    csv_filepath=f'{PROJECT_ROOT}/data/splits/pv_train.csv',
    root_dir=f'{PROJECT_ROOT}',
    transform=train_transform
)

val_dataset = PlantDiseaseDataset(
    csv_filepath=f'{PROJECT_ROOT}/data/splits/pv_val.csv',
    root_dir=f'{PROJECT_ROOT}',
    transform=val_transform
)

In [21]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [24]:
for name, param in model.named_parameters():
    print(name)

cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blocks.3.norm1.weight
blocks.3.norm1.bias
blocks.3.attn.qkv.weight
blocks.3.attn.qkv.bias
blocks.3.attn.proj.wei

In [32]:
# Freeze ViT layers
for param in model.blocks.parameters():
    param.requires_grad = False

# Unfreeze classifier
for param in model.head.parameters():
    param.requires_grad = True

# Unfreeze last few encoder layers
for block in model.blocks[-4:]:
    for param in block.parameters():
        param.requires_grad = True

In [34]:
# Set up optimizer with parameter groups

# Pick out parameters to apply weight decay to
decay, no_decay = [], []

for name, param in model.named_parameters():
    if not param.requires_grad or "head" in name:
        continue
    
    if param.ndim == 1 or "bias" in name or "norm" in name.lower():
        no_decay.append(param)
    else:
        decay.append(param)

# Optimizer
optimizer = torch.optim.AdamW([
    {
        "params": decay,
        "lr": 1e-4,
        "weight_decay": 0.05
    },
    {
        "params": no_decay,
        "lr": 1e-4,
        "weight_decay": 0.0
    },
    {
        "params": model.head.parameters(),
        "lr": 1e-3,
        "weight_decay": 0.0
    }
])

In [35]:
# Mixed precision
scaler = torch.amp.GradScaler()

criterion = nn.CrossEntropyLoss()

num_epochs = 10

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # -------------------
    # Training
    # -------------------
    model.train()
    train_loss = 0.0
    
    for images, labels in tqdm(train_loader, desc="Training"):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        # Mixed precision forward
        with torch.amp.autocast(device_type=DEVICE.type):
            outputs = model(images)
            loss = criterion(outputs, labels)

        # Backward + step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Train Loss: {avg_train_loss:.4f}")

    # -------------------
    # Validation
    # -------------------
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        all_preds = []
        all_labels = []
    
        for images, labels in tqdm(val_loader, desc="Validation"):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            with torch.amp.autocast(device_type=DEVICE.type):
                outputs = model(images)
                loss = criterion(outputs, labels)

            val_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(all_labels, all_preds)
    val_macro_f1 = f1_score(all_labels, all_preds, average='macro')
    val_micro_f1 = f1_score(all_labels, all_preds, average='micro')

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, Macro F1 Score: {val_macro_f1:.4f}, Micro F1 Score: {val_micro_f1:.4f}")


Epoch 1/10


Training: 100%|██████████| 1804/1804 [05:48<00:00,  5.17it/s]


Average Train Loss: 0.2478


Validation: 100%|██████████| 226/226 [00:18<00:00, 12.35it/s]


Validation Loss: 0.0734, Accuracy: 0.9770, Macro F1 Score: 0.9715, Micro F1 Score: 0.9770

Epoch 2/10


Training: 100%|██████████| 1804/1804 [07:19<00:00,  4.11it/s]


Average Train Loss: 0.1296


Validation: 100%|██████████| 226/226 [00:15<00:00, 14.66it/s]


Validation Loss: 0.0529, Accuracy: 0.9845, Macro F1 Score: 0.9741, Micro F1 Score: 0.9845

Epoch 3/10


Training: 100%|██████████| 1804/1804 [1:21:51<00:00,  2.72s/it]     


Average Train Loss: 0.1258


Validation: 100%|██████████| 226/226 [00:18<00:00, 12.30it/s]


Validation Loss: 0.0334, Accuracy: 0.9884, Macro F1 Score: 0.9876, Micro F1 Score: 0.9884

Epoch 4/10


Training: 100%|██████████| 1804/1804 [05:13<00:00,  5.76it/s]


Average Train Loss: 0.1087


Validation: 100%|██████████| 226/226 [00:18<00:00, 12.35it/s]


Validation Loss: 0.0860, Accuracy: 0.9770, Macro F1 Score: 0.9692, Micro F1 Score: 0.9770

Epoch 5/10


Training: 100%|██████████| 1804/1804 [05:13<00:00,  5.76it/s]


Average Train Loss: 0.0987


Validation: 100%|██████████| 226/226 [00:17<00:00, 12.56it/s]


Validation Loss: 0.0296, Accuracy: 0.9917, Macro F1 Score: 0.9892, Micro F1 Score: 0.9917

Epoch 6/10


Training: 100%|██████████| 1804/1804 [05:29<00:00,  5.47it/s]


Average Train Loss: 0.0973


Validation: 100%|██████████| 226/226 [00:20<00:00, 11.18it/s]


Validation Loss: 0.0315, Accuracy: 0.9917, Macro F1 Score: 0.9884, Micro F1 Score: 0.9917

Epoch 7/10


Training: 100%|██████████| 1804/1804 [05:29<00:00,  5.47it/s]


Average Train Loss: 0.0853


Validation: 100%|██████████| 226/226 [00:18<00:00, 12.36it/s]


Validation Loss: 0.0447, Accuracy: 0.9861, Macro F1 Score: 0.9835, Micro F1 Score: 0.9861

Epoch 8/10


Training: 100%|██████████| 1804/1804 [05:12<00:00,  5.78it/s]


Average Train Loss: 0.0920


Validation: 100%|██████████| 226/226 [00:18<00:00, 12.52it/s]


Validation Loss: 0.0275, Accuracy: 0.9950, Macro F1 Score: 0.9930, Micro F1 Score: 0.9950

Epoch 9/10


Training: 100%|██████████| 1804/1804 [05:09<00:00,  5.83it/s]


Average Train Loss: 0.0806


Validation: 100%|██████████| 226/226 [00:18<00:00, 12.35it/s]


Validation Loss: 0.0426, Accuracy: 0.9875, Macro F1 Score: 0.9864, Micro F1 Score: 0.9875

Epoch 10/10


Training: 100%|██████████| 1804/1804 [05:11<00:00,  5.80it/s]


Average Train Loss: 0.0808


Validation: 100%|██████████| 226/226 [00:18<00:00, 12.40it/s]

Validation Loss: 0.0275, Accuracy: 0.9931, Macro F1 Score: 0.9911, Micro F1 Score: 0.9931





In [36]:
torch.save(model.state_dict(), f'{PROJECT_ROOT}/models/vit_base_patch16_224_timm_finetuned.pth')