In [None]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random

# --- Define the Model ---
class ResNetBinaryClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.base_model = models.resnet18(pretrained=False)
        num_ftrs = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.base_model(x).squeeze()

model = ResNetBinaryClassifier()
model.load_state_dict(torch.load("resnet_tampered_patch_classifier.pth", map_location='cpu'))
model.eval()

# --- Register Grad-CAM Hooks ---
gradients = []
activations = []

def save_grad_hook(module, grad_input, grad_output):
    gradients.append(grad_output[0])

def save_act_hook(module, input, output):
    activations.append(output)

model.base_model.layer4.register_forward_hook(save_act_hook)
model.base_model.layer4.register_backward_hook(save_grad_hook)

# --- Image Transforms ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# --- Load Validation Dataset ---
val_dir = r"C:\Users\imran\OneDrive\Robotics Projects\Image Forgery\Image Data\VALIDATION_CG-1050\VALIDATION"
dataset = datasets.ImageFolder(root=val_dir, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# --- Identify Misclassified Images ---
misclassified = []

with torch.no_grad():
    for idx, (img, label) in enumerate(loader):
        output = model(img)
        pred = int(output.item() > 0.5)
        true = int(label.item())
        if pred != true:
            path = dataset.samples[idx][0]
            misclassified.append((path, true, pred))

# --- Grad-CAM Function ---
def gradcam(image_path):
    img = Image.open(image_path).convert("RGB")
    input_tensor = transform(img).unsqueeze(0)

    gradients.clear()
    activations.clear()

    output = model(input_tensor)
    model.zero_grad()
    output.backward()

    grad = gradients[0][0].detach().numpy()
    act = activations[0][0].detach().numpy()

    weights = np.mean(grad, axis=(1, 2))
    cam = np.zeros(act.shape[1:], dtype=np.float32)
    for i, w in enumerate(weights):
        cam += w * act[i]

    cam = np.maximum(cam, 0)
    cam = (cam - cam.min()) / (cam.max() + 1e-8)
    cam = cv2.resize(cam, (224, 224))

    img_np = np.array(img.resize((224, 224)))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)

    return img_np, overlay

# --- Visualize 10 Misclassified Images ---
for path, actual, pred in random.sample(misclassified, 10):
    original, cam_overlay = gradcam(path)
    fname = os.path.basename(path)

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(original)
    plt.axis("off")
    plt.title(f"{fname}\nActual: {'Tampered' if actual else 'Original'}")

    plt.subplot(1, 2, 2)
    plt.imshow(cam_overlay)
    plt.axis("off")
    plt.title(f"Predicted: {'Tampered' if pred else 'Original'} ❌")

    plt.tight_layout()
    plt.show()
