In [1]:
#!/usr/bin/env python
"""
Fine-tune DeiT-Base (facebook/deit-base-patch16-224) on the diabetic-retinopathy
dataset stored as a single CSV (id_code, diagnosis) + images.
Save as train_deit_hf.py and `python train_deit_hf.py`
"""

# ---------------------------------------------------------------------
# 0. Std / 3rd-party imports
# ---------------------------------------------------------------------
import os, random, json
from pathlib import Path
from dataclasses import dataclass
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from transformers import TrainingArguments

from transformers import (
    AutoImageProcessor,          # a.k.a. feature extractor
    AutoModelForImageClassification,
    TrainingArguments,
    Trainer,
    default_data_collator,
    set_seed,
)
import evaluate                 # for accuracy metric
# ---------------------------------------------------------------------
# 1. Custom Dataset that returns dict(pixel_values, labels)
# ---------------------------------------------------------------------
class DRRetinaDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_dir: Path, image_processor,
                 train: bool = True):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.iproc = image_processor
        # --- minimal augments ---
        if train:
            self.augment = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
            ])
        else:
            self.augment = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
            ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.img_dir / f"{row.id_code}.png"
        img = Image.open(img_path).convert("RGB")
        img = self.augment(img)

        # `return_tensors="pt"` gives shape (1,3,224,224); we squeeze to (3,224,224)
        pixel_values = self.iproc(img, return_tensors="pt").pixel_values.squeeze(0)
        label = int(row.diagnosis)
        return {"pixel_values": pixel_values, "labels": label}


# ---------------------------------------------------------------------
# 2. Paths, hyper-params, splits
# ---------------------------------------------------------------------
CSV_PATH = "data/aptos2019-blindness-detection/train.csv"   # your CSV
IMG_DIR  = "data/aptos2019-blindness-detection/train_images" 
OUTPUT_DIR = Path("deit_retina_ckpt")
VAL_FRAC  = 0.15          # 85 % train / 15 % val
SEED      = 42
BATCH     = 32
EPOCHS    = 30
LR        = 5e-4
WD        = 1e-4
set_seed(SEED)

# ---------------------------------------------------------------------
# 3. Prepare DataFrames & label info
# ---------------------------------------------------------------------
df = pd.read_csv(CSV_PATH)
num_labels = df["diagnosis"].nunique()
train_df, val_df = train_test_split(
    df,
    test_size   = VAL_FRAC,
    stratify    = df["diagnosis"],
    random_state= SEED,
)

# ---------------------------------------------------------------------
# 4. Load processor & model
# ---------------------------------------------------------------------
processor = AutoImageProcessor.from_pretrained(
    "facebook/deit-base-patch16-224",
    # use_safetensors=True
)
model = AutoModelForImageClassification.from_pretrained(
    "facebook/deit-base-patch16-224",
    use_safetensors=True,
    # num_labels = 5,
    # id2label   = {i: str(i) for i in range(num_labels)},
    # label2id   = {str(i): i for i in range(num_labels)},
)

# ---------------------------------------------------------------------
# 5. Build Dataset objects
# ---------------------------------------------------------------------
train_ds = DRRetinaDataset(train_df, IMG_DIR, processor, train=True)
val_ds   = DRRetinaDataset(val_df,   IMG_DIR, processor, train=False)

# ---------------------------------------------------------------------
# 6. Metric
# ---------------------------------------------------------------------
acc_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return acc_metric.compute(predictions=preds, references=labels)

# ---------------------------------------------------------------------
# 7. TrainingArguments & Trainer
# ---------------------------------------------------------------------
args = TrainingArguments(
    output_dir="./deit_results",
        num_train_epochs=3,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        logging_dir="./deit_logs",
        logging_steps=500,
        eval_strategy ="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
)

trainer = Trainer(
    model           = model,
    args            = args,
    train_dataset   = train_ds,
    eval_dataset    = val_ds,
    data_collator   = default_data_collator,   # stacks pixel_values + labels
    compute_metrics = compute_metrics,
)

# ---------------------------------------------------------------------
# 8. Train & save the best model
# ---------------------------------------------------------------------
train_results = trainer.train()
print(train_results.metrics)

print("\n✅ Training complete. Evaluating best checkpoint…")
metrics = trainer.evaluate()
print(metrics)

# best model is already loaded (default). Save weights + processor for later use
processor.save_pretrained(OUTPUT_DIR / "preprocessor")
trainer.save_model(OUTPUT_DIR / "best_model")  # will create best_model/ with pytorch_model.bin


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.771303,0.76
2,No log,0.571279,0.790909
3,0.691600,0.562675,0.790909


{'train_runtime': 1710.6254, 'train_samples_per_second': 5.458, 'train_steps_per_second': 0.342, 'total_flos': 7.299364365505659e+17, 'train_loss': 0.6569902860201322, 'epoch': 3.0}

✅ Training complete. Evaluating best checkpoint…


{'eval_loss': 0.5626751184463501, 'eval_accuracy': 0.7909090909090909, 'eval_runtime': 69.8438, 'eval_samples_per_second': 7.875, 'eval_steps_per_second': 0.501, 'epoch': 3.0}


In [17]:
!pip install accelerate



In [10]:
!pip install transformers==4.54.1



In [2]:
!pip install evaluate

Collecting evaluate
  Using cached evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Using cached datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting dill (from evaluate)
  Using cached dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Using cached xxhash-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.18-py39-none-any.whl.metadata (7.5 kB)
Collecting pyarrow>=15.0.0 (from datasets>=2.0.0->evaluate)
  Using cached pyarrow-21.0.0-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill (from evaluate)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
  Using cached multiprocess-0.70.16-py39-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
  Using cached fsspec-2025.3.0-py3-none-any.whl.met