                                                        1️ Imports & Global Config                                                       

In [21]:
# Core
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# HuggingFace ViT
from transformers import ViTForImageClassification, ViTConfig

# Metrics & utils
from sklearn.metrics import accuracy_score, confusion_matrix
from tqdm import tqdm
import numpy as np
import random
from collections import Counter

                                                        2️ Device & Reproducibility                                                        

In [22]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

Device: cuda


                                                        3️ Training Configuration                                                    

In [None]:
BATCH_SIZE = 16
EPOCHS = 5                    # Increased slightly from 3 to 5
LR = 5e-6                     # Lower LR
NUM_CLASSES = 2

MODEL_NAME = "google/vit-base-patch16-224"

TRAIN_DIR = "dataset/real-vs-fake/train"
VALID_DIR = "dataset/real-vs-fake/valid"
TEST_DIR  = "dataset/real-vs-fake/test"

MODEL_OUTPUT = "model/pretrained_vit_model_finetuned.pt"

                                                    4️ Transforms (MATCH INFERENCE EXACTLY)                                              

In [24]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

                                                        5️ Datasets & Loaders                                                        

In [25]:
train_dataset = datasets.ImageFolder(TRAIN_DIR, train_transform)
valid_dataset = datasets.ImageFolder(VALID_DIR, val_transform)
test_dataset  = datasets.ImageFolder(TEST_DIR,  val_transform)

print("Class mapping:", train_dataset.class_to_idx)
# {'fake': 0, 'real': 1}

train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_dataset, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

Class mapping: {'fake': 0, 'real': 1}


                                                6️ Class Weights (IMPORTANT – Even if Balanced)                                          

In [26]:
label_counts = Counter([label for _, label in train_dataset.samples])
print("Train class counts:", label_counts)

total = sum(label_counts.values())
weights = [
    total / label_counts[0],   # fake
    total / label_counts[1],   # real
]

class_weights = torch.tensor(weights).to(DEVICE)
print("Class weights:", class_weights)

Train class counts: Counter({0: 50000, 1: 50000})
Class weights: tensor([2., 2.], device='cuda:0')


                                                    7️ Load ViT & Fine-Tuning Strategy                                                   

In [27]:
config = ViTConfig.from_pretrained(MODEL_NAME)
config.num_labels = NUM_CLASSES

model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    config=config,
    ignore_mismatched_sizes=True
).to(DEVICE)

# Freeze everything
for p in model.parameters():
    p.requires_grad = False

# Unfreeze LAST 3 transformer blocks
for p in model.vit.encoder.layer[-3:].parameters():
    p.requires_grad = True

# Unfreeze classifier
for p in model.classifier.parameters():
    p.requires_grad = True

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


                                                8️ Optimizer, Loss (WITH LABEL SMOOTHING)                                                

In [28]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR
)

criterion = nn.CrossEntropyLoss(
    weight=class_weights,
    label_smoothing=0.1        # CRITICAL for fine-tuning
)

scaler = torch.cuda.amp.GradScaler()


  scaler = torch.cuda.amp.GradScaler()


                                                    9️ Training Loop (Stable & Correct)                                              

In [29]:
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for images, labels in loop:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
            logits = model(images).logits
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    # Validation
    model.eval()
    preds, trues, confs = [], [], []

    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            logits = model(images).logits
            probs = torch.softmax(logits, dim=1)

            conf, pred = torch.max(probs, 1)
            preds.extend(pred.cpu().numpy())
            trues.extend(labels.cpu().numpy())
            confs.extend(conf.cpu().numpy())

    print(
        f"Epoch {epoch+1} | "
        f"Train Loss: {running_loss/len(train_loader):.4f} | "
        f"Val Acc: {accuracy_score(trues, preds):.4f} | "
        f"Avg Conf: {np.mean(confs):.4f}"
    )

  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 1/5: 100%|██████████| 6250/6250 [16:48<00:00,  6.20it/s, loss=0.425]


Epoch 1 | Train Loss: 0.4149 | Val Acc: 0.9338 | Avg Conf: 0.8639


  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 2/5: 100%|██████████| 6250/6250 [13:17<00:00,  7.84it/s, loss=0.316]


Epoch 2 | Train Loss: 0.3004 | Val Acc: 0.9563 | Avg Conf: 0.8941


  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 3/5: 100%|██████████| 6250/6250 [11:21<00:00,  9.17it/s, loss=0.224]


Epoch 3 | Train Loss: 0.2708 | Val Acc: 0.9673 | Avg Conf: 0.9125


  with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
Epoch 4/5:   9%|▉         | 576/6250 [01:10<11:34,  8.17it/s, loss=0.21]  


KeyboardInterrupt: 

                                            10 Test Evaluation + Confusion Matrix (MANDATORY CHECK)                                     

In [30]:
model.eval()
test_preds, test_trues = [], []

with torch.no_grad():
    for images, labels in test_loader:
        logits = model(images.to(DEVICE)).logits
        test_preds.extend(logits.argmax(1).cpu().numpy())
        test_trues.extend(labels.numpy())

print("Test Accuracy:", accuracy_score(test_trues, test_preds))
print("Confusion Matrix:")
print(confusion_matrix(test_trues, test_preds))


Test Accuracy: 0.96995
Confusion Matrix:
[[9714  286]
 [ 315 9685]]


                                                        1️1 Save Model (.pt)                                                     

In [32]:
torch.save(
    {
        "epoch": 3,
        "model_state_dict": model.state_dict(),
        "class_mapping": {"fake": 0, "real": 1},
        "architecture": "vit-base-patch16-224",
        "val_accuracy": 0.9673
    },
    MODEL_OUTPUT
)

print("Epoch 3 model saved at:", MODEL_OUTPUT)


Epoch 3 model saved at: pretrained_vit_model_finetuned.pt


In [None]:
torch.save(model.state_dict(), MODEL_OUTPUT)
print("Saved model to:", MODEL_OUTPUT)

                                                        1️2️ Verify Model Load                                                            

In [None]:
loaded_model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    config=config,
    ignore_mismatched_sizes=True
)

loaded_model.load_state_dict(torch.load(MODEL_OUTPUT, map_location="cpu"))
loaded_model.eval()
print("PT model loaded successfully")