In [None]:
# Import PyTorch core library
import torch

# Import DataLoader utility (not used directly but common in datasets)
from torch.utils.data import DataLoader

# Import Hugging Face ViT model, image processor, trainer utilities
from transformers import (
    ViTForImageClassification,
    ViTImageProcessor,
    Trainer,
    TrainingArguments
)

# Import accuracy metric from sklearn
from sklearn.metrics import accuracy_score

# Import OS module for file and directory operations
import os

# Import Dataset base class for custom dataset creation
from torch.utils.data import Dataset  

# Import PIL Image for loading images
from PIL import Image


# =========================
# MODEL & PROCESSOR
# =========================

# Hugging Face model checkpoint (pretrained ViT)
model_name = "dima806/ai_vs_real_image_detection"

# Load image processor (handles resizing, normalization, etc.)
processor = ViTImageProcessor.from_pretrained(model_name)

# Load ViT model for image classification with 2 output classes (REAL, FAKE)
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=2
)


# =========================
# FREEZE ViT BACKBONE
# =========================

# Freeze all parameters of the ViT backbone
# This prevents updating pretrained weights during training
for param in model.vit.parameters():
    param.requires_grad = False


# =========================
# CUSTOM DATASET
# =========================

# Define a custom PyTorch Dataset for image classification
class imagedataset(Dataset):

    # Constructor
    def __init__(self, root_dir, processor):

        # Store all image file paths
        self.image_path = []

        # Store corresponding labels
        self.labels = []

        # Store the image processor
        self.processor = processor

        # Enumerate folders: REAL → 0, FAKE → 1
        for label, folder in enumerate(['REAL', 'FAKE']):

            # Create full path to folder
            folder_path = os.path.join(root_dir, folder)

            # Iterate through all images in the folder
            for img_file in os.listdir(folder_path):
                self.image_path.append(
                    os.path.join(folder_path, img_file)
                )
                self.labels.append(label)

    # Return total number of samples
    def __len__(self):
        return len(self.image_path)

    # Return one processed sample
    def __getitem__(self, idx):

        # Load image and convert to RGB format
        image = Image.open(self.image_path[idx]).convert("RGB")

        # Process image using ViT image processor
        inputs = self.processor(
            image,
            return_tensors="pt"
        )

        # Create a new dictionary for model inputs
        new_inputs = {}

        # Remove batch dimension added by processor
        for key in inputs:
            new_inputs[key] = inputs[key].squeeze(0)

        # Add label tensor for training
        new_inputs["labels"] = torch.tensor(self.labels[idx])

        return new_inputs


# =========================
# DATASETS
# =========================

# Create training dataset
train_dataset = imagedataset(
    root_dir=r"/content/dataset/train",
    processor=processor
)

# Create testing/validation dataset
test_dataset = imagedataset(
    root_dir="/content/dataset/test",
    processor=processor
)


# =========================
# TRAINING ARGUMENTS
# =========================

# Define training configuration
training_args = TrainingArguments(
    output_dir="./vit_finetune",            # Directory to save checkpoints
    per_device_train_batch_size=16,         # Batch size for training
    per_device_eval_batch_size=16,          # Batch size for evaluation
    num_train_epochs=2,                     # Number of training epochs
    learning_rate=1e-4,                     # Learning rate
    eval_strategy="steps",                  # Run evaluation every few steps
    eval_steps=500,                         # Evaluate every 500 steps
    logging_dir="./logs",                   # Directory for logs
    load_best_model_at_end=True,             # Load best checkpoint after training
    report_to="none"                        # Disable Weights & Biases logging
)


# =========================
# METRICS
# =========================

# Function to compute evaluation metrics
def compute_metrics(eval_pred):

    # Unpack predictions and true labels
    logits, labels = eval_pred

    # Convert logits to predicted class indices
    preds = logits.argmax(axis=1)

    # Compute and return accuracy
    return {
        "accuracy": accuracy_score(labels, preds)
    }


# =========================
# TRAINER
# =========================

# Initialize Hugging Face Trainer
trainer = Trainer(
    model=model,                     # ViT model
    args=training_args,              # Training configuration
    train_dataset=train_dataset,     # Training data
    eval_dataset=test_dataset,       # Evaluation data
    tokenizer=processor,             # Image processor (acts as tokenizer)
    compute_metrics=compute_metrics  # Metrics function
)

# Start fine-tuning
trainer.train()
