In [1]:
# Import libraries
import warnings
warnings.filterwarnings("ignore")
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score
from datasets import Dataset, Features
from datasets.features import Image as DatasetImage
from datasets.features import ClassLabel
from transformers import ViTImageProcessor, ViTForImageClassification, Trainer, TrainingArguments
import torch
from torchvision.transforms import Compose, Normalize, ToTensor
from PIL import Image as PILImage
from PIL import ImageFile
from pathlib import Path
from tqdm import tqdm

In [2]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
# Check MPS availability for M3 chip
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [4]:
# Load preprocessed data
dataset_dir = Path('/Users/diksha/Desktop/PROJECT/ProcessedDataset')
file_names, labels = [], []

for file in tqdm(sorted(dataset_dir.glob('*/*.*')), desc="Loading files"):
    label = str(file).split('/')[-2]  # Extracts 'Fake' or 'Real'
    labels.append(label)
    file_names.append(str(file))

# Debug: Check total files
print(f"Total files found: {len(file_names)}")

Loading files: 100%|██████████| 200/200 [00:00<00:00, 696728.24it/s]

Total files found: 200





In [5]:
# Create DataFrame (no sampling since only 200 images exist)
df = pd.DataFrame({"image": file_names, "label": labels})
df['image'] = df['image'].astype(str)  # Ensure paths are strings
print(f"Dataset shape: {df.shape}")
print("Label distribution:\n", df['label'].value_counts())

Dataset shape: (200, 2)
Label distribution:
 label
Fake    100
Real    100
Name: count, dtype: int64


In [6]:
# Create Dataset with explicit features
dataset = Dataset.from_pandas(df, features=Features({
    'image': DatasetImage(decode=False),
    'label': ClassLabel(names=['Real', 'Fake'])  # Automatically maps 'Real' -> 0, 'Fake' -> 1
}))

In [7]:
# Label mapping (for model compatibility)
labels_list = ['Real', 'Fake']
label2id = {label: i for i, label in enumerate(labels_list)}
id2label = {i: label for i, label in enumerate(labels_list)}

In [8]:
# No need for map_label2id since ClassLabel already handles it
dataset = dataset.train_test_split(test_size=0.4, shuffle=True)
train_data = dataset['train']
test_data = dataset['test']

In [9]:
# Transformations
model_str = "dima806/deepfake_vs_real_image_detection"
processor = ViTImageProcessor.from_pretrained(model_str)
image_mean, image_std = processor.image_mean, processor.image_std
normalize = Normalize(mean=image_mean, std=image_std)
_train_transforms = Compose([ToTensor(), normalize])
_val_transforms = Compose([ToTensor(), normalize])

In [18]:
def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(PILImage.open(image if isinstance(image, str) else image['path']).convert("RGB")) for image in examples['image']]
    return examples
    

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(PILImage.open(image if isinstance(image, str) else image['path']).convert("RGB")) for image in examples['image']]
    return examples

train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

In [19]:
print(train_data[0])  # Should include 'pixel_values' as a tensor

{'image': {'bytes': None, 'path': '/Users/diksha/Desktop/PROJECT/ProcessedDataset/Real/real_50314.jpg'}, 'label': 0, 'pixel_values': tensor([[[ 0.7569,  0.7412,  0.7255,  ...,  0.9451,  0.9451,  0.9451],
         [ 0.7569,  0.7412,  0.7176,  ...,  0.9451,  0.9451,  0.9451],
         [ 0.7490,  0.7412,  0.7098,  ...,  0.9451,  0.9451,  0.9451],
         ...,
         [-0.9608, -0.9529, -0.9529,  ...,  0.5137,  0.6157,  0.6706],
         [-0.9294, -0.9451, -0.9529,  ...,  0.6314,  0.7255,  0.7569],
         [-0.8745, -0.8980, -0.9294,  ...,  0.6784,  0.7490,  0.7569]],

        [[ 0.7882,  0.7725,  0.7569,  ...,  0.9922,  0.9922,  0.9922],
         [ 0.7882,  0.7725,  0.7490,  ...,  0.9922,  0.9922,  0.9922],
         [ 0.7804,  0.7725,  0.7412,  ...,  0.9922,  0.9922,  0.9922],
         ...,
         [-0.9216, -0.9137, -0.9137,  ...,  0.4980,  0.6000,  0.6549],
         [-0.8902, -0.9059, -0.9137,  ...,  0.6000,  0.6941,  0.7255],
         [-0.8353, -0.8588, -0.8902,  ...,  0.6471,  0.7

In [11]:
# Collate function
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [12]:
# Model setup
model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list)).to(device)
model.config.id2label = id2label
model.config.label2id = label2id

In [13]:
# Freeze backbone for efficiency on small dataset
for param in model.vit.parameters():
    param.requires_grad = False
for param in model.classifier.parameters():
    param.requires_grad = True

print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")

Trainable parameters: 0.00M


In [14]:
# Metrics
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    label_ids = eval_pred.label_ids
    accuracy = accuracy_score(label_ids, predictions)
    f1 = f1_score(label_ids, predictions, average='weighted')
    cm = confusion_matrix(label_ids, predictions)
    report = classification_report(label_ids, predictions, target_names=labels_list)
    print("Confusion Matrix:\n", cm)
    print("Classification Report:\n", report)
    return {"accuracy": accuracy, "f1": f1}

In [15]:
# Training arguments
args = TrainingArguments(
    output_dir="./deepfake_vs_real_image_detection",
    logging_dir='./logs',
    evaluation_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,  # Log more frequently for small dataset
    learning_rate=1e-5,  # Slightly higher than 1e-6 for faster convergence
    per_device_train_batch_size=4,  # Reduced from 16 for M3 memory
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    weight_decay=0.02,
    warmup_steps=5,  # Reduced for small dataset
    remove_unused_columns=False,
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to="none",
    fp16=False,  # MPS doesn't support fp16 well
)    

In [16]:
# Trainer
trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [21]:
trainer.save_model("./model/deepfake_vs_real_image_detection")
processor.save_pretrained("./model/deepfake_vs_real_image_detection")
print("✅ Model and Processor saved successfully!")

✅ Model and Processor saved successfully!
