In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import os
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# ==============================
# PATHS
# ==============================
BASE_PATH = "/content/authentica-ai"
MODEL_PATH = "/content/cifake_resnet18_latest.pth"
TEST_DIR = "/content/cifake_data/test"

# ==============================
# DEVICE
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================
# LOAD DATASET
# ==============================
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
])

test_dataset = datasets.ImageFolder(TEST_DIR, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("Classes:", test_dataset.classes)

# ==============================
# LOAD TRAINED MODEL
# ==============================
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 2)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

print("Model loaded successfully")

# ==============================
# EVALUATION METRICS
# ==============================
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs,1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

acc = accuracy_score(all_labels, all_preds)
prec = precision_score(all_labels, all_preds)
rec = recall_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)

print("Accuracy:",acc)
print("Precision:",prec)
print("Recall:",rec)
print("F1:",f1)

# ==============================
# CONFUSION MATRIX
# ==============================
cm = confusion_matrix(all_labels, all_preds)
os.makedirs(f"{BASE_PATH}/results/metrics", exist_ok=True)
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Real","Fake"],
            yticklabels=["Real","Fake"])
plt.title("Confusion Matrix")
plt.savefig(f"{BASE_PATH}/results/metrics/confusion_matrix.png")
plt.show()

# ==============================
# GRADCAM
# ==============================
features = None
def hook_fn(module, input, output):
    global features
    features = output

model.layer4.register_forward_hook(hook_fn)

images, labels = next(iter(test_loader))
images = images.to(device)

output = model(images)

heatmap = torch.mean(features, dim=1)[0].cpu().detach().numpy()
heatmap = np.maximum(heatmap,0)
heatmap /= np.max(heatmap)+1e-8
heatmap = cv2.resize(heatmap,(32,32))

img = images[0].cpu().permute(1,2,0).numpy()
os.makedirs(f"{BASE_PATH}/results/gradcam", exist_ok=True)
plt.imshow(img)
plt.imshow(heatmap,cmap="jet",alpha=0.4)
plt.axis("off")
plt.savefig(f"{BASE_PATH}/results/gradcam/heatmap.png")
plt.show()

print("GradCAM saved!")