**Cell 1 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

# Graph Attention Network for Histopathology Attribution
## Complete Graph with Attention-based Edge Pruning

**Architecture:**
- Patches as nodes with ViT features (768-D)
- Complete graph initialization (all nodes connected)
- Edge weights = 1 / Euclidean distance (spatial proximity)
- Node attention score = average of edge attention scores
- Dynamic edge pruning based on attention threshold
- Classification from graph-level features

**Cell 2 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from transformers import ViTModel, ViTImageProcessor
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


**Cell 3 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 1. Patch Extraction with Spatial Coordinates

**Cell 4 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [2]:
def extract_patches_with_coords(image, patch_size=224, stride=112, target_size=224):
    """
    Extract patches with their spatial coordinates for graph construction.
    
    Returns:
        patches: List of PIL images
        coordinates: List of (x, y) center coordinates for each patch
    """
    if isinstance(image, str):
        image = Image.open(image).convert('RGB')
    
    width, height = image.size
    patches = []
    coordinates = []
    
    for y in range(0, height - patch_size + 1, stride):
        for x in range(0, width - patch_size + 1, stride):
            patch = image.crop((x, y, x + patch_size, y + patch_size))
            patch = patch.resize((target_size, target_size), Image.LANCZOS)
            patches.append(patch)
            
            # Store center coordinates
            center_x = x + patch_size // 2
            center_y = y + patch_size // 2
            coordinates.append((center_x, center_y))
    
    return patches, np.array(coordinates)

print("Patch extraction with spatial coordinates defined.")

Patch extraction with spatial coordinates defined.


**Cell 5 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 2. ViT Feature Extractor (Frozen Backbone)

**Cell 6 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [3]:
! pip install scikit-image



**Cell 7 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [4]:
class ViTFeatureExtractor(nn.Module):
    """Frozen ViT backbone for patch feature extraction."""
    
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.processor = ViTImageProcessor.from_pretrained(model_name)
        self.vit = ViTModel.from_pretrained(model_name)
        
        # Freeze all parameters
        for param in self.vit.parameters():
            param.requires_grad = False
        
        self.vit.eval()
    
    def forward(self, images):
        """Extract CLS token features from patches."""
        inputs = self.processor(images=images, return_tensors="pt")
        inputs = {k: v.to(next(self.vit.parameters()).device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.vit(**inputs)
            features = outputs.last_hidden_state[:, 0, :]  # CLS token
        
        return features

print("ViT Feature Extractor defined.")

ViT Feature Extractor defined.


**Cell 8 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [5]:
# Cell 1 — imports
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms
from skimage.measure import label, regionprops, find_contours
from skimage.morphology import remove_small_objects
import matplotlib.pyplot as plt


**Cell 9 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [6]:
# Cell 2 — minimal utilities

def attention_rollout(attn_stack, add_residual=True, eps=1e-6):
    """
    attn_stack: list of (H, N, N) attention tensors from L layers after softmax.
    Returns (N,) vector of contributions from CLS to all tokens, then reshapes by caller.
    """
    if len(attn_stack) == 0:
        return None
    A = None
    for A_l in attn_stack:
        # average heads -> (N, N)
        A_l = A_l.mean(0)  # (N, N)
        if add_residual:
            A_l = A_l + torch.eye(A_l.size(-1), device=A_l.device)
            A_l = A_l / A_l.sum(dim=-1, keepdim=True).clamp_min(eps)
        A = A_l if A is None else A_l @ A
    # CLS is token 0. Return its influence on all tokens.
    return A[0]  # (N,)

def upsample_to(img_like_hw, fmap_hw, arr_chw):
    """Bilinear upsample array shaped (C, h, w) to target (H, W)."""
    t = torch.from_numpy(arr_chw).unsqueeze(0)  # (1,C,h,w)
    t = F.interpolate(t, size=img_like_hw, mode='bilinear', align_corners=False)
    return t.squeeze(0).cpu().numpy()

def binarize_instances(seg_logits, thresh=0.5, min_area=20):
    """seg_logits: (C,H,W) with either 1 channel (foreground) or 2+ for softmax."""
    if seg_logits.shape[0] == 1:
        prob = torch.sigmoid(seg_logits[0]).cpu().numpy()
    else:
        prob = F.softmax(seg_logits, dim=0)[1].cpu().numpy()  # assume channel 1 = nuclei
    mask = prob >= thresh
    mask = remove_small_objects(mask, min_size=min_area)
    lab = label(mask.astype(np.uint8), connectivity=2)
    return lab, prob

def assign_types_per_instance(type_logits, inst_map):
    """type_logits: (K,H,W), inst_map: (H,W) -> dict inst_id -> class_id"""
    if type_logits is None:
        return {}
    cls = type_logits.softmax(0).argmax(0).cpu().numpy()
    out = {}
    for k in np.unique(inst_map)[1:]:
        m = (inst_map == k)
        if m.sum() == 0:
            continue
        vals = cls[m]
        out[int(k)] = int(np.bincount(vals).argmax())
    return out

def pool_features_over_instances(feat_map_chw, inst_map):
    """feat_map_chw: (C,H,W) numpy; returns dict inst_id -> (C,) mean pooled feature."""
    out = {}
    C, H, W = feat_map_chw.shape
    for k in np.unique(inst_map)[1:]:
        m = (inst_map == k)
        if m.sum() == 0:
            continue
        out[int(k)] = feat_map_chw[:, m].mean(axis=1)
    return out


**Cell 10 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

**Cell 11 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [7]:
# # Cell 3 — CellViT wrapper
# from cellvit.modeling import CellViT256x40
# class CellViTSegmentor(nn.Module):
#     """
#     Wraps a CellViT-256x40 instance segmentation model and exposes:
#       - nuclei instances with centroids and class ids
#       - per-instance pooled embeddings (from last encoder tokens)
#       - ViT attention heatmap rolled back to pixels
#     You must adapt model construction and internal module paths to your repo.
#     """
#     def __init__(self, weights_path, device='cuda', in_size=256, patch_size=16):
#         super().__init__()
#         self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
#         self.in_size = in_size
#         self.patch_size = patch_size

#         # --- build model
#         # <<< ADAPT: import/construct your exact CellViT-256x40 class
#         # from cellvit.modeling import CellViT256x40
#         self.model = CellViT256x40(num_types=5)
#         # self.model = ...  # <<< ADAPT

#         sd = torch.load(weights_path, map_location='cpu')
#         self.model.load_state_dict(sd, strict=False)
#         self.model.eval().to(self.device)

#         # preprocessing; use model's own normalization if provided
#         self.tf = transforms.Compose([
#             transforms.Resize((in_size, in_size), interpolation=Image.BILINEAR),
#             transforms.ToTensor(),
#             transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
#         ])

#         # hooks
#         self._last_tokens = None         # (B, N, D) before head/decoder
#         self._attn_blocks = []           # list of (H, N, N) tensors across layers

#         # <<< ADAPT: attach to your ViT encoder blocks and to the token output you want to pool
#         def token_hook(_, __, out):
#             # out expected (B, N, D). If it's a tuple, take the token sequence.
#             self._last_tokens = out if isinstance(out, torch.Tensor) else out[0]
#         def attn_hook(_, __, out):
#             # out expected (B, H, N, N) *after* softmax. If you get pre-softmax, apply softmax(dim=-1).
#             if isinstance(out, torch.Tensor) and out.dim() == 4:
#                 self._attn_blocks.append(out[0])  # store first in batch; adjust if batching
#         # Examples of likely paths; change to match your code:
#         # self.model.encoder.blocks[-1].register_forward_hook(token_hook)        # <<< ADAPT
#         # for blk in self.model.encoder.blocks:                                  # <<< ADAPT
#         #     blk.attn.attn_drop.register_forward_hook(attn_hook)                # <<< ADAPT

#     @torch.no_grad()
#     def _forward_raw(self, pil_images):
#         x = torch.stack([self.tf(im) for im in pil_images]).to(self.device)
#         self._last_tokens = None
#         self._attn_blocks = []
#         y = self.model(x)
#         # <<< ADAPT: unpack outputs
#         # Expecting: seg_logits (B, Cseg, H, W); optional type_logits (B, Ctype, H, W)
#         if isinstance(y, dict):
#             seg_logits = y['seg_logits']
#             type_logits = y.get('type_logits', None)
#         elif isinstance(y, (list, tuple)):
#             seg_logits = y[0]
#             type_logits = y[1] if len(y) > 1 else None
#         else:
#             seg_logits = y
#             type_logits = None
#         tokens = self._last_tokens  # (B, N, D) or None
#         attn_stack = self._attn_blocks.copy()  # list of (H, N, N)
#         return seg_logits, type_logits, tokens, attn_stack

#     @torch.no_grad()
#     def infer_on_patch(self, pil_patch, return_overlay=True):
#         """
#         pil_patch: PIL.Image (any size). Will be resized to in_size for the model.
#         Returns:
#           nuclei: list of dicts with centroid_xy (in original patch coords), bbox, area, class_id, feat(np.ndarray)
#           attn_map_up: HxW numpy attention heatmap aligned to original patch size
#           prob_map_up: nuclei probability map aligned to original patch size
#         """
#         W0, H0 = pil_patch.size
#         seg_logits, type_logits, tokens, attn_stack = self._forward_raw([pil_patch])
#         seg_logits, type_logits = seg_logits[0], (type_logits[0] if type_logits is not None else None)

#         # 1) instance map at model resolution
#         inst_map, prob_map = binarize_instances(seg_logits, thresh=0.5, min_area=20)

#         # 2) token feature map -> pixel map
#         feat_map = None
#         if tokens is not None:
#             # tokens: (1, N, D) with N = 1 + (in_size/patch)^2. Discard CLS.
#             tokens_ = tokens[0]  # (N, D)
#             N, D = tokens_.shape
#             grid = int((N - 1) ** 0.5)
#             patch_tokens = tokens_[1:1 + grid*grid].reshape(grid, grid, D).permute(2,0,1)  # (D, h, w)
#             feat_map = patch_tokens.cpu().numpy()  # (D, h, w)
#             feat_map = upsample_to((inst_map.shape[0], inst_map.shape[1]), (grid, grid), feat_map)  # (D,H,W)

#         # 3) attention rollout -> token grid -> upsample to model H,W -> upsample to original H0,W0
#         attn_map_up = None
#         if len(attn_stack) > 0:
#             a = attention_rollout([t.softmax(-1) if t.max() > 1 else t for t in attn_stack], add_residual=True)
#             N = a.shape[0]
#             grid = int((N - 1) ** 0.5)
#             a_img = a[1:1 + grid*grid].reshape(grid, grid)  # discard CLS
#             a_img = a_img / (a_img.max() + 1e-6)
#             a_img = a_img.unsqueeze(0).unsqueeze(0)  # (1,1,h,w)
#             a_img = F.interpolate(a_img, size=inst_map.shape, mode='bilinear', align_corners=False)[0,0].cpu().numpy()
#             a_img = (a_img - a_img.min()) / (a_img.max() - a_img.min() + 1e-8)
#             # upsample to original patch size
#             a_img_t = torch.from_numpy(a_img)[None,None]
#             a_img_t = F.interpolate(a_img_t, size=(H0, W0), mode='bilinear', align_corners=False)
#             attn_map_up = a_img_t[0,0].cpu().numpy()

#         # 4) upsample prob map to original patch size
#         prob_t = torch.from_numpy(prob_map)[None,None]
#         prob_up = F.interpolate(prob_t, size=(H0, W0), mode='bilinear', align_corners=False)[0,0].cpu().numpy()

#         # 5) per-instance stats and pooled features at model H,W, then map centroids to original coords
#         nuclei = []
#         scale_y = H0 / inst_map.shape[0]
#         scale_x = W0 / inst_map.shape[1]
#         types_map = assign_types_per_instance(type_logits, inst_map)
#         pooled = pool_features_over_instances(feat_map if feat_map is not None else np.zeros((1,)+inst_map.shape), inst_map)

#         for r in regionprops(inst_map):
#             y, x = r.centroid
#             cx, cy = float(x*scale_x), float(y*scale_y)
#             k = int(r.label)
#             item = {
#                 "inst_id": k,
#                 "centroid_xy": (cx, cy),               # in original patch coords
#                 "bbox_modelhw": tuple(r.bbox),         # at model resolution
#                 "area_modelpx": int(r.area),
#                 "class_id": types_map.get(k, None),
#                 "feat": pooled.get(k, None)            # np.ndarray or None
#             }
#             nuclei.append(item)

#         # optional overlay
#         overlay = None
#         if return_overlay:
#             overlay = self._make_overlay(np.array(pil_patch), inst_map, scale_x, scale_y)

#         return nuclei, attn_map_up, prob_up, overlay

#     @staticmethod
#     def _make_overlay(rgb_uint8, inst_map, sx, sy):
#         H0, W0 = rgb_uint8.shape[:2]
#         fig, ax = plt.subplots(figsize=(5,5), dpi=120)
#         ax.imshow(rgb_uint8)
#         # draw contours for each instance after scaling
#         for k in np.unique(inst_map)[1:]:
#             m = (inst_map == k)
#             contours = find_contours(m.astype(float), 0.5)
#             for c in contours:
#                 c[:, 0] *= sy
#                 c[:, 1] *= sx
#                 ax.plot(c[:,1], c[:,0])
#         ax.set_axis_off()
#         fig.tight_layout()
#         return fig


**Cell 12 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [7]:
# Cell 4 — example usage
# 1) set your checkpoint path and construct the model
weights_path = "CellViT-256-x40.pth"  # <<< change
# seg = CellViTSegmentor(weights_path=weights_path, device='cuda', in_size=256, patch_size=16)
seg = ViTFeatureExtractor()
# 2) load a patch image (PIL). Use your own patch here.
patch = Image.open("Patches/Benign/b001_p0.png").convert("RGB")

# 3) run inference
nuclei, attn_map, prob_map, overlay_fig = seg.infer_on_patch(patch, return_overlay=True)

# 4) visualize: segmentation overlay
display(overlay_fig)
plt.close(overlay_fig)

# 5) visualize attention heatmap
if attn_map is not None:
    fig2, ax2 = plt.subplots(figsize=(5,5), dpi=120)
    ax2.imshow(patch)
    ax2.imshow(attn_map, alpha=0.4)  # default colormap
    ax2.set_axis_off()
    fig2.tight_layout()
    plt.show()

# 6) print a few nuclei records
print(f"Detected nuclei: {len(nuclei)}")
for rec in nuclei[:5]:
    print({
        "inst_id": rec["inst_id"],
        "centroid_xy": tuple(round(v,2) for v in rec["centroid_xy"]),
        "area_modelpx": rec["area_modelpx"],
        "class_id": rec["class_id"],
        "feat_shape": None if rec["feat"] is None else tuple(rec["feat"].shape)
    })


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


AttributeError: 'ViTFeatureExtractor' object has no attribute 'infer_on_patch'

**Cell 13 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

**Cell 14 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

**Cell 15 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 3. Complete Graph Construction with Euclidean Distance

**Cell 16 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
def build_complete_graph(coordinates, epsilon=1e-6):
    """
    Build complete graph with edge weights = 1 / Euclidean distance.
    
    Args:
        coordinates: (N, 2) array of (x, y) coordinates
        epsilon: Small constant to avoid division by zero
    
    Returns:
        edge_index: (2, E) tensor of edges (fully connected)
        edge_weights: (E,) tensor of edge weights (inverse distance)
    """
    N = len(coordinates)
    
    # Create all pairwise edges (complete graph)
    sources = []
    targets = []
    distances = []
    
    for i in range(N):
        for j in range(N):
            if i != j:  # No self-loops
                sources.append(i)
                targets.append(j)
                
                # Euclidean distance
                dist = np.sqrt(np.sum((coordinates[i] - coordinates[j])**2))
                distances.append(dist)
    
    edge_index = torch.tensor([sources, targets], dtype=torch.long)
    
    # Edge weights = 1 / distance (inverse distance weighting)
    distances = np.array(distances)
    edge_weights = 1.0 / (distances + epsilon)
    edge_weights = torch.tensor(edge_weights, dtype=torch.float32)
    
    # Normalize edge weights
    edge_weights = edge_weights / edge_weights.max()
    
    return edge_index, edge_weights

print("Complete graph construction defined.")

**Cell 17 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 4. Graph Attention Layer with Edge Pruning

**Cell 18 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
class GraphAttentionLayer(nn.Module):
    """
    Graph Attention Layer with:
    - Edge attention scores
    - Node attention scores (average of connected edges)
    - Dynamic edge pruning based on threshold
    """
    
    def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2, concat=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.alpha = alpha
        self.concat = concat
        
        # Learnable transformation
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        
        # Attention mechanism
        self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    
    def forward(self, h, edge_index, edge_weights, attention_threshold=0.0):
        """
        Args:
            h: (N, in_features) node features
            edge_index: (2, E) edge connections
            edge_weights: (E,) spatial edge weights
            attention_threshold: Threshold for edge pruning (0-1)
        
        Returns:
            h_prime: (N, out_features) transformed node features
            edge_attention: (E,) attention scores for edges
            node_attention: (N,) attention scores for nodes
            pruned_edge_mask: (E,) boolean mask of kept edges
        """
        # Linear transformation
        Wh = torch.mm(h, self.W)  # (N, out_features)
        
        # Compute edge attention scores
        edge_attention = self._compute_edge_attention(Wh, edge_index, edge_weights)
        
        # Apply attention threshold for edge pruning
        pruned_edge_mask = edge_attention > attention_threshold
        
        # Prune edges below threshold
        if pruned_edge_mask.sum() > 0:
            edge_index_pruned = edge_index[:, pruned_edge_mask]
            edge_attention_pruned = edge_attention[pruned_edge_mask]
        else:
            # Keep at least some edges
            top_k = max(1, int(0.1 * len(edge_attention)))
            _, top_indices = torch.topk(edge_attention, top_k)
            edge_index_pruned = edge_index[:, top_indices]
            edge_attention_pruned = edge_attention[top_indices]
            pruned_edge_mask = torch.zeros_like(edge_attention, dtype=torch.bool)
            pruned_edge_mask[top_indices] = True
        
        # Apply softmax per source node
        edge_attention_normalized = self._normalize_attention(edge_attention_pruned, edge_index_pruned)
        
        # Apply dropout
        edge_attention_normalized = F.dropout(edge_attention_normalized, self.dropout, training=self.training)
        
        # Aggregate features
        h_prime = self._aggregate(Wh, edge_index_pruned, edge_attention_normalized)
        
        # Compute node attention scores (average of connected edge attentions)
        node_attention = self._compute_node_attention(edge_attention, edge_index, h.size(0))
        
        if self.concat:
            return F.elu(h_prime), edge_attention, node_attention, pruned_edge_mask
        else:
            return h_prime, edge_attention, node_attention, pruned_edge_mask
    
    def _compute_edge_attention(self, Wh, edge_index, edge_weights):
        """Compute attention coefficients for each edge."""
        source, target = edge_index[0], edge_index[1]
        
        # Concatenate source and target features
        edge_features = torch.cat([Wh[source], Wh[target]], dim=1)  # (E, 2*out_features)
        
        # Compute attention logits
        e = self.leakyrelu(torch.matmul(edge_features, self.a).squeeze(-1))  # (E,)
        
        # Incorporate spatial edge weights
        e = e * edge_weights.to(e.device)
        
        return e
    
    def _normalize_attention(self, attention, edge_index):
        """Apply softmax normalization per source node."""
        source = edge_index[0]
        
        # Compute max per source for numerical stability
        max_attention = torch.zeros(source.max() + 1, device=attention.device)
        max_attention.index_reduce_(0, source, attention, 'amax', include_self=False)
        
        # Subtract max and exponentiate
        attention_shifted = attention - max_attention[source]
        attention_exp = torch.exp(attention_shifted)
        
        # Sum per source
        attention_sum = torch.zeros(source.max() + 1, device=attention.device)
        attention_sum.index_add_(0, source, attention_exp)
        
        # Normalize
        attention_normalized = attention_exp / (attention_sum[source] + 1e-16)
        
        return attention_normalized
    
    def _aggregate(self, Wh, edge_index, attention):
        """Aggregate neighbor features weighted by attention."""
        source, target = edge_index[0], edge_index[1]
        
        # Weighted features
        weighted_features = Wh[target] * attention.unsqueeze(-1)
        
        # Sum aggregation per node
        h_prime = torch.zeros(Wh.size(0), Wh.size(1), device=Wh.device)
        h_prime.index_add_(0, source, weighted_features)
        
        return h_prime
    
    def _compute_node_attention(self, edge_attention, edge_index, num_nodes):
        """Compute node attention as average of connected edge attentions."""
        source = edge_index[0]
        
        # Sum of edge attentions per node
        node_attention_sum = torch.zeros(num_nodes, device=edge_attention.device)
        node_attention_sum.index_add_(0, source, edge_attention)
        
        # Count edges per node
        node_degree = torch.zeros(num_nodes, device=edge_attention.device)
        node_degree.index_add_(0, source, torch.ones_like(edge_attention))
        
        # Average attention
        node_attention = node_attention_sum / (node_degree + 1e-16)
        
        return node_attention

print("Graph Attention Layer with edge pruning defined.")

**Cell 19 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 5. Multi-layer GAT Model with Classification

**Cell 20 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
class GATClassifier(nn.Module):
    """
    Graph Attention Network for Histopathology Classification.
    
    Architecture:
    - Input: Node features (768-D from ViT)
    - GAT Layer 1: 768 → 256 (multi-head attention)
    - GAT Layer 2: 256 → 128
    - Global pooling: Attention-weighted graph representation
    - Classifier: 128 → num_classes
    """
    
    def __init__(self, feature_extractor, num_classes=4, hidden_dim=256, 
                 num_heads=4, dropout=0.3):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.num_classes = num_classes
        self.num_heads = num_heads
        
        # Multi-head GAT layers
        self.gat1_heads = nn.ModuleList([
            GraphAttentionLayer(768, hidden_dim, dropout=dropout, concat=True)
            for _ in range(num_heads)
        ])
        
        self.gat2_heads = nn.ModuleList([
            GraphAttentionLayer(hidden_dim * num_heads, hidden_dim // 2, 
                              dropout=dropout, concat=True)
            for _ in range(num_heads)
        ])
        
        # Final feature dimension after concatenating heads
        final_dim = (hidden_dim // 2) * num_heads
        
        # Global attention pooling
        self.global_attention = nn.Sequential(
            nn.Linear(final_dim, hidden_dim // 2),
            nn.Tanh(),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(final_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
        # Attention threshold (learnable parameter)
        self.attention_threshold = nn.Parameter(torch.tensor(0.1))
    
    def forward(self, images, coordinates, return_attention=False):
        """
        Args:
            images: List of PIL images (patches)
            coordinates: (N, 2) numpy array of patch coordinates
            return_attention: Whether to return attention maps
        
        Returns:
            logits: (num_classes,) classification logits
            attention_info: Dict with attention scores (if return_attention=True)
        """
        # Extract features
        node_features = self.feature_extractor(images)  # (N, 768)
        
        # Build complete graph
        edge_index, edge_weights = build_complete_graph(coordinates)
        edge_index = edge_index.to(node_features.device)
        edge_weights = edge_weights.to(node_features.device)
        
        # Apply attention threshold (sigmoid to keep in [0, 1])
        threshold = torch.sigmoid(self.attention_threshold)
        
        # GAT Layer 1 (multi-head)
        gat1_outputs = []
        edge_attentions_1 = []
        node_attentions_1 = []
        edge_masks_1 = []
        
        for head in self.gat1_heads:
            h, edge_att, node_att, edge_mask = head(
                node_features, edge_index, edge_weights, threshold
            )
            gat1_outputs.append(h)
            edge_attentions_1.append(edge_att)
            node_attentions_1.append(node_att)
            edge_masks_1.append(edge_mask)
        
        # Concatenate multi-head outputs
        h1 = torch.cat(gat1_outputs, dim=1)  # (N, hidden_dim * num_heads)
        
        # GAT Layer 2 (multi-head)
        gat2_outputs = []
        edge_attentions_2 = []
        node_attentions_2 = []
        edge_masks_2 = []
        
        for head in self.gat2_heads:
            h, edge_att, node_att, edge_mask = head(
                h1, edge_index, edge_weights, threshold
            )
            gat2_outputs.append(h)
            edge_attentions_2.append(edge_att)
            node_attentions_2.append(node_att)
            edge_masks_2.append(edge_mask)
        
        # Concatenate multi-head outputs
        h2 = torch.cat(gat2_outputs, dim=1)  # (N, final_dim)
        
        # Global attention pooling
        global_att_scores = self.global_attention(h2)  # (N, 1)
        global_att_weights = F.softmax(global_att_scores, dim=0)  # (N, 1)
        
        # Graph-level representation
        graph_features = torch.sum(global_att_weights * h2, dim=0)  # (final_dim,)
        
        # Classification
        logits = self.classifier(graph_features)  # (num_classes,)
        
        if return_attention:
            # Average node attention across heads
            node_attention_avg = torch.stack(node_attentions_2).mean(dim=0)
            
            attention_info = {
                'node_attention': node_attention_avg.detach().cpu().numpy(),
                'global_attention': global_att_weights.squeeze().detach().cpu().numpy(),
                'edge_attentions_layer1': [e.detach().cpu().numpy() for e in edge_attentions_1],
                'edge_attentions_layer2': [e.detach().cpu().numpy() for e in edge_attentions_2],
                'edge_masks_layer1': [m.detach().cpu().numpy() for m in edge_masks_1],
                'edge_masks_layer2': [m.detach().cpu().numpy() for m in edge_masks_2],
                'threshold': threshold.item(),
                'num_edges_kept': [m.sum().item() for m in edge_masks_2]
            }
            return logits, attention_info
        
        return logits

print("GAT Classifier model defined.")

**Cell 21 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 6. Dataset and DataLoader

**Cell 22 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
class HistoGraphDataset(Dataset):
    """Dataset for histopathology images as graphs."""
    
    def __init__(self, root_dir, class_names=['Benign', 'InSitu', 'Invasive', 'Normal']):
        self.root_dir = root_dir
        self.class_names = class_names
        self.class_to_idx = {name: idx for idx, name in enumerate(class_names)}
        
        # Collect all image paths
        self.samples = []
        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            if os.path.exists(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.endswith(('.tif', '.png', '.jpg', '.jpeg')):
                        img_path = os.path.join(class_dir, img_name)
                        self.samples.append((img_path, self.class_to_idx[class_name]))
        
        print(f"Found {len(self.samples)} images across {len(class_names)} classes.")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Extract patches with coordinates
        patches, coordinates = extract_patches_with_coords(img_path)
        
        return patches, coordinates, label, img_path

print("Dataset class defined.")

**Cell 23 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 7. Training Configuration

**Cell 24 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Hyperparameters
NUM_CLASSES = 4
HIDDEN_DIM = 256
NUM_HEADS = 4
DROPOUT = 0.3
LEARNING_RATE = 1e-4
NUM_EPOCHS = 15
BATCH_SIZE = 1  # Process one image at a time (variable number of patches)

# Dataset paths
TRAIN_DIR = '/home/pclab/Desktop/WORK/histoAttribution/Photos'
CLASS_NAMES = ['Benign', 'InSitu', 'Invasive', 'Normal']

print("Configuration set.")

**Cell 25 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 8. Initialize Model and Dataset

**Cell 26 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Initialize feature extractor
print("Loading ViT feature extractor...")
feature_extractor = ViTFeatureExtractor().to(device)

# Initialize GAT model
print("Initializing GAT Classifier...")
model = GATClassifier(
    feature_extractor=feature_extractor,
    num_classes=NUM_CLASSES,
    hidden_dim=HIDDEN_DIM,
    num_heads=NUM_HEADS,
    dropout=DROPOUT
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Load dataset
print("\nLoading dataset...")
dataset = HistoGraphDataset(TRAIN_DIR, CLASS_NAMES)

# Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

**Cell 27 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 9. Training Loop

**Cell 28 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LEARNING_RATE
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3
)

# Training history
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'threshold': [], 'avg_edges_kept': []
}

print("Starting training...\n")

**Cell 29 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
def train_epoch(model, dataset, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for patches, coordinates, label, _ in tqdm(dataset, desc="Training"):
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(patches, coordinates)
        
        # Compute loss
        label = torch.tensor([label], dtype=torch.long).to(device)
        loss = criterion(logits.unsqueeze(0), label)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        pred = torch.argmax(logits)
        correct += (pred == label).sum().item()
        total += 1
    
    avg_loss = total_loss / total
    accuracy = correct / total
    
    return avg_loss, accuracy

def validate_epoch(model, dataset, criterion, device):
    """Validate for one epoch."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    total_edges_kept = 0
    
    with torch.no_grad():
        for patches, coordinates, label, _ in tqdm(dataset, desc="Validation"):
            # Forward pass
            logits, attention_info = model(patches, coordinates, return_attention=True)
            
            # Compute loss
            label_tensor = torch.tensor([label], dtype=torch.long).to(device)
            loss = criterion(logits.unsqueeze(0), label_tensor)
            
            # Metrics
            total_loss += loss.item()
            pred = torch.argmax(logits).item()
            correct += (pred == label)
            total += 1
            
            all_preds.append(pred)
            all_labels.append(label)
            
            # Track edges kept
            total_edges_kept += np.mean(attention_info['num_edges_kept'])
    
    avg_loss = total_loss / total
    accuracy = correct / total
    avg_edges_kept = total_edges_kept / total
    
    return avg_loss, accuracy, all_preds, all_labels, avg_edges_kept

print("Training and validation functions defined.")

**Cell 30 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
import os
import torch

# Ensure checkpoint dir
os.makedirs('checkpoints', exist_ok=True)
checkpoint_path = 'checkpoints/GAT_best.pth'

# Optionally resume
best_val_acc = 0.0
if os.path.exists(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(ckpt['model_state_dict'])
    if 'optimizer_state_dict' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    print(f"Loaded checkpoint: {checkpoint_path}")
    print(f"Stored Best Val Acc: {ckpt.get('val_acc', 0.0):.4f}")
    # Re-evaluate current validation accuracy
    val_loss0, val_acc0, _, _, avg_edges0 = validate_epoch(model, val_dataset, criterion, device)
    print(f"Re-evaluated Val Loss: {val_loss0:.4f} | Val Acc: {val_acc0:.4f}")
    print(f"Attention Threshold: {torch.sigmoid(model.attention_threshold).item():.4f}")
    print(f"Average Edges Kept: {avg_edges0:.1f}")
    best_val_acc = max(ckpt.get('val_acc', 0.0), val_acc0)
else:
    print("No previous checkpoint found. Starting from scratch.")

    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        print(f"{'='*60}")
        
        # Train
        train_loss, train_acc = train_epoch(model, train_dataset, criterion, optimizer, device)
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        
        # Validate
        val_loss, val_acc, val_preds, val_labels, avg_edges = validate_epoch(
            model, val_dataset, criterion, device
        )
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print(f"Attention Threshold: {torch.sigmoid(model.attention_threshold).item():.4f}")
        print(f"Average Edges Kept: {avg_edges:.1f}")
        
        # Update learning rate
        scheduler.step(val_acc)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['threshold'].append(torch.sigmoid(model.attention_threshold).item())
        history['avg_edges_kept'].append(avg_edges)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'history': history
            }, checkpoint_path)
            print(f"✓ Best model saved! (Val Acc: {val_acc:.4f})")

    print("\n" + "="*60)
    print(f"Training completed! Best Val Acc: {best_val_acc:.4f}")
    print("="*60)


**Cell 31 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

**Cell 32 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

**Cell 33 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 10. Training Visualization

**Cell 34 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# ===== Plot using history loaded from checkpoint (standalone) =====
import os, torch
import matplotlib.pyplot as plt

checkpoint_path = 'checkpoints/GAT_best.pth'

def load_history(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"No checkpoint at {path}")
    ckpt = torch.load(path, map_location=device, weights_only=False)
    hist = ckpt.get('history', {})
    # ensure all keys exist
    for k in ['train_loss','train_acc','val_loss','val_acc','threshold','avg_edges_kept']:
        hist.setdefault(k, [])
    return hist

# prefer in-memory `history` if present and non-empty, else load from disk
try:
    # If history exists and has data, use it; otherwise load from checkpoint
    if 'history' in locals() and sum(len(v) for v in history.values()) > 0:
        use_hist = history
    else:
        use_hist = load_history(checkpoint_path)
except NameError:
    use_hist = load_history(checkpoint_path)

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

axes[0, 0].plot(use_hist['train_loss'], label='Train Loss', marker='o')
axes[0, 0].plot(use_hist['val_loss'], label='Val Loss', marker='s')
axes[0, 0].set_xlabel('Epoch'); axes[0, 0].set_ylabel('Loss'); axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(use_hist['train_acc'], label='Train Acc', marker='o')
axes[0, 1].plot(use_hist['val_acc'], label='Val Acc', marker='s')
axes[0, 1].set_xlabel('Epoch'); axes[0, 1].set_ylabel('Accuracy'); axes[0, 1].set_title('Training and Validation Accuracy')
axes[0, 1].legend(); axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(use_hist['threshold'], marker='o')
axes[1, 0].set_xlabel('Epoch'); axes[1, 0].set_ylabel('Threshold Value'); axes[1, 0].set_title('Attention Threshold Evolution')
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(use_hist['avg_edges_kept'], marker='o')
axes[1, 1].set_xlabel('Epoch'); axes[1, 1].set_ylabel('Number of Edges'); axes[1, 1].set_title('Average Edges Kept After Pruning')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('plots/GAT_training_curves.png', dpi=300, bbox_inches='tight')
plt.show()
print("Training curves saved to plots/GAT_training_curves.png")


**Cell 35 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 11. Confusion Matrix and Classification Report

**Cell 36 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Load best model
checkpoint = torch.load('checkpoints/GAT_best.pth', map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")

# Get predictions on validation set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for patches, coordinates, label, _ in tqdm(val_dataset, desc="Final Evaluation"):
        logits = model(patches, coordinates)
        pred = torch.argmax(logits).item()
        all_preds.append(pred)
        all_labels.append(label)

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.title('GAT Confusion Matrix (Validation Set)')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig('plots/GAT_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

# Classification report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))

print(f"\nFinal Validation Accuracy: {accuracy_score(all_labels, all_preds):.4f}")

**Cell 37 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 12. Attention Visualization

**Cell 38 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
def visualize_graph_attention(model, image_path, class_name, device):
    """
    Visualize graph structure and attention for a sample image.
    """
    model.eval()
    
    # Extract patches
    patches, coordinates = extract_patches_with_coords(image_path)
    
    # Get predictions and attention
    with torch.no_grad():
        logits, attention_info = model(patches, coordinates, return_attention=True)
    
    pred_class = torch.argmax(logits).item()
    confidence = F.softmax(logits, dim=0)[pred_class].item()
    
    # Create visualization
    fig = plt.figure(figsize=(20, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
    
    # Original image
    ax1 = fig.add_subplot(gs[0, 0])
    img = Image.open(image_path).convert('RGB')
    ax1.imshow(img)
    ax1.set_title(f'Original Image\nTrue: {class_name} | Pred: {CLASS_NAMES[pred_class]} ({confidence:.2f})')
    ax1.axis('off')
    
    # Node attention heatmap
    ax2 = fig.add_subplot(gs[0, 1])
    node_attention = attention_info['node_attention']
    scatter = ax2.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=node_attention, cmap='hot', s=100, alpha=0.7)
    ax2.set_title('Node Attention Scores\n(Avg of Edge Attentions)')
    ax2.invert_yaxis()
    plt.colorbar(scatter, ax=ax2)
    
    # Global attention heatmap
    ax3 = fig.add_subplot(gs[0, 2])
    global_attention = attention_info['global_attention']
    scatter = ax3.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=global_attention, cmap='viridis', s=100, alpha=0.7)
    ax3.set_title('Global Pooling Attention')
    ax3.invert_yaxis()
    plt.colorbar(scatter, ax=ax3)
    
    # Graph structure (pruned edges)
    ax4 = fig.add_subplot(gs[1, :])
    
    # Build edge list from kept edges (using first head)
    edge_index, _ = build_complete_graph(coordinates)
    edge_mask = attention_info['edge_masks_layer2'][0]
    kept_edges = edge_index[:, edge_mask]
    
    # Plot nodes
    scatter = ax4.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=node_attention, cmap='hot', s=150, alpha=0.8, 
                         edgecolors='black', linewidths=2, zorder=3)
    
    # Plot kept edges
    for i in range(kept_edges.shape[1]):
        src, tgt = kept_edges[0, i], kept_edges[1, i]
        ax4.plot([coordinates[src, 0], coordinates[tgt, 0]], 
                [coordinates[src, 1], coordinates[tgt, 1]], 
                'gray', alpha=0.9, linewidth=1, zorder=1)
    
    ax4.set_title(f'Graph Structure (Pruned)\nThreshold: {attention_info["threshold"]:.3f} | '
                 f'Edges Kept: {np.mean(attention_info["num_edges_kept"]):.0f}')
    ax4.invert_yaxis()
    plt.colorbar(scatter, ax=ax4, label='Node Attention')
    
    plt.savefig(f'attention_viz_GAT_{class_name}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nAttention Statistics:")
    print(f"  Node attention range: [{node_attention.min():.4f}, {node_attention.max():.4f}]")
    print(f"  Global attention range: [{global_attention.min():.4f}, {global_attention.max():.4f}]")
    print(f"  Edges kept: {np.mean(attention_info['num_edges_kept']):.0f} / {len(edge_mask)}")
    print(f"  Top-5 nodes by attention: {np.argsort(node_attention)[-5:][::-1]}")

print("Visualization function defined.")

**Cell 39 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Visualize samples from each class
for class_name in CLASS_NAMES:
    class_dir = os.path.join(TRAIN_DIR, class_name)
    if os.path.exists(class_dir):
        images = [f for f in os.listdir(class_dir) if f.endswith(('.tif', '.png', '.jpg'))]
        if images:
            sample_image = os.path.join(class_dir, images[0])
            print(f"\n{'='*60}")
            print(f"Visualizing: {class_name}")
            print(f"{'='*60}")
            visualize_graph_attention(model, sample_image, class_name, device)

**Cell 40 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 13. Save Final Results

**Cell 41 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
import json

# Save training history
with open('checkpoints/GAT_history.json', 'w') as f:
    json.dump(history, f, indent=4)

print("Training history saved to checkpoints/GAT_history.json")

# Print final summary
print("\n" + "="*60)
print("FINAL RESULTS - GAT with Attention-based Edge Pruning")
print("="*60)
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
# print(f"Final Attention Threshold: {history['threshold'][-1]:.4f}")
# print(f"Final Edges Kept (avg): {history['avg_edges_kept'][-1]:.1f}")
# print(f"\nModel Architecture:")
print(f"  - Input features: 768 (ViT)")
print(f"  - Hidden dimension: {HIDDEN_DIM}")
print(f"  - Number of heads: {NUM_HEADS}")
print(f"  - Trainable parameters: {trainable_params:,}")
print("="*60)

**Cell 42 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## Summary

### Key Features Implemented:

1. **Complete Graph Initialization**
   - All patches (nodes) connected to each other
   - Edge weights = 1 / Euclidean distance (spatial proximity)

2. **Graph Attention Mechanism**
   - Multi-head attention (4 heads)
   - Edge-level attention scores
   - Node attention = average of connected edge attentions

3. **Dynamic Edge Pruning**
   - Learnable attention threshold
   - Edges below threshold are removed
   - Graph becomes sparser during training

4. **Two-layer GAT**
   - Layer 1: 768 → 256 (feature transformation)
   - Layer 2: 256 → 128 (refinement)
   - Global attention pooling for classification

5. **Interpretability**
   - Visualize node attention scores
   - Visualize graph structure evolution
   - Track edge pruning over epochs

### Advantages over MIL:
- **Spatial relationships**: Edges encode patch proximity
- **Message passing**: Information flows between neighboring patches
- **Adaptive structure**: Graph topology learned during training
- **Richer representations**: Multi-hop reasoning through graph layers

**Cell 43 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 14. ViT + GAT Attention Fusion for Enhanced Attribution

This section extracts:
- **ViT Internal Attention**: Per-patch attention weights from transformer blocks
- **GAT Node Attention**: Graph-based spatial attention
- **Fused Attention**: Combined multi-level attribution map

**Cell 44 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
class ViTFeatureExtractorWithAttention(nn.Module):
    """
    ViT Feature Extractor that also returns internal attention weights.
    This allows us to see which parts of each patch the ViT focuses on.
    """
    
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.processor = ViTImageProcessor.from_pretrained(model_name)
        self.vit = ViTModel.from_pretrained(model_name, output_attentions=True)
        
        # Freeze all parameters
        for param in self.vit.parameters():
            param.requires_grad = False
        
        self.vit.eval()
    
    def forward(self, images, return_attention=False):
        """
        Extract features and optionally attention weights.
        
        Returns:
            features: (N, 768) CLS token features
            attentions: List of (N, num_heads, num_patches, num_patches) attention matrices
        """
        inputs = self.processor(images=images, return_tensors="pt")
        inputs = {k: v.to(next(self.vit.parameters()).device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.vit(**inputs)
            features = outputs.last_hidden_state[:, 0, :]  # CLS token
            
            if return_attention:
                # Get attention from all layers
                # Shape: (num_layers, batch_size, num_heads, seq_len, seq_len)
                attentions = outputs.attentions
                return features, attentions
            
            return features
    
    def get_patch_attention_scores(self, images):
        """
        Get aggregated attention scores for each whole-slide patch.
        
        Args:
            images: List of PIL images (the whole-slide patches)
        
        Returns:
            patch_attention: (N,) average attention score per whole-slide patch
                            where N is the number of whole-slide patches
        """
        _, attentions = self.forward(images, return_attention=True)
        
        # Aggregate attention across layers and heads for EACH whole-slide patch
        # attentions is a tuple of (num_layers,) where each element has shape:
        # (batch_size, num_heads, seq_len, seq_len)
        
        patch_attention_scores = []
        
        # Process each whole-slide patch (each element in the batch)
        num_whole_patches = len(images)
        
        for patch_idx in range(num_whole_patches):
            all_layer_attentions = []
            
            for layer_attention in attentions:
                # layer_attention shape: (batch_size, num_heads, seq_len, seq_len)
                # Get attention for this specific whole-slide patch
                patch_att = layer_attention[patch_idx]  # (num_heads, seq_len, seq_len)
                
                # Average across heads
                head_avg = patch_att.mean(dim=0)  # (seq_len, seq_len)
                
                # Get attention from CLS token (position 0) to all internal patches
                # and sum/average to get overall importance of this whole-slide patch
                cls_to_patches = head_avg[0, 1:]  # (num_internal_patches,)
                
                # Aggregate: mean attention from CLS to all internal patches
                patch_importance = cls_to_patches.mean()
                
                all_layer_attentions.append(patch_importance)
            
            # Average across all layers for this whole-slide patch
            avg_attention = torch.stack(all_layer_attentions).mean()
            patch_attention_scores.append(avg_attention.item())
        
        return np.array(patch_attention_scores)

print("ViT Feature Extractor with Attention defined.")

**Cell 45 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
def fuse_vit_gat_attention(vit_attention, gat_attention, fusion_method='multiplicative', alpha=0.5):
    """
    Fuse ViT internal attention with GAT node attention.
    
    Args:
        vit_attention: (N,) ViT patch attention scores
        gat_attention: (N,) GAT node attention scores
        fusion_method: 'additive', 'multiplicative', or 'weighted'
        alpha: Weight for additive/weighted fusion (0-1)
    
    Returns:
        fused_attention: (N,) Combined attention scores
    """
    # Normalize both attention maps to [0, 1]
    vit_norm = (vit_attention - vit_attention.min()) / (vit_attention.max() - vit_attention.min() + 1e-8)
    gat_norm = (gat_attention - gat_attention.min()) / (gat_attention.max() - gat_attention.min() + 1e-8)
    
    if fusion_method == 'additive':
        # Weighted sum
        fused = alpha * vit_norm + (1 - alpha) * gat_norm
    
    elif fusion_method == 'multiplicative':
        # Element-wise multiplication (highlights regions with high attention in BOTH)
        fused = vit_norm * gat_norm
        # Normalize again
        fused = (fused - fused.min()) / (fused.max() - fused.min() + 1e-8)
    
    elif fusion_method == 'weighted':
        # Weighted geometric mean
        fused = (vit_norm ** alpha) * (gat_norm ** (1 - alpha))
    
    else:
        raise ValueError(f"Unknown fusion method: {fusion_method}")
    
    return fused

print("Attention fusion function defined.")

**Cell 46 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
def visualize_fused_attention(model, vit_with_attention, image_path, class_name, device, 
                              fusion_methods=['additive', 'multiplicative', 'weighted']):
    """
    Extract and visualize ViT attention, GAT attention, and their fusion.
    
    This creates a comprehensive attribution map showing:
    1. ViT's internal patch attention
    2. GAT's graph-based node attention
    3. Multiple fusion strategies
    """
    model.eval()
    vit_with_attention.eval()
    
    # Load and extract patches
    patches, coordinates = extract_patches_with_coords(image_path)
    
    print(f"\n{'='*60}")
    print(f"Processing: {class_name}")
    print(f"Number of patches: {len(patches)}")
    print(f"{'='*60}\n")
    
    # ========================================
    # 1. Get ViT Internal Attention
    # ========================================
    print("Extracting ViT internal attention...")
    vit_patch_attention = vit_with_attention.get_patch_attention_scores(patches)
    print(f"  ViT attention shape: {vit_patch_attention.shape}")
    print(f"  ViT attention range: [{vit_patch_attention.min():.4f}, {vit_patch_attention.max():.4f}]")
    
    # ========================================
    # 2. Get GAT Node Attention
    # ========================================
    print("\nExtracting GAT node attention...")
    with torch.no_grad():
        logits, attention_info = model(patches, coordinates, return_attention=True)
    
    gat_node_attention = attention_info['node_attention']
    gat_global_attention = attention_info['global_attention']
    
    print(f"  GAT node attention shape: {gat_node_attention.shape}")
    print(f"  GAT node attention range: [{gat_node_attention.min():.4f}, {gat_node_attention.max():.4f}]")
    
    pred_class = torch.argmax(logits).item()
    confidence = F.softmax(logits, dim=0)[pred_class].item()
    
    print(f"\nPrediction: {CLASS_NAMES[pred_class]} (confidence: {confidence:.2%})")
    
    # ========================================
    # 3. Fuse Attentions
    # ========================================
    print("\nFusing attention maps...")
    fused_attentions = {}
    for method in fusion_methods:
        fused = fuse_vit_gat_attention(vit_patch_attention, gat_node_attention, 
                                       fusion_method=method, alpha=0.5)
        fused_attentions[method] = fused
        print(f"  {method.capitalize()} fusion range: [{fused.min():.4f}, {fused.max():.4f}]")
    
    # ========================================
    # 4. Create Comprehensive Visualization
    # ========================================
    print("\nGenerating visualizations...")
    
    # Load original image
    img = Image.open(image_path).convert('RGB')
    img_array = np.array(img)
    
    # Create figure with multiple subplots
    num_fusions = len(fusion_methods)
    fig = plt.figure(figsize=(20, 5 * (num_fusions + 2)))
    gs = fig.add_gridspec(num_fusions + 2, 4, hspace=0.4, wspace=0.3)
    
    # Row 0: Original image and ViT attention
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(img_array)
    ax1.set_title(f'Original Image\nTrue: {class_name} | Pred: {CLASS_NAMES[pred_class]} ({confidence:.2%})', 
                  fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(gs[0, 1])
    scatter = ax2.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=vit_patch_attention, cmap='Reds', s=150, alpha=0.8,
                         edgecolors='black', linewidths=1)
    ax2.set_title('ViT Internal Attention\n(CLS token → patches)', fontsize=12, fontweight='bold')
    ax2.invert_yaxis()
    plt.colorbar(scatter, ax=ax2, fraction=0.046, pad=0.04)
    
    # Overlay heatmap on original image (ViT)
    ax3 = fig.add_subplot(gs[0, 2:])
    ax3.imshow(img_array, alpha=0.5)
    scatter = ax3.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=vit_patch_attention, cmap='Reds', s=200, alpha=0.9)
    ax3.set_title('ViT Attention Overlay', fontsize=12, fontweight='bold')
    ax3.axis('off')
    plt.colorbar(scatter, ax=ax3, fraction=0.046, pad=0.04)
    
    # Row 1: GAT attention
    ax4 = fig.add_subplot(gs[1, 0])
    scatter = ax4.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=gat_node_attention, cmap='Blues', s=150, alpha=0.8,
                         edgecolors='black', linewidths=1)
    ax4.set_title('GAT Node Attention\n(Avg of edge attentions)', fontsize=12, fontweight='bold')
    ax4.invert_yaxis()
    plt.colorbar(scatter, ax=ax4, fraction=0.046, pad=0.04)
    
    ax5 = fig.add_subplot(gs[1, 1])
    scatter = ax5.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=gat_global_attention, cmap='Greens', s=150, alpha=0.8,
                         edgecolors='black', linewidths=1)
    ax5.set_title('GAT Global Pooling\n(Final classification)', fontsize=12, fontweight='bold')
    ax5.invert_yaxis()
    plt.colorbar(scatter, ax=ax5, fraction=0.046, pad=0.04)
    
    # Overlay heatmap on original image (GAT)
    ax6 = fig.add_subplot(gs[1, 2:])
    ax6.imshow(img_array, alpha=0.5)
    scatter = ax6.scatter(coordinates[:, 0], coordinates[:, 1], 
                         c=gat_node_attention, cmap='Blues', s=200, alpha=0.9)
    ax6.set_title('GAT Node Attention Overlay', fontsize=12, fontweight='bold')
    ax6.axis('off')
    plt.colorbar(scatter, ax=ax6, fraction=0.046, pad=0.04)
    
    # Rows 2+: Fused attentions
    for idx, (method, fused_att) in enumerate(fused_attentions.items()):
        row = idx + 2
        
        # Scatter plot
        ax_scatter = fig.add_subplot(gs[row, 0])
        scatter = ax_scatter.scatter(coordinates[:, 0], coordinates[:, 1], 
                                    c=fused_att, cmap='hot', s=150, alpha=0.8,
                                    edgecolors='black', linewidths=1)
        ax_scatter.set_title(f'Fused Attention ({method.capitalize()})\n(ViT ⊗ GAT)', 
                            fontsize=12, fontweight='bold')
        ax_scatter.invert_yaxis()
        plt.colorbar(scatter, ax=ax_scatter, fraction=0.046, pad=0.04)
        
        # Heatmap on original
        ax_overlay = fig.add_subplot(gs[row, 1:3])
        ax_overlay.imshow(img_array, alpha=0.4)
        scatter = ax_overlay.scatter(coordinates[:, 0], coordinates[:, 1], 
                                    c=fused_att, cmap='hot', s=250, alpha=0.95)
        ax_overlay.set_title(f'Fused Attribution Map ({method.capitalize()})', 
                            fontsize=12, fontweight='bold')
        ax_overlay.axis('off')
        plt.colorbar(scatter, ax=ax_overlay, fraction=0.046, pad=0.04)
        
        # Top-k patches
        ax_top = fig.add_subplot(gs[row, 3])
        top_k = 10
        top_indices = np.argsort(fused_att)[-top_k:][::-1]
        colors = plt.cm.hot(np.linspace(0.9, 0.3, top_k))
        
        ax_top.barh(range(top_k), fused_att[top_indices], color=colors)
        ax_top.set_yticks(range(top_k))
        ax_top.set_yticklabels([f"Patch {i}" for i in top_indices])
        ax_top.set_xlabel('Attention Score')
        ax_top.set_title(f'Top-{top_k} Important Patches', fontsize=10, fontweight='bold')
        ax_top.invert_yaxis()
        ax_top.grid(axis='x', alpha=0.3)
    
    plt.suptitle(f'Multi-Level Attention Analysis: {class_name}', 
                 fontsize=16, fontweight='bold', y=0.995)
    
    save_path = f'attention_fusion_{class_name}.png'
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Visualization saved: {save_path}")
    
    # ========================================
    # 5. Print Detailed Statistics
    # ========================================
    print(f"\n{'='*60}")
    print("ATTENTION STATISTICS")
    print(f"{'='*60}")
    
    print(f"\n1. ViT Internal Attention:")
    print(f"   Mean: {vit_patch_attention.mean():.4f}")
    print(f"   Std:  {vit_patch_attention.std():.4f}")
    print(f"   Top-5 patches: {np.argsort(vit_patch_attention)[-5:][::-1]}")
    
    print(f"\n2. GAT Node Attention:")
    print(f"   Mean: {gat_node_attention.mean():.4f}")
    print(f"   Std:  {gat_node_attention.std():.4f}")
    print(f"   Top-5 patches: {np.argsort(gat_node_attention)[-5:][::-1]}")
    
    print(f"\n3. Fused Attentions:")
    for method, fused_att in fused_attentions.items():
        print(f"   {method.capitalize()}:")
        print(f"     Mean: {fused_att.mean():.4f}")
        print(f"     Std:  {fused_att.std():.4f}")
        print(f"     Top-5: {np.argsort(fused_att)[-5:][::-1]}")
    
    print(f"\n4. Graph Structure:")
    print(f"   Threshold: {attention_info['threshold']:.4f}")
    print(f"   Edges kept: {np.mean(attention_info['num_edges_kept']):.0f}")
    print(f"   Sparsity: {(1 - np.mean(attention_info['num_edges_kept']) / len(attention_info['edge_masks_layer2'][0])):.2%}")
    
    print(f"{'='*60}\n")
    
    return {
        'vit_attention': vit_patch_attention,
        'gat_attention': gat_node_attention,
        'gat_global': gat_global_attention,
        'fused_attentions': fused_attentions,
        'prediction': CLASS_NAMES[pred_class],
        'confidence': confidence,
        'coordinates': coordinates
    }

print("Fused attention visualization function defined.")

**Cell 47 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Initialize ViT with attention extraction capability
print("Initializing ViT with attention extraction...")
vit_with_attention = ViTFeatureExtractorWithAttention().to(device)
print("✓ Ready to extract multi-level attention!\n")

**Cell 48 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# RERUN: Complete Fused Attention Extraction Pipeline
print("="*80)
print("INITIALIZING AND RUNNING FUSED ATTENTION EXTRACTION")
print("="*80 + "\n")

# Step 1: Initialize ViT with attention extraction
print("Step 1: Initializing ViT with attention extraction...")
vit_with_attention = ViTFeatureExtractorWithAttention().to(device)
print("✓ ViT initialized!\n")

# Step 2: Run fused attention extraction on all classes
print("Step 2: Extracting fused attention for all classes...")
attention_results = {}

for class_name in CLASS_NAMES:
    class_dir = os.path.join(TRAIN_DIR, class_name)
    if os.path.exists(class_dir):
        images = [f for f in os.listdir(class_dir) if f.endswith(('.tif', '.png', '.jpg'))]
        if images:
            # Take the first image as sample
            sample_image = os.path.join(class_dir, images[10])
            
            # Extract and visualize fused attention
            result = visualize_fused_attention(
                model=model,
                vit_with_attention=vit_with_attention,
                image_path=sample_image,
                class_name=class_name,
                device=device,
                fusion_methods=['additive', 'multiplicative', 'weighted']
            )
            
            attention_results[class_name] = result

print("\n" + "="*80)
print("✓ COMPLETED: Fused attention analysis for all classes")
print("="*80)

**Cell 49 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Test on one sample from each class
print("="*80)
print("TESTING FUSED ATTENTION EXTRACTION")
print("="*80)

attention_results = {}

for class_name in CLASS_NAMES:
    class_dir = os.path.join(TRAIN_DIR, class_name)
    if os.path.exists(class_dir):
        images = [f for f in os.listdir(class_dir) if f.endswith(('.tif', '.png', '.jpg'))]
        if images:
            # Take the first image as sample
            sample_image = os.path.join(class_dir, images[0])
            
            # Extract and visualize fused attention
            result = visualize_fused_attention(
                model=model,
                vit_with_attention=vit_with_attention,
                image_path=sample_image,
                class_name=class_name,
                device=device,
                fusion_methods=['additive', 'multiplicative', 'weighted']
            )
            
            attention_results[class_name] = result

print("\n" + "="*80)
print("✓ COMPLETED: Fused attention analysis for all classes")
print("="*80)

**Cell 50 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Compare attention patterns across classes
print("\n" + "="*80)
print("CROSS-CLASS ATTENTION ANALYSIS")
print("="*80 + "\n")

comparison_data = {
    'class': [],
    'vit_mean': [],
    'vit_std': [],
    'gat_mean': [],
    'gat_std': [],
    'fusion_mult_mean': [],
    'fusion_mult_std': []
}

for class_name, result in attention_results.items():
    comparison_data['class'].append(class_name)
    comparison_data['vit_mean'].append(result['vit_attention'].mean())
    comparison_data['vit_std'].append(result['vit_attention'].std())
    comparison_data['gat_mean'].append(result['gat_attention'].mean())
    comparison_data['gat_std'].append(result['gat_attention'].std())
    comparison_data['fusion_mult_mean'].append(result['fused_attentions']['multiplicative'].mean())
    comparison_data['fusion_mult_std'].append(result['fused_attentions']['multiplicative'].std())

# Create comparison plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

x_pos = np.arange(len(comparison_data['class']))
width = 0.35

# ViT Attention
axes[0].bar(x_pos, comparison_data['vit_mean'], width, 
           yerr=comparison_data['vit_std'], capsize=5, 
           color='coral', alpha=0.8, edgecolor='black')
axes[0].set_xlabel('Class', fontweight='bold')
axes[0].set_ylabel('Mean Attention', fontweight='bold')
axes[0].set_title('ViT Internal Attention by Class', fontweight='bold')
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(comparison_data['class'], rotation=45)
axes[0].grid(axis='y', alpha=0.3)

# GAT Attention
axes[1].bar(x_pos, comparison_data['gat_mean'], width, 
           yerr=comparison_data['gat_std'], capsize=5, 
           color='skyblue', alpha=0.8, edgecolor='black')
axes[1].set_xlabel('Class', fontweight='bold')
axes[1].set_ylabel('Mean Attention', fontweight='bold')
axes[1].set_title('GAT Node Attention by Class', fontweight='bold')
axes[1].set_xticks(x_pos)
axes[1].set_xticklabels(comparison_data['class'], rotation=45)
axes[1].grid(axis='y', alpha=0.3)

# Fused Attention
axes[2].bar(x_pos, comparison_data['fusion_mult_mean'], width, 
           yerr=comparison_data['fusion_mult_std'], capsize=5, 
           color='lightgreen', alpha=0.8, edgecolor='black')
axes[2].set_xlabel('Class', fontweight='bold')
axes[2].set_ylabel('Mean Attention', fontweight='bold')
axes[2].set_title('Fused Attention (Multiplicative) by Class', fontweight='bold')
axes[2].set_xticks(x_pos)
axes[2].set_xticklabels(comparison_data['class'], rotation=45)
axes[2].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('attention_comparison_across_classes.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Cross-class comparison saved: attention_comparison_across_classes.png\n")

# Print summary table
print("Summary Table:")
print("-" * 80)
print(f"{'Class':<12} | {'ViT Mean':<10} | {'GAT Mean':<10} | {'Fused Mean':<10} | {'Prediction':<10}")
print("-" * 80)
for class_name, result in attention_results.items():
    print(f"{class_name:<12} | {result['vit_attention'].mean():<10.4f} | "
          f"{result['gat_attention'].mean():<10.4f} | "
          f"{result['fused_attentions']['multiplicative'].mean():<10.4f} | "
          f"{result['prediction']:<10}")
print("-" * 80)

**Cell 51 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following markdown cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

## 15. Clinical-Grade Heatmap Generation

Generate high-quality, interpretable heatmaps suitable for clinical interpretation.

**Cell 52 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
from scipy.interpolate import griddata
from matplotlib.colors import LinearSegmentedColormap

def generate_smooth_heatmap(image_path, coordinates, attention_scores, 
                            output_path=None, alpha=0.6, method='cubic'):
    """
    Generate smooth, interpolated heatmap overlay on original image.
    
    Args:
        image_path: Path to original image
        coordinates: (N, 2) patch center coordinates
        attention_scores: (N,) attention values
        output_path: Where to save the result
        alpha: Transparency of heatmap overlay
        method: Interpolation method ('linear', 'cubic', 'nearest')
    """
    # Load image
    img = Image.open(image_path).convert('RGB')
    img_array = np.array(img)
    height, width = img_array.shape[:2]
    
    # Create grid for interpolation
    grid_x, grid_y = np.mgrid[0:width:1, 0:height:1]
    
    # Interpolate attention scores to create smooth heatmap
    heatmap = griddata(
        points=coordinates,
        values=attention_scores,
        xi=(grid_x, grid_y),
        method=method,
        fill_value=0
    ).T
    
    # Normalize heatmap
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
    
    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(20, 7))
    
    # Original image
    axes[0].imshow(img_array)
    axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Heatmap only
    im = axes[1].imshow(heatmap, cmap='hot', interpolation='bilinear')
    axes[1].set_title('Attention Heatmap', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Overlay
    axes[2].imshow(img_array, alpha=1.0)
    im2 = axes[2].imshow(heatmap, cmap='hot', alpha=alpha, interpolation='bilinear')
    axes[2].set_title('Heatmap Overlay (Attribution Map)', fontsize=14, fontweight='bold')
    axes[2].axis('off')
    plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved: {output_path}")
    
    plt.show()
    
    return heatmap

print("Smooth heatmap generator defined.")

**Cell 53 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

In [None]:
# Generate smooth heatmaps for all classes and fusion methods
print("="*80)
print("GENERATING CLINICAL-GRADE HEATMAPS")
print("="*80 + "\n")

for class_name, result in attention_results.items():
    print(f"\nProcessing {class_name}...")
    
    # Find the original image path
    class_dir = os.path.join(TRAIN_DIR, class_name)
    images = [f for f in os.listdir(class_dir) if f.endswith(('.tif', '.png', '.jpg'))]
    sample_image = os.path.join(class_dir, images[0])
    
    coordinates = result['coordinates']
    
    # Generate heatmaps for different attention sources
    print(f"  1. ViT Internal Attention...")
    generate_smooth_heatmap(
        sample_image, coordinates, result['vit_attention'],
        output_path=f'heatmap_vit_{class_name}.png',
        alpha=0.6
    )
    
    print(f"  2. GAT Node Attention...")
    generate_smooth_heatmap(
        sample_image, coordinates, result['gat_attention'],
        output_path=f'heatmap_gat_{class_name}.png',
        alpha=0.6
    )
    
    print(f"  3. Fused Attention (Multiplicative)...")
    generate_smooth_heatmap(
        sample_image, coordinates, result['fused_attentions']['multiplicative'],
        output_path=f'heatmap_fused_multiplicative_{class_name}.png',
        alpha=0.7
    )
    
    print(f"  4. Fused Attention (Additive)...")
    generate_smooth_heatmap(
        sample_image, coordinates, result['fused_attentions']['additive'],
        output_path=f'heatmap_fused_additive_{class_name}.png',
        alpha=0.7
    )

print("\n" + "="*80)
print("✓ ALL HEATMAPS GENERATED")
print("="*80)

**Cell 54 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.

**Cell 55 — Purpose & Usage**

Configuration: Uses prior definitions and variables in the notebook.

Explanation: The following code cell performs the next logical step in the pipeline (see its code/content).

Use: Run this cell in sequence; intended to be executed after earlier setup cells.