# Virtus - Deepfake Image Detection Model

In this notebook, we fine-tune a Vision Transformer (ViT) model for detecting deepfakes in images.


In [None]:
import warnings
warnings.filterwarnings('ignore')

import gc  # Garbage collection interface
import itertools
from collections import Counter

# Data handling & plotting
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Scikit-learn libraries
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    confusion_matrix,
    classification_report,
    f1_score
)

# Oversampling for imbalanced datasets
from imblearn.over_sampling import RandomOverSampler

# Hugging Face libraries
import accelerate
import evaluate
from datasets import Dataset, Image, ClassLabel
from transformers import (
    TrainingArguments,
    Trainer,
    ViTImageProcessor,
    ViTForImageClassification,
    DefaultDataCollator
)

# PyTorch core
import torch
from torch.utils.data import DataLoader

# Torchvision transforms for data augmentation
from torchvision.transforms import (
    Compose,
    CenterCrop,
    Normalize,
    RandomRotation,
    RandomHorizontalFlip,
    RandomResizedCrop,
    RandomAdjustSharpness,
    Resize,
    ToTensor
)

# PIL configuration to handle corrupted images gracefully
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Allows PIL to load partially corrupted images

## Dataset Preparation & Preprocessing

This section outlines the complete dataset preprocessing pipeline. It involves loading image paths and corresponding labels from the directory structure, applying `RandomOverSampler` to balance the dataset, and creating a Hugging Face `Dataset` object. The pipeline also includes the mapping of string labels to integer class IDs, followed by splitting the dataset into training and testing sets with a 60:40 ratio, ensuring stratification for balanced class distribution.



In [None]:
# Reference: https://huggingface.co/docs/datasets/en/image_load
image_dict = {}

from pathlib import Path
from tqdm import tqdm

file_names = []
labels = []

# Define the path to your dataset
data_path = Path('xyz')

# Iterate through all files in the dataset directory (assumes structure: root/class_name/image.jpg)
for file in tqdm(sorted(data_path.glob('*/*/*.*'))):
    label = file.parts[-2]  # Get the second-last part as the label
    labels.append(label)
    file_names.append(str(file))

# Check length consistency with total number of files & labels
print(len(file_names), len(labels))

# Create a pandas DataFrame
df = pd.DataFrame.from_dict({"image": file_names, "label": labels})
print(df.shape)

In [None]:
df.head() # Preview the DataFrame

In [None]:
df['label'].unique() # View unique class labels

In [None]:
# Random resampling of minority class to balance the dataset
y = df[['label']]
df = df.drop(['label'], axis=1)
ros = RandomOverSampler(random_state=83)
df, y_resampled = ros.fit_resample(df, y)

# Clean up temporary variables
del y
df['label'] = y_resampled
del y_resampled
gc.collect()

print(df.shape)

In [None]:
# Create a Hugging Face Dataset from a pandas DataFrame
# This is useful when transitioning from tabular data (DataFrame) to a Dataset
# for preprocessing or training with Hugging Face tools
dataset = Dataset.from_pandas(df).cast_column("image", Image())

In [None]:
dataset[0]["image"] # Display the first image in the dataset (as a PIL.Image)

In [None]:
# Extract a subset of examples (first 5) to inspect structure or labels
labels_subset = dataset[:5]
print(labels_subset)

In [None]:
# Define the list of unique labels (ensure consistency with training labels)
# In this dataset, 0 = Real, 1 = Fake
labels_list = ['Real', 'Fake']

# Create label ↔ ID mapping dictionaries
label2id, id2label = dict(), dict()
for i, label in enumerate(labels_list):
    label2id[label] = i
    id2label[i] = label

print(f"Mapping of IDs to labels:{label2id}\n")
print(f"Mapping of labels to IDs:{id2label}")

In [None]:
# Create a ClassLabel object to match string labels to integer IDs
class_labels = ClassLabel(num_classes=len(labels_list), names=labels_list)

# Mapping string labels to their corresponding IDs
def map_label2id(example):
    example["label"] = class_labels.str2int(example["label"])
    return example

# Hugging Face Dataset expects the label column to be of type ClassLabel
dataset = dataset.map(map_label2id, batched=True)
dataset = dataset.cast_column('label', ClassLabel)

# Split the dataset into training and testing sets (60:40 split), stratified by label
dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column="label")

train_data = dataset['train']
test_data = dataset['test']

#--Dataset processing is done--#

## Model Loading and Training Pipeline

This section covers the process of loading the pre-trained Vision Transformer (ViT) model, configuring the necessary training parameters, and fine-tuning the model on the prepared dataset.




In [None]:
model_str = "xyz"
processor = ViTImageProcessor.from_pretrained(model_str)

# Extract normalization parameters used during model pretraining
image_mean, image_std = processor.image_mean, processor.image_std

# Target image size for the model
size = processor.size["height"]
print("Resize target size:", size)

normalize = Normalize(mean=image_mean, std=image_std) # To make the model unbiased

# Define transformations for training data
_train_transforms = Compose(
    [
    Resize((size, size)),
    RandomRotation(90),             # Adds rotation-based augmentation
    RandomAdjustSharpness(2),       # Enhances sharpness to simulate noise variations
    ToTensor(),
    normalize
    ]
)

# Define transformations for validation data (no augmentation)
_val_transforms = Compose(
    [
    Resize((size, size)),
    ToTensor(),
    normalize
    ]
)

def train_transforms(examples):
    """
    Apply training transformations to a batch of examples.
    Converts each image to RGB and applies augmentation.
    """
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    """
    Apply validation transformations to a batch of examples.
    Converts each image to RGB and applies resizing & normalization.
    """
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [None]:
# Set custom transform functions for the training and validation datasets
train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

In [None]:
def collate_fn(examples):
    """
    Custom collate function to batch input images and labels for training.

    Returns:
        dict: A dictionary with keys:
            - 'pixel_values': Tensor of stacked image tensors
            - 'labels': Tensor of corresponding class labels
    """
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {
        "pixel_values": pixel_values,
        "labels": labels
    }

In [None]:
# Load pre-trained Vision Transformer model with the correct number of labels
model = ViTForImageClassification.from_pretrained(
    model_str,
    num_labels=len(labels_list)
)

# Map class labels to IDs and vice versa
model.config.label2id = label2id
model.config.id2label = id2label

print(f"Trainable Parameters: {model.num_parameters(only_trainable=True) / 1e6:.2f}M")

In [None]:
# Load evaluation metric
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    """
    Compute evaluation metrics for the model.

    Args:
        eval_pred (EvalPrediction): A namedtuple with 'predictions' and 'label_ids'.

    Returns:
        dict: A dictionary with accuracy score.
    """
    predictions = eval_pred.predictions
    label_ids = eval_pred.label_ids
    predicted_labels = predictions.argmax(axis=1)

    acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']
    return {"accuracy": acc_score}

In [None]:
# Define training hyperparameters and configuration
metric_name = "accuracy"
model_name = "virtus"
num_train_epochs = 2

args = TrainingArguments(
    output_dir=model_name,               # Directory to save model checkpoints
    logging_dir='./logs',                # Directory to save logs
    evaluation_strategy="epoch",         # Evaluate model at the end of each epoch
    learning_rate=1e-6,                  # Low learning rate for stable fine-tuning
    per_device_train_batch_size=32,      # Batch size for training
    per_device_eval_batch_size=8,        # Batch size for evaluation
    num_train_epochs=num_train_epochs,   # Number of training epochs
    weight_decay=0.02,                   # Helps prevent overfitting
    warmup_steps=50,                     # Warm-up steps for learning rate scheduler
    remove_unused_columns=False,         # Retain all columns (e.g., pixel_values)
    save_strategy="epoch",               # Save checkpoint every epoch
    load_best_model_at_end=True,         # Restore best model based on eval metric
    save_total_limit=1,                  # Keep only the best checkpoint to save disk space
    report_to="none"                     # Disable reporting (e.g., to WandB)
)

In [None]:
# Create a Trainer instance for fine-tuning the model
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
# Let the game begin
trainer.train()

In [None]:
# Evaluate the Post-training model's performance on the validation / test dataset
# Returns final accuracy and other defined metrics
trainer.evaluate()

##  Classification Report & Confusion Matrix

Visualize and print detailed metrics including accuracy, F1 score, and a confusion matrix.


In [None]:
# Perform predictions on the test dataset
outputs = trainer.predict(test_data)

print("Test Metrics:")
print(outputs.metrics)

In [None]:
# Preview predicted vs actual labels for the first few samples
preds = outputs.predictions.argmax(axis=1)
labels = outputs.label_ids

for i in range(5):
    print(f" Predicted: {id2label[preds[i]]} | Actual: {id2label[labels[i]]}")

In [None]:
y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(axis=1)

def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues, figsize=(10, 8)):
    """
    Plots a confusion matrix with labels and color-coded heatmap.
    """
    plt.figure(figsize=figsize)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.0f'
    thresh = cm.max() / 2.0

    # Add number labels inside heatmap
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), ha="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

# Calculate and print classification metrics
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average="macro")
print(f"Accuracy: {accuracy:.4f}")
print(f" F1 Score (macro): {f1:.4f}")

# Generate confusion matrix and classification report
if len(labels_list) <= 150:  # Avoid crashing on huge class sets
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, classes=labels_list, figsize=(8, 6))

    print("\n🧾 Classification Report:\n")
    print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))


In [None]:
trainer.save_model()