In [None]:
import pandas as pd

import os
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

import string
import math

from sklearn.metrics import f1_score, accuracy_score, classification_report

from tqdm import tqdm

import timm
from timm.data import resolve_model_data_config, create_transform

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
PROJECT_ROOT = ".."

In [4]:
df_labels = pd.read_csv(f'{PROJECT_ROOT}/data/splits/label_space.csv')
NUM_LABELS = len(df_labels)

# Dataset

In [5]:
class PlantDiseaseDataset(Dataset):
    def __init__(self, split, transforms=None):
        self.split = split
        self.transforms = transforms

        self.samples = []
        self.class_to_idx = {}

        if split == "train":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_train.csv')
        elif split == "val":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_val.csv')
        elif split == "test_pv":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/pv_test.csv')
        elif split == "test_pd":
            df = pd.read_csv(f'{PROJECT_ROOT}/data/splits/plantdoc_test_mapped.csv')

        self.samples = [
            (f"{PROJECT_ROOT}/{row['filepath_rel']}", row['canonical_id'])
            for _, row in df.iterrows()
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        image = Image.open(img_path).convert("RGB")

        if self.transforms:
            image = self.transforms(image)

        return image, label

# Training

In [6]:
model = timm.create_model(
    "vit_base_patch16_224",
    pretrained=True,
    num_classes=NUM_LABELS
).to(DEVICE)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [None]:
data_config = resolve_model_data_config(model)

train_transform = create_transform(
    **data_config,
    is_training=True
)

val_transform = create_transform(
    **data_config,
    is_training=False
)

NameError: name 'model' is not defined

In [None]:
train_dataset = PlantDiseaseDataset(
    split="train",
    transforms=train_transform
)

val_dataset = PlantDiseaseDataset(
    split="val",
    processor=val_transform,
)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [None]:
# Mixed precision
scaler = torch.amp.GradScaler()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

num_epochs = 10

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # -------------------
    # Training
    # -------------------
    model.train()
    train_loss = 0.0
    
    for images, labels in tqdm(train_loader, desc="Training"):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        # Mixed precision forward
        with torch.amp.autocast(device_type=DEVICE.type):
            outputs = model(images)
            loss = criterion(outputs, labels)

        # Backward + step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Train Loss: {avg_train_loss:.4f}")

    # -------------------
    # Validation
    # -------------------
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0

    with torch.no_grad():
        all_preds = []
        all_labels = []
    
        for images, labels in tqdm(val_loader, desc="Validation"):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            with torch.amp.autocast(device_type=DEVICE.type):
                outputs = model(images)
                loss = criterion(outputs, labels)

            val_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = accuracy_score(all_labels, all_preds)
    val_macro_f1 = f1_score(all_labels, all_preds, average='macro')
    val_micro_f1 = f1_score(all_labels, all_preds, average='micro')

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, Macro F1 Score: {val_macro_f1:.4f}, Micro F1 Score: {val_micro_f1:.4f}")