# MCIANet Attention Map Visualisation

Visualises the internal attention representations stored by `MCIANet` during a forward pass:

| Attribute | Shape | Description |
|---|---|---|
| `backbone.last_spatial_masks` | `[B, C, 1, 12, 12]` | Per-channel spatial focus maps |
| `backbone.attn1.last_attn` | `[C, C]` | Cross-channel attention after stage 1 (6×6) |
| `backbone.attn2.last_attn` | `[C, C]` | Cross-channel attention after stage 2 (3×3) |

**Requires**: a trained `CODEX_cHL_ATT_MASK_VP` checkpoint.

## 1  Setup

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import h5py
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from pathlib import Path

# ── Paths — edit these to match your environment ──────────────────────────────
CFG_PATH  = '/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/src/MCA/configs/_experiments_/CODEX_cHL_ATT_MASK_VP.py'
CKPT_PATH = '/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/src/MCA/z_RUNS/CODEX_cHL_ATT_MASK_VP/iter_XXXX.pth'  # ← update
H5_PATH   = '/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/data/MCI_data/h5_files/CODEX_cHL/CODEX_cHL.h5'
MARKERS_PATH = '/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/data/MCI_data/h5_files/CODEX_cHL/used_markers.txt'
VAL_IDX_PATH = '/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/data/MCI_data/h5_files/CODEX_cHL/val.txt'

N_PATCHES = 6   # number of sample patches to load
DEVICE    = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f'Using device: {DEVICE}')

In [None]:
# Load marker names from HDF5
with h5py.File(H5_PATH, 'r') as f:
    all_markers = f['marker_names'][:].astype(str)

# Filter to used markers if a list is provided
used_markers_path = Path(MARKERS_PATH)
if used_markers_path.exists():
    used_marker_names = np.loadtxt(MARKERS_PATH, dtype=str, delimiter='\n')
    # Keep only markers that appear in the HDF5
    marker_names = [m for m in used_marker_names if m in all_markers]
else:
    marker_names = list(all_markers)

print(f'Markers ({len(marker_names)}):', marker_names)

## 2  Load model

In [None]:
from mmengine import Config
from mmengine.registry import MODELS, init_default_scope
from mmengine.runner import load_checkpoint

# Trigger custom_imports so MCIANet, MVSimCLR, etc. are registered
cfg = Config.fromfile(CFG_PATH)
init_default_scope(cfg.get('default_scope', 'mmselfsup'))

model = MODELS.build(cfg.model)
load_checkpoint(model, CKPT_PATH, map_location='cpu')
model = model.to(DEVICE)
model.eval()

backbone = model.backbone
print('Model loaded. Backbone type:', type(backbone).__name__)

## 3  Load sample patches

In [None]:
from mmengine.registry import DATASETS
from copy import deepcopy

# Build a minimal val dataset using the config's val pipeline
dataset_kwargs = deepcopy(cfg.get('_base_', {}))

# Simpler: build MCIDataset directly with the val pipeline from config
val_pipeline = [
    dict(type='C_CentralCutter', size=cfg.get('cutter_size', 24)),
    dict(type='C_ToTensor'),
    dict(type='C_MultiView', n_views=[1], transforms=[None]),
    dict(type='C_PackInputs'),
]

ds = DATASETS.build(dict(
    type='MCIDataset',
    h5_filepath=H5_PATH,
    patch_size=cfg.get('patch_size', 32),
    used_markers=MARKERS_PATH,
    used_indicies=VAL_IDX_PATH,
    pipeline=val_pipeline,
    ignore_annotation=['Seg Artifact'],
))

print(f'Val dataset: {len(ds)} cells')

In [None]:
# Pick N_PATCHES random indices and stack into a batch tensor
rng = np.random.default_rng(42)
sample_indices = rng.integers(0, len(ds), size=N_PATCHES).tolist()

patches = []
annotations = []
for idx in sample_indices:
    item = ds[idx]
    # item['inputs'] is a list of tensors (one per view); take view 0
    patches.append(item['inputs'][0])
    annotations.append(item['data_samples'].get('annotation', ['?'])[0])

x = torch.stack(patches).to(DEVICE)   # [B, C, H, W]
print(f'Patch batch: {x.shape}  |  labels: {annotations}')

## 4  Forward pass (populates stored attention attributes)

In [None]:
with torch.no_grad():
    _ = backbone(x)

# Retrieve stored tensors and move to CPU for plotting
spatial_masks = backbone.last_spatial_masks.cpu()   # [B, C, 1, 12, 12]
attn1         = backbone.attn1.last_attn.cpu()      # [C, C]  (averaged over batch by MHA)
attn2         = backbone.attn2.last_attn.cpu()      # [C, C]

B, C, _, H_m, W_m = spatial_masks.shape
print(f'spatial_masks : {spatial_masks.shape}')
print(f'attn1         : {attn1.shape}')
print(f'attn2         : {attn2.shape}')

## 5  Spatial focus maps

For each biological channel, show the raw patch intensity alongside the 12×12 attention mask (upsampled to 24×24 for overlay).

In [None]:
import torch.nn.functional as F

PATCH_IDX = 0   # which sample to visualise; change as needed

patch_np    = x[PATCH_IDX].cpu().numpy()           # [C, H, W]
masks_np    = spatial_masks[PATCH_IDX, :, 0]       # [C, 12, 12]
# Upsample masks to input resolution
masks_up    = F.interpolate(
    spatial_masks[PATCH_IDX, :, 0].unsqueeze(1),   # [C, 1, 12, 12]
    size=(patch_np.shape[-2], patch_np.shape[-1]),
    mode='bilinear',
    align_corners=False,
).squeeze(1).numpy()                                # [C, H, W]

ncols = 6
nrows = int(np.ceil(C / ncols))
fig, axes = plt.subplots(nrows, ncols * 2, figsize=(ncols * 4, nrows * 2.5))
axes = axes.flatten()

for c in range(C):
    ax_img  = axes[c * 2]
    ax_mask = axes[c * 2 + 1]

    img = patch_np[c]
    msk = masks_up[c]

    ax_img.imshow(img, cmap='inferno', vmin=0, vmax=img.max() + 1e-6)
    ax_img.set_title(marker_names[c] if c < len(marker_names) else f'Ch{c}', fontsize=7)
    ax_img.axis('off')

    ax_mask.imshow(img, cmap='gray',   vmin=0, vmax=img.max() + 1e-6)
    ax_mask.imshow(msk, cmap='hot',    alpha=0.6, vmin=0, vmax=1)
    ax_mask.set_title('focus', fontsize=7)
    ax_mask.axis('off')

# Hide unused subplots
for i in range(C * 2, len(axes)):
    axes[i].set_visible(False)

fig.suptitle(
    f'Spatial focus maps  |  patch {PATCH_IDX}: {annotations[PATCH_IDX]}',
    fontsize=11, y=1.01
)
plt.tight_layout()
plt.show()

## 6  Channel cross-attention heatmaps

`attn1` (6×6 spatial scale) and `attn2` (3×3 spatial scale): row = query channel, column = key channel.  
High values → the query channel attends strongly to the key channel.

In [None]:
def plot_attn_heatmap(attn: torch.Tensor, marker_names: list, title: str, ax):
    """Plot a C×C attention matrix as an annotated heatmap."""
    C = attn.shape[0]
    labels = marker_names[:C] if len(marker_names) >= C else [f'Ch{i}' for i in range(C)]
    sns.heatmap(
        attn.numpy(),
        xticklabels=labels,
        yticklabels=labels,
        ax=ax,
        cmap='viridis',
        square=True,
        cbar_kws={'shrink': 0.5},
        linewidths=0,
    )
    ax.set_title(title, fontsize=11)
    ax.tick_params(axis='x', labelrotation=90, labelsize=6)
    ax.tick_params(axis='y', labelrotation=0,  labelsize=6)


fig, axes = plt.subplots(1, 2, figsize=(22, 10))
plot_attn_heatmap(attn1, marker_names, 'Cross-channel attention  (stage 1 – 6×6)', axes[0])
plot_attn_heatmap(attn2, marker_names, 'Cross-channel attention  (stage 2 – 3×3)', axes[1])
plt.tight_layout()
plt.show()

## 7  Channel importance ranking

Row-sum of each attention matrix = total attention weight received by each channel.  
Channels with high row-sums are consistently attended to by other channels.

In [None]:
def channel_importance(attn: torch.Tensor) -> np.ndarray:
    """Mean attention weight received per channel (column-sum)."""
    return attn.mean(dim=0).numpy()   # average over query rows → [C]


def plot_importance(importance: np.ndarray, marker_names: list, title: str, ax, color: str):
    C = len(importance)
    labels = marker_names[:C] if len(marker_names) >= C else [f'Ch{i}' for i in range(C)]
    order  = np.argsort(importance)[::-1]
    ax.bar(range(C), importance[order], color=color, alpha=0.8)
    ax.set_xticks(range(C))
    ax.set_xticklabels([labels[i] for i in order], rotation=90, fontsize=7)
    ax.set_ylabel('Mean attention weight')
    ax.set_title(title, fontsize=11)


imp1 = channel_importance(attn1)
imp2 = channel_importance(attn2)

fig, axes = plt.subplots(1, 2, figsize=(20, 5))
plot_importance(imp1, marker_names, 'Channel importance — stage 1 (6×6)', axes[0], '#4C8EDA')
plot_importance(imp2, marker_names, 'Channel importance — stage 2 (3×3)', axes[1], '#E8735A')
plt.tight_layout()
plt.show()

## 8  Multi-sample spatial mask overlay

For a single selected biomarker channel, compare the spatial focus mask across all loaded patches to check for consistency.

In [None]:
CHANNEL_NAME = 'CD20'   # change to any marker in marker_names
if CHANNEL_NAME in marker_names:
    CHANNEL_IDX = marker_names.index(CHANNEL_NAME)
else:
    CHANNEL_IDX = 0
    CHANNEL_NAME = marker_names[CHANNEL_IDX] if marker_names else f'Ch{CHANNEL_IDX}'
    print(f'Channel not found; falling back to {CHANNEL_NAME}')

patch_np_all = x.cpu().numpy()                      # [B, C, H, W]
masks_up_all = F.interpolate(
    spatial_masks[:, CHANNEL_IDX, 0].unsqueeze(1),  # [B, 1, 12, 12]
    size=(patch_np_all.shape[-2], patch_np_all.shape[-1]),
    mode='bilinear',
    align_corners=False,
).squeeze(1).numpy()                                # [B, H, W]

fig, axes = plt.subplots(2, B, figsize=(B * 3, 6))
if B == 1:
    axes = axes.reshape(2, 1)

for b in range(B):
    img = patch_np_all[b, CHANNEL_IDX]
    msk = masks_up_all[b]

    axes[0, b].imshow(img, cmap='inferno', vmin=0, vmax=img.max() + 1e-6)
    axes[0, b].set_title(f'{annotations[b]}', fontsize=8)
    axes[0, b].axis('off')

    axes[1, b].imshow(img, cmap='gray',  vmin=0, vmax=img.max() + 1e-6)
    axes[1, b].imshow(msk, cmap='hot',   alpha=0.6, vmin=0, vmax=1)
    axes[1, b].axis('off')

axes[0, 0].set_ylabel('Raw', fontsize=9)
axes[1, 0].set_ylabel('Focus overlay', fontsize=9)

fig.suptitle(f'Spatial focus masks for "{CHANNEL_NAME}" across {B} patches', fontsize=12)
plt.tight_layout()
plt.show()