In [1]:
import os
import random
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from pathlib import Path
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
SEED = 42
BATCH_SIZE = 16
EPOCHS = 20
LR = 5e-6
PATIENCE = 3
NUM_CLASSES = 2

MODEL_NAME = "google/vit-base-patch16-224"
DATASET_ROOT = r"D:\ZT\Thuliyam AI\thuliyam_AI\model_training\final_dataset"

TRAIN_DIR = os.path.join(DATASET_ROOT, "train")
VAL_DIR   = os.path.join(DATASET_ROOT, "validation")
TEST_DIR  = os.path.join(DATASET_ROOT, "test")

MODEL_OUT = Path("model/best_vit_real_vs_fake.pt")
MODEL_OUT.parent.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
AMP = DEVICE == "cuda"

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

val_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

In [None]:
train_ds = datasets.ImageFolder(TRAIN_DIR, train_tfms)
val_ds   = datasets.ImageFolder(VAL_DIR, val_tfms)
test_ds  = datasets.ImageFolder(TEST_DIR, val_tfms)

print("Class mapping:", train_ds.class_to_idx)

train_loader = DataLoader(
    train_ds, BATCH_SIZE, shuffle=True,
    num_workers=2, pin_memory=AMP
)
val_loader = DataLoader(
    val_ds, BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=AMP
)
test_loader = DataLoader(
    test_ds, BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=AMP
)

In [None]:
labels = [lbl for _, lbl in train_ds.samples]
counts = Counter(labels)
total = sum(counts.values())

class_weights = torch.tensor(
    [total / counts[i] for i in range(NUM_CLASSES)],
    dtype=torch.float32
).to(DEVICE)

In [None]:
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
).to(DEVICE)

# Freeze backbone
for p in model.parameters():
    p.requires_grad = False

# Unfreeze last 3 transformer blocks + classifier
for p in model.vit.encoder.layer[-3:].parameters():
    p.requires_grad = True
for p in model.classifier.parameters():
    p.requires_grad = True

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR
)

criterion = nn.CrossEntropyLoss(
    weight=class_weights,
    label_smoothing=0.1
)

scaler = torch.cuda.amp.GradScaler(enabled=AMP)

In [None]:
best_val_acc = 0.0
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for imgs, labels in loop:
        imgs = imgs.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=AMP):
            logits = model(imgs).logits
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    # =============================
    # VALIDATION
    # =============================
    model.eval()
    preds, trues = [], []

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(DEVICE, non_blocking=True)
            logits = model(imgs).logits
            preds.extend(logits.argmax(1).cpu().numpy())
            trues.extend(labels.numpy())

    val_acc = accuracy_score(trues, preds)
    print(f"\nEpoch {epoch+1} | Val Accuracy: {val_acc:.4f}")

    # =============================
    # EARLY STOPPING + CHECKPOINT
    # =============================
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0

        torch.save({
            "model_state_dict": model.state_dict(),
            "architecture": MODEL_NAME,
            "num_classes": NUM_CLASSES,
            "class_mapping": train_ds.class_to_idx,
            "val_accuracy": best_val_acc,
            "epoch": epoch + 1
        }, MODEL_OUT)

        print("Best model saved")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping triggered")
            break

In [None]:
print("\nTesting BEST model...")
checkpoint = torch.load(MODEL_OUT, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

preds, trues = [], []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(DEVICE, non_blocking=True)
        logits = model(imgs).logits
        preds.extend(logits.argmax(1).cpu().numpy())
        trues.extend(labels.numpy())

print("FINAL Test Accuracy:", accuracy_score(trues, preds))