In [4]:
import torch
import os
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification
from sklearn.metrics import classification_report, balanced_accuracy_score, confusion_matrix

In [5]:
# =============================
# CONFIG (Must match Training)
# =============================
BATCH_SIZE = 16
NUM_CLASSES = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "google/vit-base-patch16-224"

# PATHS
DATASET_ROOT = r"D:\ZT\Thuliyam AI\thuliyam_AI\model_training\final_dataset"
TEST_DIR = os.path.join(DATASET_ROOT, "test")
MODEL_PATH = r"D:\ZT\Thuliyam AI\thuliyam_AI\backend\model_weights\best_vit_real_vs_fake.pt"


In [6]:
# =============================
# TRANSFORMS
# =============================
val_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def run_test():
    # 1. Load Dataset
    test_ds = datasets.ImageFolder(TEST_DIR, val_tfms)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    target_names = test_ds.classes
    
    print(f"Testing on {len(test_ds)} images from: {TEST_DIR}")

    # 2. Initialize and Load Model
    model = ViTForImageClassification.from_pretrained(
        MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
    ).to(DEVICE)
    
    checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    # 3. Inference Loop
    preds, trues = [], []
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs = imgs.to(DEVICE)
            logits = model(imgs).logits
            preds.extend(logits.argmax(1).cpu().numpy())
            trues.extend(labels.cpu().numpy())

    # 4. Final Metrics Report
    print("\n" + "="*30)
    print("FINAL TEST RESULTS")
    print("="*30)
    print(f"Balanced Accuracy: {balanced_accuracy_score(trues, preds):.4f}")
    print("\nDetailed Classification Report:")
    print(classification_report(trues, preds, target_names=target_names))

    # 5. Visual Confusion Matrix
    cm = confusion_matrix(trues, preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=target_names, yticklabels=target_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Deepfake Detection: Confusion Matrix')
    plt.show()