In [None]:
!pip install kagglehub


In [None]:
from google.colab import files
files.upload()   # Upload your kaggle.json file


In [None]:
!mkdir -p ~/.kaggle
!mv "kaggle (1).json" ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json


In [None]:
!kaggle datasets download -d divg07/casia-20-image-tampering-detection-dataset


In [None]:
!unzip casia-20-image-tampering-detection-dataset.zip -d casia_dataset


In [None]:
import os

for root, dirs, files in os.walk("casia_dataset"):
    print(root, len(files), "files")


In [None]:
!pip install timm matplotlib opencv-python


In [None]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

data_dir = "/content/casia_dataset/CASIA2"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder(
    root=data_dir,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

print("Classes:", train_dataset.classes)


In [None]:
# Filter out _masks folder
full_dataset = datasets.ImageFolder(root=data_dir, transform=None)
# Remove _masks from classes
filtered_samples = [(p,l) for p,l in full_dataset.samples if full_dataset.classes[l] != "_masks"]
paths = [p for p,_ in filtered_samples]
labels = [l for _,l in filtered_samples]

print("Filtered dataset samples:", len(paths))


In [None]:
import timm
import torch.nn as nn

model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 2)


In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [None]:
# -------------------------
# Imports & Config
# -------------------------
import os, random, shutil
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import timm
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# -------------------------
# Basic settings
# -------------------------
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Dataset path
data_dir = "/content/casia_dataset/CASIA2"  # Only Au & Tp should remain

# -------------------------
# Remove Groundtruth masks if still present
# -------------------------
mask_dir = os.path.join(data_dir, "CASIA 2 Groundtruth")
if os.path.exists(mask_dir):
    shutil.move(mask_dir, os.path.join(data_dir, "_masks"))
    print("Moved Groundtruth masks out of training path")

# -------------------------
# Hyperparameters
# -------------------------
num_epochs = 12
batch_size = 24
img_size = 224
num_classes = 2
best_val_f1 = 0.0

# -------------------------
# Transforms
# -------------------------
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(img_size, scale=(0.7,1.0), ratio=(0.9,1.1)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), inplace=False),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize(int(img_size*1.12)),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229,0.224,0.225]),
])

 # -------------------------
# Dataset & Dataloaders (filtered)
# -------------------------
full_dataset = datasets.ImageFolder(root=data_dir, transform=None)

# Filter out _masks folder
filtered_samples = [(p,l) for p,l in full_dataset.samples if full_dataset.classes[l] != "_masks"]
paths = [p for p,_ in filtered_samples]
labels = [l for _,l in filtered_samples]

# Stratified train/val split
from collections import defaultdict
by_label = defaultdict(list)
for p,l in zip(paths, labels):
    by_label[l].append(p)

train_paths, train_labels, val_paths, val_labels = [], [], [], []
val_ratio = 0.15

for l, items in by_label.items():
    random.shuffle(items)
    n_val = max(1, int(len(items) * val_ratio))
    val_items = items[:n_val]
    train_items = items[n_val:]
    train_paths += train_items
    train_labels += [l]*len(train_items)
    val_paths += val_items
    val_labels += [l]*len(val_items)

print("Train samples:", len(train_paths))
print("Val samples:", len(val_paths))

# Custom Dataset
from torch.utils.data import Dataset
from PIL import Image

class CASIADataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx], self.paths[idx]

train_ds = CASIADataset(train_paths, train_labels, transform=train_transform)
val_ds   = CASIADataset(val_paths, val_labels, transform=val_transform)

# DataLoaders (reduce workers to 2 for Colab)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


# Model: ViT-B/16
# -------------------------
model = timm.create_model('vit_base_patch16_224', pretrained=True)
if hasattr(model, 'head'):
    model.head = nn.Linear(model.head.in_features, num_classes)
else:
    model.reset_classifier(num_classes=num_classes)
model = model.to(device)

# -------------------------
# Optimizer, Scheduler, Loss
# -------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler()

# -------------------------
# Metrics
# -------------------------
def epoch_metrics_all(preds, labels):
    preds = np.asarray(preds)
    labels = np.asarray(labels)
    acc = accuracy_score(labels, preds)
    p,r,f1,_ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
    return acc,p,r,f1

# -------------------------
# Training & Validation
# -------------------------
def train_one_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    losses, all_preds, all_labels = [], [], []
    for imgs, labels, _ in tqdm(loader, desc="Train", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast(device_type='cuda'):
            logits = model(imgs)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        losses.append(loss.item())
        all_preds.extend(logits.argmax(dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    mean_loss = float(np.mean(losses))
    return mean_loss, *epoch_metrics_all(all_preds, all_labels)

@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    losses, all_preds, all_labels = [], [], []
    for imgs, labels, _ in tqdm(loader, desc="Val", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.amp.autocast(device_type='cuda'):
            logits = model(imgs)
            loss = criterion(logits, labels)
        losses.append(loss.item())
        all_preds.extend(logits.argmax(dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    mean_loss = float(np.mean(losses))
    return mean_loss, *epoch_metrics_all(all_preds, all_labels)


def show_img_and_heatmap(img_path, cam_map, alpha=0.5):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (img_size, img_size))
    heatmap = cv2.applyColorMap(np.uint8(255*cam_map), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    overlay = np.uint8(heatmap*alpha + img*(1-alpha))
    fig, ax = plt.subplots(1,3, figsize=(12,4))
    ax[0].imshow(img); ax[0].set_title("Image"); ax[0].axis('off')
    ax[1].imshow(heatmap); ax[1].set_title("Heatmap"); ax[1].axis('off')
    ax[2].imshow(overlay); ax[2].set_title("Overlay"); ax[2].axis('off')
    plt.show()

# -------------------------
# Training loop with checkpoint
# -------------------------
save_dir = Path("/content/checkpoints_casia_vit")
save_dir.mkdir(parents=True, exist_ok=True)

for epoch in range(1,num_epochs+1):
    print(f"\n=== Epoch {epoch}/{num_epochs} ===")
    train_loss, train_acc, train_p, train_r, train_f1 = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
    val_loss, val_acc, val_p, val_r, val_f1 = validate(model, val_loader, criterion, device)
    scheduler.step()
    print(f"Train: loss {train_loss:.4f} acc {train_acc:.4f} p {train_p:.4f} r {train_r:.4f} f1 {train_f1:.4f}")
    print(f" Val : loss {val_loss:.4f} acc {val_acc:.4f} p {val_p:.4f} r {val_r:.4f} f1 {val_f1:.4f}")
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        ckpt_path = save_dir / f"best_vit_epoch{epoch}_f1{val_f1:.4f}.pt"
        torch.save({'epoch':epoch,'model_state_dict':model.state_dict(),
                    'optimizer_state_dict':optimizer.state_dict(),'val_f1':val_f1}, ckpt_path)
        print("Saved best checkpoint:", ckpt_path)



In [None]:
# ===========================================
# üîß Stable Grad-CAM Hooks for ViT-B/16 (timm)
# ===========================================
activations, gradients = None, None

# Hook the attention output layer (strong gradient signal)
target_layer = model.blocks[-1].attn  # works for all timm ViT versions

def forward_hook(module, input, output):
    global activations
    activations = output.detach()
    activations.requires_grad_(True)

def backward_hook(module, grad_in, grad_out):
    global gradients
    gradients = grad_out[0].detach()

hook_fwd = target_layer.register_forward_hook(forward_hook)
hook_bwd = target_layer.register_full_backward_hook(backward_hook)


In [None]:
def compute_vit_gradcam(model, input_tensor, target_class=None):
    global activations, gradients
    model.eval()
    activations, gradients = None, None

    input_tensor = input_tensor.to(device)
    logits = model(input_tensor)
    pred_class = logits.argmax(dim=1).item() if target_class is None else int(target_class)
    score = logits[0, pred_class]

    model.zero_grad()
    score.backward(retain_graph=True)

    # Safety checks
    if activations is None or gradients is None:
        raise RuntimeError("Hooks didn't capture activations or gradients!")

    acts = activations[0]       # (tokens, C)
    grads = gradients[0]

    # Remove CLS token if present
    if acts.shape[0] > 1:
        acts = acts[1:]
        grads = grads[1:]

    # Compute weights (mean gradients)
    weights = grads.mean(dim=0)

    # Weighted combination
    cam = (acts * weights.unsqueeze(0)).sum(dim=1).detach().cpu().numpy()

    # Reshape to patch grid (auto-detect)
    grid_size = int(np.sqrt(cam.shape[0]))
    cam = cam.reshape(grid_size, grid_size)
    cam = np.maximum(cam, 0)

    # Normalize [0,1]
    cam -= cam.min()
    if cam.max() != 0:
        cam /= cam.max()

    # Resize to full image resolution
    cam = cv2.resize(cam, (input_tensor.shape[-1], input_tensor.shape[-2]))
    return cam, pred_class


In [None]:
class_names = ['Au', 'Tp']
examples = [(p, l) for p, l in zip(val_paths, val_labels)]
random.shuffle(examples)
examples = examples[:6]

for p, l in examples:
    im = val_transform(Image.open(p).convert("RGB")).unsqueeze(0)
    cam_map, pred = compute_vit_gradcam(model, im)
    print(f"GT: {class_names[l]} | Pred: {class_names[pred]}")
    show_img_and_heatmap(p, cam_map, alpha=0.45)


In [None]:


import torch
from PIL import Image
import matplotlib.pyplot as plt
import cv2

# Define class names
class_names = ['Au', 'Tp']  # 0 -> Authentic, 1 -> Tampered

# 1Ô∏è‚É£ Load best checkpoint
best_ckpt = sorted(Path("/content/checkpoints_casia_vit").glob("best_vit_epoch*_f1*.pt"))[-1]
checkpoint = torch.load(best_ckpt, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"‚úÖ Loaded model from: {best_ckpt}")

# 2Ô∏è‚É£ Function to predict and visualize
def predict_image(img_path):
    # Load image
    img = Image.open(img_path).convert("RGB")
    img_tensor = val_transform(img).unsqueeze(0).to(device)

    # Forward pass
    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)[0]
        pred_idx = probs.argmax().item()
        confidence = probs[pred_idx].item()

    print(f"Prediction: {class_names[pred_idx]} ({confidence*100:.2f}% confidence)")

    # Compute Grad-CAM
    cam_map, _ = compute_vit_gradcam(model, img_tensor)

    # Visualization
    show_img_and_heatmap(img_path, cam_map, alpha=0.45)

# 3Ô∏è‚É£ Upload or use a path
from google.colab import files
uploaded = files.upload()

for fn in uploaded.keys():
    print(f"\nüîé Analyzing {fn} ...")
    predict_image(fn)


In [None]:
# # ===============================
# # Grad-CAM debug & visualization cell (timm 1.0.21, ViT-B/16)
# # Run this AFTER training and after you loaded the best checkpoint into `model`.
# # It requires: model, device, val_transform, val_paths, val_labels
# # ===============================
# import torch, cv2, random, numpy as np
# import matplotlib.pyplot as plt
# from PIL import Image
# from IPython.display import display

# plt.rcParams['figure.dpi'] = 100

# # ---------- utility: show image / heatmap ----------
# def show_img_and_heatmap_inline(img_path, cam_map, alpha=0.45, title=None):
#     img = cv2.imread(img_path)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     img = cv2.resize(img, (img.shape[1], img.shape[0]))  # keep original aspect
#     # resize cam_map to displayed size (use img size)
#     cam_resized = cv2.resize(cam_map, (img.shape[1], img.shape[0]))
#     heatmap = cv2.applyColorMap(np.uint8(255*cam_resized), cv2.COLORMAP_JET)
#     heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
#     overlay = np.uint8(heatmap*alpha + img*(1-alpha))
#     fig, ax = plt.subplots(1,3, figsize=(12,4))
#     ax[0].imshow(img); ax[0].set_title("Image"); ax[0].axis('off')
#     ax[1].imshow(heatmap); ax[1].set_title("Heatmap"); ax[1].axis('off')
#     ax[2].imshow(overlay); ax[2].set_title("Overlay"); ax[2].axis('off')
#     if title:
#         plt.suptitle(title)
#     plt.show()

# # ---------- helper: register hooks on a layer object ----------
# def make_hooks(target_module):
#     activations = {'val': None}
#     gradients = {'val': None}
#     def forward_hook(module, input, output):
#         # output may be tensor or tuple; keep tensor
#         out = output if isinstance(output, torch.Tensor) else output[0]
#         activations['val'] = out
#         # ensure requires_grad if needed for backward
#         try:
#             activations['val'].requires_grad_(True)
#         except Exception:
#             pass
#     def backward_hook(module, grad_in, grad_out):
#         g = grad_out[0] if isinstance(grad_out, tuple) else grad_out
#         gradients['val'] = g
#     fh = target_module.register_forward_hook(forward_hook)
#     bh = target_module.register_full_backward_hook(backward_hook)
#     return activations, gradients, fh, bh

# # ---------- robust grad-cam function that handles qkv/proj/attn outputs ----------
# def compute_gradcam_from_hooks(activations, gradients, input_tensor):
#     """
#     activations['val'], gradients['val'] expected to exist and be Tensors.
#     Handles cases:
#       - activations shape: (tokens, C) where C may be 768 or 2304 (qkv concat).
#       - gradients same shape.
#     Returns cam_map resized to (H,W) in [0,1].
#     """
#     acts = activations['val']   # shape (1, tokens, C) or (tokens, C)
#     grads = gradients['val']
#     if acts is None or grads is None:
#         raise RuntimeError("Activations or gradients are None")
#     # align shapes to (tokens, C)
#     if acts.dim() == 3:
#         acts = acts[0]
#     if grads.dim() == 3:
#         grads = grads[0]
#     # If channels are qkv concat (C == 3*embed_dim), split and use 'value' part if possible.
#     C = acts.shape[1]
#     if C % 3 == 0 and C >= 768:
#         embed = C // 3
#         # Use value component (3rd chunk)
#         v = acts[:, 2*embed:3*embed]
#         gv = grads[:, 2*embed:3*embed]
#         acts_patches = v
#         grads_patches = gv
#     else:
#         # default: use entire channels
#         acts_patches = acts
#         grads_patches = grads
#     # remove CLS token if present
#     if acts_patches.shape[0] > 1:
#         acts_patches = acts_patches[1:]
#         grads_patches = grads_patches[1:]
#     # compute weights and CAM
#     weights = grads_patches.mean(dim=0)            # (C,) importance per channel
#     cam_patches = (acts_patches * weights.unsqueeze(0)).sum(dim=1).detach().cpu().numpy()  # (N_patches,)
#     # reshape to grid
#     n_patches = cam_patches.shape[0]
#     grid_size = int(np.round(np.sqrt(n_patches)))
#     if grid_size*grid_size != n_patches:
#         # fallback: try to crop/pad to nearest square
#         gs = grid_size
#         n_needed = gs*gs
#         if n_needed > n_patches:
#             # pad with zeros
#             cam_patches = np.pad(cam_patches, (0, n_needed - n_patches), 'constant')
#         else:
#             cam_patches = cam_patches[:n_needed]
#         grid_size = gs
#     cam_map = cam_patches.reshape(grid_size, grid_size)
#     cam_map = np.maximum(cam_map, 0)
#     cam_map -= cam_map.min()
#     if cam_map.max() != 0:
#         cam_map /= cam_map.max()
#     # final map will be resized by caller to image size
#     return cam_map

# # ---------- Try multiple candidate layers and visualize results ----------
# def try_layers_and_show(image_path, candidate_layer_names=None, top_k=3, alpha=0.45):
#     """
#     candidate_layer_names: list of strings pointing to elements in model, e.g.
#       - 'blocks[-1].attn.qkv'  (may not exist as tensor output)
#       - 'blocks[-1].attn'      (attention module output)
#       - 'blocks[-1].attn.proj'
#       - 'norm'
#       - 'blocks[-1]'
#     The function locates attributes automatically and tries them.
#     """
#     if candidate_layer_names is None:
#         candidate_layer_names = [
#             'blocks[-1].attn.qkv',
#             'blocks[-1].attn',
#             'blocks[-1].attn.proj',
#             'blocks[-1].norm1',
#             'blocks[-1].norm'
#         ]
#     # prepare input tensor
#     img = Image.open(image_path).convert("RGB")
#     img_t = val_transform(img).unsqueeze(0).to(device)
#     # get model output (for pred)
#     with torch.no_grad():
#         logits = model(img_t)
#         probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
#         pred_idx = int(np.argmax(probs))
#     print(f"Model prediction: {pred_idx} ({probs[pred_idx]:.3f})")
#     results = []
#     handles = []
#     for name in candidate_layer_names:
#         # try to resolve attribute path
#         try:
#             # support negative index like blocks[-1]
#             parts = name.split('.')
#             obj = model
#             for p in parts:
#                 if p.endswith(']'):
#                     # handle index like blocks[-1]
#                     base, idx = p[:-1].split('[')
#                     idx = int(idx)
#                     obj = getattr(obj, base)[idx]
#                 else:
#                     if hasattr(obj, p):
#                         obj = getattr(obj, p)
#                     else:
#                         obj = obj  # leave as is; may be method etc.
#             target = obj
#         except Exception as e:
#             # skip if resolution fails
#             # print(f"Skip layer {name}: {e}")
#             continue
#         # register hooks
#         try:
#             activations, gradients, fh, bh = make_hooks(target)
#         except Exception as e:
#             # print(f"Could not hook {name}: {e}")
#             continue
#         # forward+backward to collect grads
#         model.zero_grad()
#         logits = model(img_t)
#         score = logits[0, pred_idx]
#         score.backward(retain_graph=True)
#         # diagnostics
#         a = activations['val']
#         g = gradients['val']
#         a_mean = float(a.abs().mean().detach().cpu().numpy()) if a is not None else None
#         g_mean = float(g.abs().mean().detach().cpu().numpy()) if g is not None else None
#         print(f"Layer '{name}': act_mean={a_mean:.6f}, grad_mean={g_mean:.6g}")
#         # if grads extremely small, still try but note it
#         try:
#             cam = compute_gradcam_from_hooks(activations, gradients, img_t)
#         except Exception as e:
#             cam = None
#             print(f"  -> compute failed for {name}: {e}")
#         results.append((name, cam, a_mean, g_mean))
#         # remove hooks
#         fh.remove(); bh.remove()
#     # show top_k visually (non-None cams)
#     shown = 0
#     for name, cam, a_mean, g_mean in results:
#         if cam is None:
#             continue
#         print(f"\nShowing layer {name} (act_mean={a_mean:.4f}, grad_mean={g_mean:.6g})")
#         show_img_and_heatmap_inline(image_path, cam, alpha=alpha, title=f"{name} act={a_mean:.3f} grad={g_mean:.6g}")
#         shown += 1
#         if shown >= top_k:
#             break
#     if shown == 0:
#         print("No usable CAMs produced. Try training more / different layers / verify gradients.")

# # ---------- small helper: single image predict + best-layer CAM ----------
# def predict_and_show(image_path, chosen_layer_name='blocks[-1].attn'):
#     # hook chosen layer and show single cam
#     parts = chosen_layer_name.split('.')
#     obj = model
#     for p in parts:
#         if p.endswith(']'):
#             base, idx = p[:-1].split('[')
#             idx = int(idx)
#             obj = getattr(obj, base)[idx]
#         else:
#             obj = getattr(obj, p)
#     activations, gradients, fh, bh = make_hooks(obj)
#     img = Image.open(image_path).convert("RGB")
#     img_t = val_transform(img).unsqueeze(0).to(device)
#     with torch.no_grad():
#         logits = model(img_t)
#         probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
#         pred_idx = int(np.argmax(probs))
#     # compute cam (need backward so remove no_grad)
#     model.zero_grad()
#     logits = model(img_t)
#     score = logits[0, pred_idx]
#     score.backward(retain_graph=True)
#     a_mean = float(activations['val'].abs().mean().detach().cpu().numpy())
#     g_mean = float(gradients['val'].abs().mean().detach().cpu().numpy())
#     cam = compute_gradcam_from_hooks(activations, gradients, img_t)
#     fh.remove(); bh.remove()
#     print(f"Prediction: {pred_idx} ({probs[pred_idx]:.3f}), act_mean={a_mean:.6f}, grad_mean={g_mean:.6g}")
#     show_img_and_heatmap_inline(image_path, cam, alpha=0.45, title=f"{chosen_layer_name}")
#     return pred_idx, probs[pred_idx]

# # ---------- USAGE ----------
# # Example: try automatic candidates on a single image (use any path from your val set)
# example_image = val_paths[0]  # change index if you want another image
# print("Testing image:", example_image)
# try_layers_and_show(example_image, candidate_layer_names=[
#     'blocks[-1].attn.qkv',
#     'blocks[-1].attn',
#     'blocks[-1].attn.proj',
#     'blocks[-1].norm1'
# ], top_k=3, alpha=0.45)

# # If one name shows a good heatmap, use:
# # predict_and_show(example_image, chosen_layer_name='blocks[-1].attn.qkv')


In [None]:
# =====================================================
# üîç Final ViT-B/16 Grad-CAM (for Deepfake/Tampered Detection)
# =====================================================
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from google.colab import files

# Class names (0 = Authentic, 1 = Tampered)
class_names = ['Au', 'Tp']

# ‚úÖ Use a mid-level transformer block for clearer spatial heatmaps
chosen_layer_name = 'blocks[-3].attn'

# -----------------------------------------------------
# Hook registration helpers
# -----------------------------------------------------
def make_hooks(target_module):
    activations, gradients = {}, {}

    def forward_hook(module, input, output):
        out = output[0] if isinstance(output, tuple) else output
        activations["val"] = out
        try:
            activations["val"].requires_grad_(True)
        except Exception:
            pass

    def backward_hook(module, grad_in, grad_out):
        g = grad_out[0] if isinstance(grad_out, tuple) else grad_out
        gradients["val"] = g

    fh = target_module.register_forward_hook(forward_hook)
    bh = target_module.register_full_backward_hook(backward_hook)
    return activations, gradients, fh, bh


# -----------------------------------------------------
# Grad-CAM computation
# -----------------------------------------------------
def compute_gradcam_from_hooks(activations, gradients, input_tensor):
    acts = activations["val"][0]   # (tokens, channels)
    grads = gradients["val"][0]
    if acts.shape[0] > 1:  # drop CLS token
        acts, grads = acts[1:], grads[1:]

    weights = grads.mean(dim=0)
    cam = (acts * weights.unsqueeze(0)).sum(dim=1).detach().cpu().numpy()
    cam = np.maximum(cam, 0)

    if cam.max() > 1e-6:
        cam /= cam.max()
    cam = cam ** 0.7  # gamma correction for better contrast

    grid = int(np.sqrt(cam.shape[0]))
    cam = cam.reshape(grid, grid)
    cam = cv2.resize(cam, (input_tensor.shape[-1], input_tensor.shape[-2]))
    return cam


# -----------------------------------------------------
# Visualization helper
# -----------------------------------------------------
def show_img_and_heatmap(img_path, cam_map, alpha=0.45):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w, _ = img.shape

    cam_resized = cv2.resize(cam_map, (w, h))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    overlay = np.uint8(heatmap * alpha + img * (1 - alpha))

    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].imshow(img); ax[0].set_title("Image"); ax[0].axis("off")
    ax[1].imshow(heatmap); ax[1].set_title("Heatmap"); ax[1].axis("off")
    ax[2].imshow(overlay); ax[2].set_title("Overlay"); ax[2].axis("off")
    plt.show()


# -----------------------------------------------------
# Prediction + Grad-CAM Visualization
# -----------------------------------------------------
def predict_and_visualize(image_path, chosen_layer_name='blocks[-3].attn'):
    # Resolve the layer path dynamically
    parts = chosen_layer_name.split('.')
    obj = model
    for p in parts:
        if p.endswith(']'):
            base, idx = p[:-1].split('[')
            idx = int(idx)
            obj = getattr(obj, base)[idx]
        else:
            obj = getattr(obj, p)

    activations, gradients, fh, bh = make_hooks(obj)

    # Preprocess image
    img = Image.open(image_path).convert("RGB")
    img_t = val_transform(img).unsqueeze(0).to(device)

    # Forward pass WITH gradients
    outputs = model(img_t)
    probs = torch.softmax(outputs, dim=1)[0]
    pred_idx = int(probs.argmax())
    confidence = probs[pred_idx].item()

    # Backpropagate to compute Grad-CAM
    model.zero_grad(set_to_none=True)
    score = outputs[0, pred_idx]
    score.backward(retain_graph=False)

    cam_map = compute_gradcam_from_hooks(activations, gradients, img_t)
    fh.remove(); bh.remove()

    label = class_names[pred_idx].upper()
    print(f"\nPrediction: {label} ({confidence*100:.2f}% confidence)")
    show_img_and_heatmap(image_path, cam_map, alpha=0.45)


# -----------------------------------------------------
# Upload and test new images
# -----------------------------------------------------
uploaded = files.upload()
for fn in uploaded.keys():
    print(f"\nüîç Analyzing {fn} ...")
    predict_and_visualize(fn, chosen_layer_name)


In [None]:
# =====================================================
# üíæ Save trained model weights for inference
# =====================================================
save_path = "/content/vit_casia_final.pth"

torch.save(model.state_dict(), save_path)
print(f"‚úÖ Model weights saved successfully at: {save_path}")


In [None]:
from google.colab import files
files.download("/content/vit_casia_final.pth")