In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the ViT model with dropout
class ViTWithDropout(nn.Module):
    def __init__(self):
        super(ViTWithDropout, self).__init__()
        self.model = models.vit_l_16(weights="IMAGENET1K_V1")
        self.model.heads.head = nn.Sequential(
            nn.Dropout(0.5),  # Dropout before the final layer
            nn.Linear(self.model.heads.head.in_features, 2)  # 2 output classes
        )

    def forward(self, x):
        return self.model(x)

# Load the trained model
model = ViTWithDropout().to(device)
checkpoint_path = "/input/deepfake_model_session5.pth"  # Path to your best model checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()  # Set model to evaluation mode

# Define test dataset transformations (without augmentation)
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load test dataset
from torchvision.datasets import ImageFolder  # Update based on your dataset loading method
test_dataset = ImageFolder(root="/input/test", transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Evaluate the model on the test set and calculate confusion matrix
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# Confusion Matrix and Classification Report
conf_matrix = confusion_matrix(y_true, y_pred)
class_names = test_dataset.classes  # Assuming class names are available in dataset

print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()