# XAl-ChangeNet Visualization
Interactively inspect predictions, Grad-CAM heatmaps, and overlays.


In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

from models.siamese_unet import SiameseResNet18UNet
from scripts.explain import grad_cam as grad_cam_fn


In [None]:
pairs_file = Path('data/xbd/pairs_example.json')  # update to actual manifest
ckpt_path = Path('checkpoints/latest.pth')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
with pairs_file.open('r', encoding='utf-8') as f:
    record = json.load(f)[0]
root = pairs_file.parent
pre_path = root / record['pre_image']
post_path = root / record['post_image']
mask_path = root / record['mask']

pre_img = np.array(Image.open(pre_path).convert('RGB'))
post_img = np.array(Image.open(post_path).convert('RGB'))
mask_img = np.array(Image.open(mask_path).convert('L')) / 255.0


In [None]:
from albumentations import Compose, Resize, Normalize

transform = Compose([Resize(512, 512), Normalize()])
pre_tensor = torch.from_numpy(transform(image=pre_img)['image'].transpose(2, 0, 1)).unsqueeze(0).float().to(device)
post_tensor = torch.from_numpy(transform(image=post_img)['image'].transpose(2, 0, 1)).unsqueeze(0).float().to(device)
mask_tensor = torch.from_numpy(mask_img).unsqueeze(0).unsqueeze(0).float().to(device)

model = SiameseResNet18UNet().to(device)
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt['model'])
model.eval()
with torch.no_grad():
    logits = model(pre_tensor, post_tensor)
    probs = torch.sigmoid(logits)
pred_mask = probs.squeeze().cpu().numpy()


In [None]:
fig, axes = plt.subplots(1, 4, figsize=(18, 6))
axes[0].imshow(pre_img); axes[0].set_title('Pre'); axes[0].axis('off')
axes[1].imshow(post_img); axes[1].set_title('Post'); axes[1].axis('off')
axes[2].imshow(mask_img, cmap='gray'); axes[2].set_title('Ground Truth'); axes[2].axis('off')
axes[3].imshow(pred_mask, cmap='magma'); axes[3].set_title('Prediction'); axes[3].axis('off')
plt.show()


In [None]:
grad_cam_fn(model, pre_tensor, post_tensor, device, Path('outputs') / 'viz_grad_cam.png', post_img)
Image.open(Path('outputs') / 'viz_grad_cam.png')
