# Attention Map Inspection
Utility helpers to load and inspect multiple-instance learning attention weights saved during prototype runs.

In [7]:
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch
import matplotlib.pyplot as plt
from PIL import Image, ImageOps

RunSelector = Union[str, Path]


def _resolve_run(root: Path, run: Optional[RunSelector]) -> Path:
    if not root.exists():
        raise FileNotFoundError(f'Run root {root} does not exist')
    run_dirs = [d for d in root.iterdir() if d.is_dir()]
    if not run_dirs:
        raise FileNotFoundError(f'No runs found under {root}')
    if run is None:
        return max(run_dirs, key=lambda d: d.stat().st_mtime)
    run_path = root / Path(run)
    if run_path.exists():
        return run_path
    matches = [d for d in run_dirs if str(run) in d.name]
    if len(matches) == 1:
        return matches[0]
    if matches:
        raise ValueError(f'Multiple runs matched {run}: {[d.name for d in matches]}')
    raise FileNotFoundError(f'Run {run} not found under {root}')


def _resolve_epoch(run_dir: Path, epoch: Optional[int]) -> Path:
    epoch_dirs = sorted([d for d in run_dir.glob('epoch_*') if d.is_dir()])
    if not epoch_dirs:
        raise FileNotFoundError(f'No epoch directories found under {run_dir}')
    if epoch is None:
        return epoch_dirs[-1]
    epoch_path = run_dir / f'epoch_{epoch:03d}'
    if epoch_path.exists():
        return epoch_path
    raise FileNotFoundError(f'Epoch {epoch} not found under {run_dir}')


def load_attention_records(
    runs_root: RunSelector = '../prototype/runs',
    run: Optional[RunSelector] = None,
    epoch: Optional[int] = None,
) -> Tuple[Sequence[dict], Path, Path]:
    root_path = Path(runs_root).expanduser().resolve()
    run_dir = _resolve_run(root_path, run)
    epoch_dir = _resolve_epoch(run_dir, epoch)
    attention_path = epoch_dir / 'attention_maps.pt'
    if not attention_path.exists():
        raise FileNotFoundError(f'Missing attention map file at {attention_path}')
    records = torch.load(attention_path, map_location='cpu')
    if not isinstance(records, (list, tuple)):
        raise TypeError(f'Unexpected attention payload type: {type(records).__name__}')
    return records, run_dir, epoch_dir


def _filter_records(
    records: Sequence[dict],
    bag_ids: Optional[Sequence[Union[str, int]]],
    limit: Optional[int],
) -> Sequence[dict]:
    bag_filter = None
    if bag_ids is not None:
        bag_filter = {str(b) for b in bag_ids}
    filtered: list = []
    for record in records:
        bag_id = record.get('bag_id')
        if bag_filter is not None and str(bag_id) not in bag_filter:
            continue
        filtered.append(record)
        if limit is not None and len(filtered) >= limit:
            break
    return filtered


def print_attention_maps(
    runs_root: RunSelector = '../prototype/runs',
    run: Optional[RunSelector] = None,
    epoch: Optional[int] = None,
    bag_ids: Optional[Sequence[Union[str, int]]] = None,
    top_k: Optional[int] = 5,
    limit: Optional[int] = None,
) -> None:
    """
    Load attention weights saved during training checkpoints and print the most attended tiles.

    Args:
        runs_root: Root directory that contains prototype training runs.
        run: Specific run directory name or partial match. Defaults to the most recent run.
        epoch: Epoch number to inspect. Defaults to the latest available epoch.
        bag_ids: Optional sequence of bag identifiers to filter the output.
        top_k: Number of highest attention tiles to print for each bag. Use None to show all weights.
        limit: Maximum number of bags to display. Useful when runs contain many bags.
    """
    records, run_dir, epoch_dir = load_attention_records(runs_root=runs_root, run=run, epoch=epoch)
    filtered = _filter_records(records, bag_ids=bag_ids, limit=limit)
    print(f'Run: {run_dir.name} | Epoch: {epoch_dir.name}')
    print(f'Total attention records: {len(records)}')
    if not filtered:
        print('No attention records matched the provided filters.')
        return
    for idx, record in enumerate(filtered, start=1):
        bag_id = record.get('bag_id', f'bag_{idx}')
        bag_label = record.get('label', 'NA')
        probability = record.get('probability', 'NA')
        attention = record.get('attention')
        if attention is None:
            print(f'- Bag {bag_id}: missing attention tensor')
            continue
        attention_tensor = torch.as_tensor(attention).detach().flatten().cpu()
        count = attention_tensor.numel()
        print(f'- Bag {bag_id} | label={bag_label} | p={probability}')
        if count == 0:
            print('  (empty attention vector)')
            continue
        if top_k is None:
            weights = attention_tensor.tolist()
            print(f'  weights: {weights}')
        else:
            k = max(1, min(top_k, count))
            values, indices = torch.topk(attention_tensor, k)
            pairs = ', '.join(f"idx={i} score={v:.4f}" for i, v in zip(indices.tolist(), values.tolist()))
            print(f'  top-{k}: {pairs}')


@lru_cache(maxsize=None)
def _load_bag_index(embeddings_path: Union[str, Path]) -> Dict[str, Dict[str, Any]]:
    path = Path(embeddings_path).expanduser().resolve()
    if not path.exists():
        raise FileNotFoundError(f'Embeddings file not found at {path}')
    raw_entries = torch.load(path, map_location='cpu')
    index: Dict[str, Dict[str, Any]] = {}
    for idx, entry in enumerate(raw_entries):
        bag_id = None
        coords = None
        slide_path = None
        tile_paths = None
        tile_size = None
        stride = None

        if isinstance(entry, dict):
            bag_id = entry.get('bag_id') or entry.get('id') or entry.get('slide_id') or entry.get('name')
            coords = entry.get('coords') or entry.get('tile_coords')
            tile_paths = entry.get('tile_paths') or entry.get('tiles')
            slide_path = entry.get('slide_path') or entry.get('image_path') or entry.get('img_path')
            meta = entry.get('meta') or {}
            coords = coords if coords is not None else meta.get('coords')
            tile_paths = tile_paths if tile_paths is not None else meta.get('tile_paths')
            slide_path = slide_path if slide_path is not None else meta.get('slide_path')
            tile_size = entry.get('tile_size') or meta.get('tile_size')
            stride = entry.get('stride') or meta.get('stride')
        elif isinstance(entry, (list, tuple)):
            if len(entry) >= 3:
                bag_id = entry[2]
            if len(entry) >= 4:
                coords = entry[3]
            if len(entry) >= 5:
                slide_path = entry[4]
            if len(entry) >= 6:
                tile_paths = entry[5]
            if len(entry) >= 7:
                tile_size = entry[6]
            if len(entry) >= 8:
                stride = entry[7]

        if bag_id is None:
            bag_id = f'bag_{idx:05d}'

        meta = {
            'bag_id': str(bag_id),
            'coords': coords,
            'slide_path': slide_path,
            'tile_paths': tile_paths,
            'tile_size': tile_size,
            'stride': stride,
        }
        canonical = str(bag_id)
        index[canonical] = meta
        # Add convenient aliases based on filename stems when applicable.
        name_variant = Path(canonical).name
        stem_variant = Path(canonical).stem
        for alias in {name_variant, stem_variant}:
            index.setdefault(alias, meta)
    return index


def _lookup_bag_meta(index: Dict[str, Dict[str, Any]], bag_id: Union[str, int]) -> Optional[Dict[str, Any]]:
    key = str(bag_id)
    if key in index:
        return index[key]
    variants = {Path(key).name, Path(key).stem}
    for variant in variants:
        if variant in index:
            return index[variant]
    return None


def _resolve_slide_path(meta: Dict[str, Any], image_root: Optional[Path], bag_id: Union[str, int], image_pattern: Optional[str]) -> Optional[Path]:
    candidates = []
    slide_path = meta.get('slide_path')
    if slide_path:
        candidates.append(Path(slide_path))
    if image_pattern:
        pattern_value = image_pattern.format(bag_id=Path(str(bag_id)).stem)
        candidates.append(Path(pattern_value))
    tried = set()
    for candidate in candidates:
        for base in (None, image_root):
            if base is None or candidate.is_absolute():
                path = candidate
            else:
                path = base / candidate
            if path in tried:
                continue
            tried.add(path)
            resolved = path.expanduser()
            if resolved.exists():
                return resolved
            if resolved.suffix == '':
                for ext in ('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'):
                    alt = resolved.with_suffix(ext)
                    if alt.exists():
                        return alt
    return None


def _load_tile_images_for_indices(
    meta: Dict[str, Any],
    indices: Sequence[int],
    image_root: Optional[Path],
    bag_id: Union[str, int],
    image_pattern: Optional[str],
    fallback_tile_size: int,
) -> Tuple[Sequence[Image.Image], int]:
    tile_paths = meta.get('tile_paths')
    tile_size = int(meta.get('tile_size') or fallback_tile_size)
    images: list[Image.Image] = []

    if tile_paths:
        paths = list(tile_paths)
        for idx in indices:
            if idx < 0 or idx >= len(paths):
                continue
            raw = Path(paths[idx])
            candidates = [raw]
            if image_root is not None:
                candidates.append(image_root / raw)
            loaded = None
            for path in candidates:
                resolved = path.expanduser()
                if resolved.exists():
                    loaded = resolved
                    break
            if loaded is None:
                continue
            img = ImageOps.exif_transpose(Image.open(loaded)).convert('RGB')
            images.append(img)
        return images, tile_size

    coords = meta.get('coords')
    if coords is None:
        return [], tile_size

    coords_tensor = torch.as_tensor(coords).detach().cpu()
    if coords_tensor.ndim != 2 or coords_tensor.shape[1] < 2:
        return [], tile_size

    slide_path = _resolve_slide_path(meta, image_root, bag_id, image_pattern)
    if slide_path is None or not slide_path.exists():
        return [], tile_size

    slide = ImageOps.exif_transpose(Image.open(slide_path)).convert('RGB')
    width, height = slide.size
    for idx in indices:
        if idx < 0 or idx >= coords_tensor.shape[0]:
            continue
        y, x = coords_tensor[idx][:2].tolist()
        x0, y0 = int(x), int(y)
        x1, y1 = x0 + tile_size, y0 + tile_size
        x1 = min(x1, width)
        y1 = min(y1, height)
        crop = slide.crop((x0, y0, x1, y1))
        images.append(crop)
    return images, tile_size


def display_attention_maps(
    runs_root: RunSelector = '../prototype/runs',
    run: Optional[RunSelector] = None,
    epoch: Optional[int] = None,
    bag_ids: Optional[Sequence[Union[str, int]]] = None,
    top_k: Optional[int] = 10,
    limit: Optional[int] = 4,
    figsize: Tuple[int, int] = (3, 3),
    embeddings_path: Union[str, Path] = '../data/tensors/embeddings.pt',
    image_root: Union[str, Path, None] = '../data',
    image_pattern: str = 'core_data/{bag_id}.jpg',
    fallback_tile_size: int = 224,
) -> None:
    """
    Visualize the highest-attention tiles for selected bags as image grids.

    Args mirror `print_attention_maps`, with additional parameters:
        figsize: Size of each tile subplot.
        embeddings_path: Location of the dataset used during training (for tile metadata).
        image_root: Base directory used to resolve slide or tile file locations.
        image_pattern: Fallback pattern for locating slides when metadata lacks explicit paths.
        fallback_tile_size: Crop size to use when tile size metadata is missing.
    """
    records, run_dir, epoch_dir = load_attention_records(runs_root=runs_root, run=run, epoch=epoch)
    filtered = _filter_records(records, bag_ids=bag_ids, limit=limit)
    print(f'Run: {run_dir.name} | Epoch: {epoch_dir.name}')
    print(f'Total attention records: {len(records)}')
    if not filtered:
        print('No attention records matched the provided filters.')
        return

    image_root_path = Path(image_root).expanduser().resolve() if image_root is not None else None
    bag_index = _load_bag_index(embeddings_path)

    for idx, record in enumerate(filtered, start=1):
        bag_id = record.get('bag_id', f'bag_{idx}')
        metadata = _lookup_bag_meta(bag_index, bag_id)
        if metadata is None:
            print(f'- Bag {bag_id}: metadata not found in embeddings file; skipping image display.')
            continue

        attention = record.get('attention')
        if attention is None:
            print(f'- Bag {bag_id}: missing attention tensor')
            continue
        attention_tensor = torch.as_tensor(attention).detach().flatten().cpu()
        count = attention_tensor.numel()
        if count == 0:
            print(f'- Bag {bag_id}: attention tensor empty')
            continue

        if top_k is None:
            k = min(count, 12)
        else:
            k = max(1, min(top_k, count))
        values, indices = torch.topk(attention_tensor, k)
        tiles, tile_size = _load_tile_images_for_indices(
            metadata,
            indices.tolist(),
            image_root_path,
            bag_id,
            image_pattern,
            fallback_tile_size,
        )
        if not tiles:
            print(f'- Bag {bag_id}: unable to locate tile imagery for the selected indices.')
            continue

        label = record.get('label', 'NA')
        probability = record.get('probability', 'NA')
        try:
            prob_display = f'{float(probability):.3f}'
        except (TypeError, ValueError):
            prob_display = str(probability)

        fig, axes = plt.subplots(1, len(tiles), figsize=(figsize[0] * len(tiles), figsize[1]))
        if not isinstance(axes, Sequence):
            axes = [axes]
        for ax, tile_img, tile_idx, weight in zip(axes, tiles, indices.tolist(), values.tolist()):
            ax.imshow(tile_img)
            ax.set_title(f'idx {tile_idx} | w={weight:.3f}', fontsize=9)
            ax.axis('off')
        fig.suptitle(f'Bag {bag_id} | label={label} | p={prob_display} | tile={tile_size}', fontsize=12)
        plt.tight_layout()
        plt.show()


In [8]:
# Example usage (uncomment and adjust parameters as needed)
print_attention_maps(run='training_run_mil_1', epoch=5, top_k=5, limit=3)
display_attention_maps(run='training_run_mil_1', epoch=5, top_k=6, limit=2)


  records = torch.load(attention_path, map_location='cpu')
  raw_entries = torch.load(path, map_location='cpu')


Run: training_run_mil_1 | Epoch: epoch_005
Total attention records: 8
- Bag bag_00000 | label=0.0 | p=0.5962618589401245
  top-5: idx=3 score=0.3144, idx=9 score=0.1409, idx=0 score=0.1304, idx=6 score=0.1177, idx=8 score=0.0735
- Bag bag_00001 | label=0.0 | p=0.7020145058631897
  top-5: idx=9 score=0.4004, idx=6 score=0.1529, idx=3 score=0.0835, idx=5 score=0.0685, idx=4 score=0.0676
- Bag bag_00002 | label=0.0 | p=0.5735396146774292
  top-5: idx=5 score=0.3026, idx=2 score=0.1322, idx=4 score=0.1057, idx=9 score=0.0882, idx=3 score=0.0844
Run: training_run_mil_1 | Epoch: epoch_005
Total attention records: 8
- Bag bag_00000: unable to locate tile imagery for the selected indices.
- Bag bag_00001: unable to locate tile imagery for the selected indices.
