# TopoEncoder Latent Space Dashboard (Panel + HoloViews)

This notebook builds a Panel dashboard that loads a TopoEncoder experiment folder (checkpoint + benchmarks)
and compares 3D latent spaces for the TopoEncoder, StandardVQ, and VanillaAE baselines. Use the controls
to select checkpoints, sampling, coloring, and styling so you can swap visualization styles without
regenerating static images.

**Usage**
1. Set the experiment folder in the sidebar (for example: `outputs/3d_topoencoder_mnist_cpu_adapt_lr7_bnch`).
2. Click **Scan checkpoints** and **Load + Analyze**.
3. Explore the latent plots and metrics.


In [1]:
from __future__ import annotations

from pathlib import Path
import re

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import panel as pn
import holoviews as hv

from fragile.paths import DEFAULT_EXPERIMENT_DIR, OUTPUTS_DIR
from fragile.core.layers import StandardVQ, TopoEncoderPrimitives, VanillaAE
from fragile.core.layers.topology import InvariantChartClassifier
from fragile.datasets import CIFAR10_CLASSES

# Plotly backend is required for 3D scatter. Raise a clear error if missing.
try:
    pn.extension("plotly", sizing_mode="stretch_width")
    hv.extension("plotly")
except Exception as exc:
    raise RuntimeError(
        "Plotly backend is required for this dashboard. Install plotly and rerun."
    ) from exc


In [2]:
def _load_checkpoint(path: Path) -> dict:
    if not path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    try:
        return torch.load(path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(path, map_location="cpu")


def _load_benchmarks(checkpoint_path: Path) -> dict | None:
    bench_path = checkpoint_path.parent / "benchmarks.pt"
    if not bench_path.exists():
        return None
    try:
        return torch.load(bench_path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(bench_path, map_location="cpu")


def _checkpoint_sort_key(path: Path) -> tuple[int, int, str]:
    name = path.name
    if "final" in name:
        return (1, 0, name)
    match = re.search(r"(\d+)", name)
    if match:
        return (0, int(match.group(1)), name)
    return (0, -1, name)


def _collect_checkpoints(exp_dir: Path) -> list[Path]:
    if not exp_dir.exists():
        return []
    return sorted(
        [p for p in exp_dir.glob("*.pt") if p.name != "benchmarks.pt"],
        key=_checkpoint_sort_key,
    )


def _default_experiment_dir() -> Path:
    if DEFAULT_EXPERIMENT_DIR.exists():
        return DEFAULT_EXPERIMENT_DIR
    outputs = OUTPUTS_DIR
    if not outputs.exists():
        return Path(".")
    for candidate in sorted(outputs.iterdir()):
        if candidate.is_dir() and (candidate / "benchmarks.pt").exists():
            return candidate
    return outputs


def _as_tensor(x: np.ndarray | torch.Tensor) -> torch.Tensor:
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x).float()
    return x.float()


def _as_numpy(x) -> np.ndarray | None:
    if x is None:
        return None
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)


def _batch_iter(x: torch.Tensor, batch_size: int):
    n = x.shape[0]
    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        yield start, end, x[start:end]


def _pca_fit(z: np.ndarray) -> tuple[np.ndarray, np.ndarray | None]:
    if z.shape[1] < 3:
        mean = np.zeros(z.shape[1])
        return mean, None
    if z.shape[1] == 3:
        mean = np.zeros(3)
        return mean, np.eye(3)
    mean = z.mean(axis=0, keepdims=True)
    z_centered = z - mean
    _, _, vt = np.linalg.svd(z_centered, full_matrices=False)
    return mean.reshape(-1), vt[:3].T


def _pca_apply(z: np.ndarray, mean: np.ndarray, basis: np.ndarray | None) -> np.ndarray:
    if basis is None:
        if z.shape[1] < 3:
            pad = 3 - z.shape[1]
            return np.pad(z, ((0, 0), (0, pad)), mode="constant")
        return z[:, :3]
    z_centered = z - mean
    return z_centered @ basis


def _pca_to_3d(z: np.ndarray) -> np.ndarray:
    mean, basis = _pca_fit(z)
    return _pca_apply(z, mean, basis)


def _sample_data(
    x: torch.Tensor, labels: np.ndarray | None, max_samples: int, seed: int
) -> tuple[torch.Tensor, np.ndarray | None, np.ndarray]:
    n = x.shape[0]
    if max_samples <= 0 or max_samples >= n:
        indices = np.arange(n)
        return x, labels, indices
    rng = np.random.default_rng(seed)
    indices = rng.choice(n, size=max_samples, replace=False)
    indices_t = torch.from_numpy(indices)
    x = x[indices_t]
    if labels is not None:
        labels = labels[indices]
    return x, labels, indices


def _dataset_specs(config: dict, data: dict) -> tuple[list[str], tuple[int, int, int]]:
    dataset = str(config.get("dataset") or data.get("dataset_name", "")).lower()
    if "cifar" in dataset:
        return list(CIFAR10_CLASSES), (32, 32, 3)
    if "mnist" in dataset:
        return [str(i) for i in range(10)], (28, 28, 1)
    input_dim = int(config.get("input_dim", 0) or data.get("X_test").shape[1])
    side = int(np.sqrt(input_dim)) if input_dim > 0 else 0
    if side * side == input_dim:
        return [str(i) for i in range(int(config.get("num_classes", 0) or 0))], (
            side,
            side,
            1,
        )
    return [str(i) for i in range(int(config.get("num_classes", 0) or 0))], (
        input_dim,
        1,
        1,
    )


class BaselineClassifier(nn.Module):
    def __init__(self, latent_dim: int, num_classes: int, hidden_dim: int = 32) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, num_classes),
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return self.net(z)


def _prepare_models(
    config: dict,
    state: dict,
    metrics: dict,
    bench_state: dict | None,
    bench_dims: dict | None,
    device: str,
) -> tuple[dict[str, object], dict[str, str | None]]:
    baseline_attn = bool(config.get("baseline_attn", False))
    baseline_attn_tokens = int(config.get("baseline_attn_tokens", 4))
    baseline_attn_dim = int(config.get("baseline_attn_dim", 32))
    baseline_attn_heads = int(config.get("baseline_attn_heads", 4))
    baseline_attn_dropout = float(config.get("baseline_attn_dropout", 0.0))
    baseline_vision_preproc = bool(config.get("baseline_vision_preproc", False))
    vision_in_channels = int(config.get("vision_in_channels", 0))
    vision_height = int(config.get("vision_height", 0))
    vision_width = int(config.get("vision_width", 0))

    bundle_size = config.get("bundle_size")
    if isinstance(bundle_size, int) and bundle_size <= 0:
        bundle_size = None
    soft_equiv_bundle_size = config.get("soft_equiv_bundle_size")
    if isinstance(soft_equiv_bundle_size, int) and soft_equiv_bundle_size <= 0:
        soft_equiv_bundle_size = None
    soft_equiv_soft_assign = config.get("soft_equiv_soft_assign")
    if soft_equiv_soft_assign is None:
        soft_equiv_soft_assign = True
    soft_equiv_temperature = config.get("soft_equiv_temperature")
    if soft_equiv_temperature is None:
        soft_equiv_temperature = 1.0
    covariant_attn = config.get("covariant_attn")
    if covariant_attn is None:
        covariant_attn = True

    model_kwargs = {
        "input_dim": config["input_dim"],
        "hidden_dim": config["hidden_dim"],
        "latent_dim": config["latent_dim"],
        "num_charts": config["num_charts"],
        "codes_per_chart": config["codes_per_chart"],
        "bundle_size": bundle_size,
        "covariant_attn": covariant_attn,
        "covariant_attn_tensorization": config.get("covariant_attn_tensorization", "full"),
        "covariant_attn_rank": config.get("covariant_attn_rank", 8),
        "covariant_attn_tau_min": config.get("covariant_attn_tau_min", 1e-2),
        "covariant_attn_denom_min": config.get("covariant_attn_denom_min", 1e-3),
        "covariant_attn_use_transport": config.get("covariant_attn_use_transport", True),
        "covariant_attn_transport_eps": config.get("covariant_attn_transport_eps", 1e-3),
        "vision_preproc": config.get("vision_preproc", False),
        "vision_in_channels": config.get("vision_in_channels", 0),
        "vision_height": config.get("vision_height", 0),
        "vision_width": config.get("vision_width", 0),
        "vision_num_rotations": config.get("vision_num_rotations", 8),
        "vision_kernel_size": config.get("vision_kernel_size", 5),
        "vision_use_reflections": config.get("vision_use_reflections", False),
        "vision_norm_nonlinearity": config.get("vision_norm_nonlinearity", "n_sigmoid"),
        "vision_norm_bias": config.get("vision_norm_bias", True),
        "soft_equiv_metric": config.get("soft_equiv_metric", False),
        "soft_equiv_bundle_size": soft_equiv_bundle_size,
        "soft_equiv_hidden_dim": config.get("soft_equiv_hidden_dim", 64),
        "soft_equiv_use_spectral_norm": config.get("soft_equiv_use_spectral_norm", True),
        "soft_equiv_zero_self_mixing": config.get("soft_equiv_zero_self_mixing", False),
        "soft_equiv_soft_assign": soft_equiv_soft_assign,
        "soft_equiv_temperature": soft_equiv_temperature,
    }

    model_atlas = TopoEncoderPrimitives(**model_kwargs).to(device)
    model_atlas.load_state_dict(state["atlas"])
    model_atlas.eval()

    sources = {"std": None, "ae": None}

    model_std = None
    if state.get("std") is not None and not config.get("disable_vq", False):
        std_hidden_dim = metrics.get("std_hidden_dim", config["hidden_dim"])
        model_std = StandardVQ(
            input_dim=config["input_dim"],
            hidden_dim=std_hidden_dim,
            latent_dim=config["latent_dim"],
            num_codes=config["num_codes_standard"],
            use_attention=baseline_attn,
            attn_tokens=baseline_attn_tokens,
            attn_dim=baseline_attn_dim,
            attn_heads=baseline_attn_heads,
            attn_dropout=baseline_attn_dropout,
            vision_preproc=baseline_vision_preproc,
            vision_in_channels=vision_in_channels,
            vision_height=vision_height,
            vision_width=vision_width,
        ).to(device)
        model_std.load_state_dict(state["std"])
        model_std.eval()
        sources["std"] = "checkpoint"
    elif bench_state is not None and bench_state.get("std") is not None and not config.get(
        "disable_vq", False
    ):
        std_hidden_dim = (
            (bench_dims or {}).get("std_hidden_dim") or metrics.get("std_hidden_dim")
        ) or config["hidden_dim"]
        model_std = StandardVQ(
            input_dim=config["input_dim"],
            hidden_dim=int(std_hidden_dim),
            latent_dim=config["latent_dim"],
            num_codes=config["num_codes_standard"],
            use_attention=baseline_attn,
            attn_tokens=baseline_attn_tokens,
            attn_dim=baseline_attn_dim,
            attn_heads=baseline_attn_heads,
            attn_dropout=baseline_attn_dropout,
            vision_preproc=baseline_vision_preproc,
            vision_in_channels=vision_in_channels,
            vision_height=vision_height,
            vision_width=vision_width,
        ).to(device)
        model_std.load_state_dict(bench_state["std"])
        model_std.eval()
        sources["std"] = "benchmarks"

    model_ae = None
    if state.get("ae") is not None and not config.get("disable_ae", False):
        ae_hidden_dim = metrics.get("ae_hidden_dim", config["hidden_dim"])
        model_ae = VanillaAE(
            input_dim=config["input_dim"],
            hidden_dim=ae_hidden_dim,
            latent_dim=config["latent_dim"],
            use_attention=baseline_attn,
            attn_tokens=baseline_attn_tokens,
            attn_dim=baseline_attn_dim,
            attn_heads=baseline_attn_heads,
            attn_dropout=baseline_attn_dropout,
            vision_preproc=baseline_vision_preproc,
            vision_in_channels=vision_in_channels,
            vision_height=vision_height,
            vision_width=vision_width,
        ).to(device)
        model_ae.load_state_dict(state["ae"])
        model_ae.eval()
        sources["ae"] = "checkpoint"
    elif bench_state is not None and bench_state.get("ae") is not None and not config.get(
        "disable_ae", False
    ):
        ae_hidden_dim = (
            (bench_dims or {}).get("ae_hidden_dim") or metrics.get("ae_hidden_dim")
        ) or config["hidden_dim"]
        model_ae = VanillaAE(
            input_dim=config["input_dim"],
            hidden_dim=int(ae_hidden_dim),
            latent_dim=config["latent_dim"],
            use_attention=baseline_attn,
            attn_tokens=baseline_attn_tokens,
            attn_dim=baseline_attn_dim,
            attn_heads=baseline_attn_heads,
            attn_dropout=baseline_attn_dropout,
            vision_preproc=baseline_vision_preproc,
            vision_in_channels=vision_in_channels,
            vision_height=vision_height,
            vision_width=vision_width,
        ).to(device)
        model_ae.load_state_dict(bench_state["ae"])
        model_ae.eval()
        sources["ae"] = "benchmarks"

    models = {"topo": model_atlas, "std": model_std, "ae": model_ae}
    return models, sources


def _prepare_classifiers(
    config: dict,
    state: dict,
    num_classes: int | None,
    device: str,
) -> dict[str, nn.Module | None]:
    if not num_classes:
        return {"topo": None, "std": None, "ae": None}

    classifier_bundle_size = config.get("classifier_bundle_size") or None
    classifiers: dict[str, nn.Module | None] = {"topo": None, "std": None, "ae": None}

    if state.get("classifier") is not None:
        topo_classifier = InvariantChartClassifier(
            num_charts=config["num_charts"],
            num_classes=int(num_classes),
            latent_dim=config["latent_dim"],
            bundle_size=classifier_bundle_size,
        ).to(device)
        topo_classifier.load_state_dict(state["classifier"])
        topo_classifier.eval()
        classifiers["topo"] = topo_classifier

    if state.get("classifier_std") is not None:
        std_classifier = BaselineClassifier(
            latent_dim=config["latent_dim"],
            num_classes=int(num_classes),
            hidden_dim=int(config.get("hidden_dim", 32)),
        ).to(device)
        std_classifier.load_state_dict(state["classifier_std"])
        std_classifier.eval()
        classifiers["std"] = std_classifier

    if state.get("classifier_ae") is not None:
        ae_classifier = BaselineClassifier(
            latent_dim=config["latent_dim"],
            num_classes=int(num_classes),
            hidden_dim=int(config.get("hidden_dim", 32)),
        ).to(device)
        ae_classifier.load_state_dict(state["classifier_ae"])
        ae_classifier.eval()
        classifiers["ae"] = ae_classifier

    return classifiers


def _compute_topo_latents(
    model: TopoEncoderPrimitives,
    x: torch.Tensor,
    batch_size: int,
    device: str,
    classifier: nn.Module | None = None,
    labels: np.ndarray | None = None,
) -> dict[str, np.ndarray]:
    z_list = []
    chart_list = []
    code_list = []
    mse_list = []
    pred_list = []
    correct_list = []

    model.eval()
    with torch.no_grad():
        for start, end, batch in _batch_iter(x, batch_size):
            batch = batch.to(device)
            (
                k_chart,
                k_code,
                _z_n,
                z_tex,
                router_weights,
                z_geo,
                _vq_loss,
                _indices,
                _z_n_all,
                _c_bar,
            ) = model.encoder(batch)
            recon, _ = model.decoder(z_geo, z_tex, None)
            mse = ((recon - batch) ** 2).mean(dim=1)
            z_list.append(z_geo.cpu().numpy())
            chart_list.append(k_chart.cpu().numpy())
            code_list.append(k_code.cpu().numpy())
            mse_list.append(mse.cpu().numpy())

            if classifier is not None and labels is not None:
                logits = classifier(router_weights, z_geo)
                preds = logits.argmax(dim=1).cpu().numpy()
                pred_list.append(preds)
                correct_list.append(preds == labels[start:end])

    out = {
        "z": np.concatenate(z_list, axis=0),
        "chart": np.concatenate(chart_list, axis=0),
        "symbol": np.concatenate(code_list, axis=0),
        "recon_mse": np.concatenate(mse_list, axis=0),
    }
    if pred_list:
        out["pred"] = np.concatenate(pred_list, axis=0)
        out["correct"] = np.concatenate(correct_list, axis=0).astype(np.int64)
    return out


def _compute_std_latents(
    model: StandardVQ,
    x: torch.Tensor,
    batch_size: int,
    device: str,
    classifier: nn.Module | None = None,
    labels: np.ndarray | None = None,
    num_charts: int | None = None,
) -> dict[str, np.ndarray]:
    z_e_list = []
    z_q_list = []
    code_list = []
    mse_list = []
    pred_list = []
    correct_list = []

    model.eval()
    with torch.no_grad():
        embed = model.embeddings.weight
        for start, end, batch in _batch_iter(x, batch_size):
            batch = batch.to(device)
            z_e = model.encoder(batch)
            z_sq = (z_e**2).sum(dim=1, keepdim=True)
            e_sq = (embed**2).sum(dim=1).unsqueeze(0)
            dot = torch.matmul(z_e, embed.t())
            dist = z_sq + e_sq - 2.0 * dot
            indices = torch.argmin(dist, dim=1)
            z_q = embed[indices]
            z_st = z_e + (z_q - z_e).detach()
            recon = model.decoder(z_st)
            mse = ((recon - batch) ** 2).mean(dim=1)
            z_e_list.append(z_e.cpu().numpy())
            z_q_list.append(z_q.cpu().numpy())
            code_list.append(indices.cpu().numpy())
            mse_list.append(mse.cpu().numpy())

            if classifier is not None and labels is not None:
                logits = classifier(z_e)
                preds = logits.argmax(dim=1).cpu().numpy()
                pred_list.append(preds)
                correct_list.append(preds == labels[start:end])

    out = {
        "z_e": np.concatenate(z_e_list, axis=0),
        "z_q": np.concatenate(z_q_list, axis=0),
        "code": np.concatenate(code_list, axis=0),
        "recon_mse": np.concatenate(mse_list, axis=0),
    }
    if num_charts:
        out["chart_like"] = out["code"] % int(num_charts)
    if pred_list:
        out["pred"] = np.concatenate(pred_list, axis=0)
        out["correct"] = np.concatenate(correct_list, axis=0).astype(np.int64)
    return out


def _compute_ae_latents(
    model: VanillaAE,
    x: torch.Tensor,
    batch_size: int,
    device: str,
    classifier: nn.Module | None = None,
    labels: np.ndarray | None = None,
) -> dict[str, np.ndarray]:
    z_list = []
    mse_list = []
    pred_list = []
    correct_list = []

    model.eval()
    with torch.no_grad():
        for start, end, batch in _batch_iter(x, batch_size):
            batch = batch.to(device)
            z = model.encoder(batch)
            recon = model.decoder(z)
            mse = ((recon - batch) ** 2).mean(dim=1)
            z_list.append(z.cpu().numpy())
            mse_list.append(mse.cpu().numpy())

            if classifier is not None and labels is not None:
                logits = classifier(z)
                preds = logits.argmax(dim=1).cpu().numpy()
                pred_list.append(preds)
                correct_list.append(preds == labels[start:end])

    out = {
        "z": np.concatenate(z_list, axis=0),
        "recon_mse": np.concatenate(mse_list, axis=0),
    }
    if pred_list:
        out["pred"] = np.concatenate(pred_list, axis=0)
        out["correct"] = np.concatenate(correct_list, axis=0).astype(np.int64)
    return out


def _coerce_color(values: np.ndarray | None) -> np.ndarray | None:
    if values is None:
        return None
    arr = np.asarray(values)
    if arr.ndim > 1:
        arr = arr.reshape(-1)
    if arr.dtype.kind in "iuf":
        return arr
    return pd.factorize(arr)[0]


def _build_metrics_table(
    metrics: dict,
    latents: dict[str, dict[str, np.ndarray] | None],
    models: dict[str, object],
    sources: dict[str, str | None],
) -> pd.DataFrame:
    rows = []

    topo = latents.get("topo")
    if topo is not None:
        topo_perplexity = None
        if models.get("topo") is not None:
            topo_perplexity = models["topo"].compute_perplexity(
                torch.from_numpy(topo["chart"])
            )
        rows.append(
            {
                "model": "TopoEncoder",
                "mse_sample": float(np.mean(topo["recon_mse"])),
                "mse_checkpoint": metrics.get("mse_atlas"),
                "ami_checkpoint": metrics.get("ami_atlas"),
                "perplexity_sample": topo_perplexity,
                "perplexity_checkpoint": metrics.get("atlas_perplexity"),
            }
        )

    std = latents.get("std")
    if std is not None and models.get("std") is not None:
        std_perplexity = models["std"].compute_perplexity(torch.from_numpy(std["code"]))
        rows.append(
            {
                "model": f"StandardVQ ({sources.get('std') or 'unknown'})",
                "mse_sample": float(np.mean(std["recon_mse"])),
                "mse_checkpoint": metrics.get("mse_std"),
                "ami_checkpoint": metrics.get("ami_std"),
                "perplexity_sample": std_perplexity,
                "perplexity_checkpoint": metrics.get("std_perplexity"),
            }
        )

    ae = latents.get("ae")
    if ae is not None:
        rows.append(
            {
                "model": f"VanillaAE ({sources.get('ae') or 'unknown'})",
                "mse_sample": float(np.mean(ae["recon_mse"])),
                "mse_checkpoint": metrics.get("mse_ae"),
                "ami_checkpoint": metrics.get("ami_ae"),
                "perplexity_sample": None,
                "perplexity_checkpoint": None,
            }
        )

    return pd.DataFrame(rows)


In [3]:
COLOR_OPTIONS = [
    "label",
    "chart",
    "meso_symbol",
    "code",
    "correctness",
    "recon_mse",
    "latent_norm",
    "none",
]
COLOR_SCALES = [
    "Viridis",
    "Plasma",
    "Cividis",
    "Turbo",
    "Magma",
    "Inferno",
    "IceFire",
    "Greys",
    "Blues",
    "Reds",
]
CORRECTNESS_CMAP = ["#d62728", "#2ca02c"]

MODEL_LABELS = {"topo": "TopoEncoder", "std": "StandardVQ", "ae": "VanillaAE"}

data_cache = {
    "loaded": False,
    "labels": None,
    "latents": {},
    "config": {},
    "metrics": {},
    "sources": {},
    "classifiers": {},
    "models": {},
    "pca": {},
    "codes_per_chart": None,
    "checkpoint": None,
    "x_sample": None,
    "labels_sample": None,
    "sample_indices": None,
    "image_shape": None,
    "class_names": None,
}

default_dir = _default_experiment_dir()

experiment_dir = pn.widgets.TextInput(
    name="Experiment folder", value=str(default_dir), placeholder="outputs/..."
)
checkpoint_select = pn.widgets.Select(name="Checkpoint", options=[])
refresh_button = pn.widgets.Button(name="Scan checkpoints", button_type="primary")
load_button = pn.widgets.Button(name="Load + Analyze", button_type="success")

max_samples = pn.widgets.IntInput(name="Max samples (0=all)", value=3000, step=500)
batch_size = pn.widgets.IntInput(name="Batch size", value=512, step=128)
sample_seed = pn.widgets.IntInput(name="Sample seed", value=7, step=1)

color_by = pn.widgets.Select(name="Color by", options=COLOR_OPTIONS, value="label")
colorscale = pn.widgets.Select(name="Colorscale", options=COLOR_SCALES, value="Viridis")
marker_size = pn.widgets.FloatSlider(
    name="Marker size", value=4.0, start=1.0, end=12.0, step=0.5
)
marker_opacity = pn.widgets.FloatSlider(
    name="Marker opacity", value=0.7, start=0.1, end=1.0, step=0.05
)
show_colorbar = pn.widgets.Checkbox(name="Show colorbar", value=True)
vq_latent = pn.widgets.RadioButtonGroup(
    name="StandardVQ latent", options=["quantized", "pre-quantized"], value="quantized"
)
show_models = pn.widgets.CheckButtonGroup(
    name="Show models", options=list(MODEL_LABELS.values()), value=list(MODEL_LABELS.values())
)

status_pane = pn.pane.Markdown("Load an experiment to begin.")
metrics_pane = pn.pane.DataFrame(pd.DataFrame())


def _refresh_checkpoints(_=None) -> None:
    exp_dir = Path(experiment_dir.value).expanduser()
    options = [p.name for p in _collect_checkpoints(exp_dir)]
    checkpoint_select.options = options
    if options:
        default = next((opt for opt in options if "final" in opt), options[-1])
        checkpoint_select.value = default
    else:
        checkpoint_select.value = None


def _resolve_latent(model_key: str, vq_choice: str) -> np.ndarray | None:
    data = data_cache["latents"].get(model_key)
    if data is None:
        return None
    if model_key == "std":
        if vq_choice == "pre-quantized":
            return data.get("z_e_3d")
        return data.get("z_q_3d")
    return data.get("z_3d")


def _resolve_color(model_key: str, color_choice: str, vq_choice: str) -> np.ndarray | None:
    labels = data_cache.get("labels")
    data = data_cache["latents"].get(model_key)
    if data is None:
        return None
    if color_choice == "label":
        return labels
    if color_choice == "chart":
        if model_key == "topo":
            return data.get("chart")
        if model_key == "std":
            return data.get("chart_like")
        return None
    if color_choice == "meso_symbol":
        if model_key == "topo":
            return data.get("symbol")
        if model_key == "std":
            return data.get("code")
        return None
    if color_choice == "code":
        if model_key == "topo":
            return data.get("code_global")
        if model_key == "std":
            return data.get("code")
        return None
    if color_choice == "correctness":
        return data.get("correct")
    if color_choice == "recon_mse":
        return data.get("recon_mse")
    if color_choice == "latent_norm":
        z = _resolve_latent(model_key, vq_choice)
        return np.linalg.norm(z, axis=1) if z is not None else None
    if color_choice == "none":
        return None
    return labels


def _plot_model(
    model_key: str,
    color_choice: str,
    colorscale_choice: str,
    size_value: float,
    opacity_value: float,
    vq_choice: str,
    show_colorbar_value: bool,
):
    if not data_cache["loaded"]:
        return pn.pane.Markdown("Load an experiment to render plots.")

    z = _resolve_latent(model_key, vq_choice)
    if z is None:
        return pn.pane.Markdown(f"{MODEL_LABELS[model_key]} not available.")

    df = pd.DataFrame({"x": z[:, 0], "y": z[:, 1], "z": z[:, 2]})

    if color_choice == "none":
        scatter = hv.Scatter3D(df, kdims=["x", "y", "z"]).opts(color="#1f77b4")
    else:
        color_values = _coerce_color(_resolve_color(model_key, color_choice, vq_choice))
        if color_values is None:
            return pn.pane.Markdown(
                f"{MODEL_LABELS[model_key]}: `{color_choice}` coloring unavailable."
            )
        df["color"] = color_values
        if color_choice == "correctness":
            scatter = hv.Scatter3D(df, kdims=["x", "y", "z"], vdims=["color"]).opts(
                color="color",
                cmap=CORRECTNESS_CMAP,
                colorbar=False,
            )
        else:
            scatter = hv.Scatter3D(df, kdims=["x", "y", "z"], vdims=["color"]).opts(
                color="color",
                cmap=colorscale_choice,
                colorbar=show_colorbar_value,
            )

    title = MODEL_LABELS[model_key]
    if model_key == "std":
        title = f"{title} ({vq_choice})"

    return scatter.opts(
        width=420,
        height=420,
        size=size_value,
        alpha=opacity_value,
        title=title,
        xlabel="z1",
        ylabel="z2",
        zlabel="z3",
    )


@pn.depends(
    show_models.param.value,
    color_by.param.value,
    colorscale.param.value,
    marker_size.param.value,
    marker_opacity.param.value,
    vq_latent.param.value,
    show_colorbar.param.value,
)
def plot_row(
    show_models_value,
    color_choice,
    colorscale_choice,
    size_value,
    opacity_value,
    vq_choice,
    show_colorbar_value,
):
    if not data_cache["loaded"]:
        return pn.pane.Markdown("Load an experiment to render plots.")

    panels = []
    if "TopoEncoder" in show_models_value:
        panels.append(
            pn.panel(
                _plot_model(
                    "topo",
                    color_choice,
                    colorscale_choice,
                    size_value,
                    opacity_value,
                    vq_choice,
                    show_colorbar_value,
                )
            )
        )
    if "StandardVQ" in show_models_value:
        panels.append(
            pn.panel(
                _plot_model(
                    "std",
                    color_choice,
                    colorscale_choice,
                    size_value,
                    opacity_value,
                    vq_choice,
                    show_colorbar_value,
                )
            )
        )
    if "VanillaAE" in show_models_value:
        panels.append(
            pn.panel(
                _plot_model(
                    "ae",
                    color_choice,
                    colorscale_choice,
                    size_value,
                    opacity_value,
                    vq_choice,
                    show_colorbar_value,
                )
            )
        )

    return pn.FlexBox(*panels, flex_wrap="wrap", gap="16px")


def _load_experiment(_=None) -> None:
    exp_dir = Path(experiment_dir.value).expanduser()
    if not exp_dir.exists():
        status_pane.object = f"Experiment folder not found: `{exp_dir}`"
        data_cache["loaded"] = False
        return

    if not checkpoint_select.value:
        status_pane.object = "No checkpoint selected."
        data_cache["loaded"] = False
        return

    checkpoint_path = exp_dir / checkpoint_select.value
    checkpoint = _load_checkpoint(checkpoint_path)
    if "data" not in checkpoint:
        status_pane.object = f"Checkpoint has no data: `{checkpoint_path.name}`"
        data_cache["loaded"] = False
        return

    config = checkpoint["config"]
    metrics = checkpoint.get("metrics", {})
    state = checkpoint["state"]

    benchmarks = _load_benchmarks(checkpoint_path)
    bench_state = benchmarks.get("state", {}) if benchmarks else None
    bench_dims = benchmarks.get("dims", {}) if benchmarks else None

    device = "cpu"
    models, sources = _prepare_models(config, state, metrics, bench_state, bench_dims, device)

    data = checkpoint["data"]
    class_names, image_shape = _dataset_specs(config, data)
    x_test = _as_tensor(data["X_test"])
    labels = _as_numpy(data.get("labels_test"))
    x_test, labels, sample_indices = _sample_data(
        x_test, labels, max_samples.value, sample_seed.value
    )

    num_classes = config.get("num_classes")
    if labels is not None:
        num_classes = int(labels.max()) + 1
    classifiers = _prepare_classifiers(config, state, num_classes, device)

    latents = {}
    pca = {}
    topo = _compute_topo_latents(
        models["topo"],
        x_test,
        batch_size.value,
        device,
        classifier=classifiers.get("topo"),
        labels=labels,
    )
    topo_pca = _pca_fit(topo["z"])
    topo["z_3d"] = _pca_apply(topo["z"], *topo_pca)
    codes_per_chart = int(config.get("codes_per_chart", 1))
    topo["code_global"] = topo["chart"] * codes_per_chart + topo["symbol"]
    latents["topo"] = topo
    pca["topo"] = topo_pca

    if models["std"] is not None:
        std = _compute_std_latents(
            models["std"],
            x_test,
            batch_size.value,
            device,
            classifier=classifiers.get("std"),
            labels=labels,
            num_charts=config.get("num_charts"),
        )
        std_pca_e = _pca_fit(std["z_e"])
        std_pca_q = _pca_fit(std["z_q"])
        std["z_e_3d"] = _pca_apply(std["z_e"], *std_pca_e)
        std["z_q_3d"] = _pca_apply(std["z_q"], *std_pca_q)
        latents["std"] = std
        pca["std"] = {"z_e": std_pca_e, "z_q": std_pca_q}
    else:
        latents["std"] = None

    if models["ae"] is not None:
        ae = _compute_ae_latents(
            models["ae"],
            x_test,
            batch_size.value,
            device,
            classifier=classifiers.get("ae"),
            labels=labels,
        )
        ae_pca = _pca_fit(ae["z"])
        ae["z_3d"] = _pca_apply(ae["z"], *ae_pca)
        latents["ae"] = ae
        pca["ae"] = ae_pca
    else:
        latents["ae"] = None

    data_cache.update(
        {
            "loaded": True,
            "labels": labels,
            "latents": latents,
            "config": config,
            "metrics": metrics,
            "sources": sources,
            "classifiers": classifiers,
            "models": models,
            "pca": pca,
            "codes_per_chart": codes_per_chart,
            "checkpoint": checkpoint_path,
            "x_sample": x_test,
            "labels_sample": labels,
            "sample_indices": sample_indices,
            "image_shape": image_shape,
            "class_names": class_names,
        }
    )

    metrics_pane.object = _build_metrics_table(metrics, latents, models, sources)

    available = ["TopoEncoder"]
    if latents.get("std") is not None:
        available.append("StandardVQ")
    if latents.get("ae") is not None:
        available.append("VanillaAE")
    show_models.options = available
    show_models.value = [name for name in show_models.value if name in available] or available

    dataset = config.get("dataset", "unknown")
    classifier_note = (
        f"topo={classifiers.get('topo') is not None}, "
        f"std={classifiers.get('std') is not None}, "
        f"ae={classifiers.get('ae') is not None}"
    )
    status_pane.object = (
        f"**Checkpoint:** `{checkpoint_path}`\n"
        f"- dataset: `{dataset}`\n"
        f"- samples: {x_test.shape[0]}\n"
        f"- latent_dim: {config.get('latent_dim')}\n"
        f"- charts: {config.get('num_charts')}\n"
        f"- benchmarks: std={sources.get('std')}, ae={sources.get('ae')}\n"
        f"- classifiers: {classifier_note}"
    )


refresh_button.on_click(_refresh_checkpoints)
load_button.on_click(_load_experiment)
experiment_dir.param.watch(lambda event: _refresh_checkpoints(), "value")
_refresh_checkpoints()

controls = pn.Column(
    "### Load",
    experiment_dir,
    refresh_button,
    checkpoint_select,
    load_button,
    "### Sampling",
    max_samples,
    batch_size,
    sample_seed,
    "### Display",
    color_by,
    colorscale,
    marker_size,
    marker_opacity,
    show_colorbar,
    vq_latent,
    show_models,
    width=340,
)

dashboard = pn.Column(
    pn.Row(controls, pn.Column(status_pane, metrics_pane, sizing_mode="stretch_width")),
    plot_row,
    sizing_mode="stretch_width",
)

dashboard


In [5]:
from PIL import Image

IMAGE_MODEL_KEYS = {label: key for key, label in MODEL_LABELS.items()}

image_model = pn.widgets.Select(name="Model", options=list(IMAGE_MODEL_KEYS.keys()))
image_class = pn.widgets.Select(name="Class", options=["all"], value="all")
image_index = pn.widgets.IntInput(name="Sample index", value=0, start=0, end=0)
image_random = pn.widgets.Button(name="Random sample", button_type="primary")

image_rotation = pn.widgets.Select(name="Rotate", options=[0, 90, 180, 270], value=0)
image_noise = pn.widgets.FloatSlider(
    name="Noise sigma", value=0.0, start=0.0, end=0.5, step=0.01
)
image_patch = pn.widgets.IntSlider(name="Patch size", value=0, start=0, end=28, step=1)
image_patch_x = pn.widgets.IntSlider(name="Patch x", value=0, start=0, end=27, step=1)
image_patch_y = pn.widgets.IntSlider(name="Patch y", value=0, start=0, end=27, step=1)
image_vq_latent = pn.widgets.RadioButtonGroup(
    name="StandardVQ latent", options=["quantized", "pre-quantized"], value="quantized"
)
image_show_latent = pn.widgets.Checkbox(name="Show latent cloud", value=True)

image_status = pn.pane.Markdown("Load an experiment in the first dashboard.")


def _available_models() -> list[str]:
    available = []
    if data_cache.get("latents", {}).get("topo") is not None:
        available.append("TopoEncoder")
    if data_cache.get("latents", {}).get("std") is not None:
        available.append("StandardVQ")
    if data_cache.get("latents", {}).get("ae") is not None:
        available.append("VanillaAE")
    return available or list(IMAGE_MODEL_KEYS.keys())


def _ensure_image_controls() -> None:
    if not data_cache.get("loaded"):
        return
    available = _available_models()
    if image_model.options != available:
        image_model.options = available
        if image_model.value not in available:
            image_model.value = available[0]

    class_names = data_cache.get("class_names") or []
    options = ["all"] + list(class_names) if class_names else ["all"]
    if image_class.options != options:
        image_class.options = options
        if image_class.value not in options:
            image_class.value = options[0]

    image_shape = data_cache.get("image_shape") or (28, 28, 1)
    height, width, _channels = image_shape
    image_patch.end = max(1, min(height, width))
    image_patch_x.end = max(0, width - 1)
    image_patch_y.end = max(0, height - 1)
    if image_patch.value > image_patch.end:
        image_patch.value = image_patch.end
    if image_patch_x.value > image_patch_x.end:
        image_patch_x.value = image_patch_x.end
    if image_patch_y.value > image_patch_y.end:
        image_patch_y.value = image_patch_y.end


def _sample_pool(labels: np.ndarray | None, class_value: str) -> np.ndarray:
    if labels is None:
        return np.arange(data_cache["x_sample"].shape[0])
    if class_value == "all":
        return np.arange(labels.shape[0])
    class_names = data_cache.get("class_names") or []
    if class_value in class_names:
        class_idx = class_names.index(class_value)
    else:
        class_idx = int(class_value)
    return np.where(labels == class_idx)[0]


def _sync_sample_index() -> np.ndarray:
    labels = data_cache.get("labels_sample")
    pool = _sample_pool(labels, image_class.value)
    if pool.size == 0:
        image_index.end = 0
        image_index.value = 0
        return pool
    image_index.end = int(pool.size - 1)
    if image_index.value > image_index.end:
        image_index.value = image_index.end
    return pool


def _choose_random_sample(_=None) -> None:
    if not data_cache.get("loaded"):
        return
    pool = _sync_sample_index()
    if pool.size == 0:
        return
    rng = np.random.default_rng()
    image_index.value = int(rng.integers(0, pool.size))


def _to_image_array(x: np.ndarray | torch.Tensor, image_shape: tuple[int, int, int]) -> np.ndarray:
    arr = _as_numpy(x)
    height, width, channels = image_shape
    if arr.ndim == 1:
        arr = arr.reshape((height, width, channels))
    if channels == 1:
        return arr.reshape(height, width)
    return arr.reshape(height, width, channels)


def _flatten_image(img: np.ndarray, image_shape: tuple[int, int, int]) -> np.ndarray:
    height, width, channels = image_shape
    if channels == 1 and img.ndim == 2:
        return img.reshape(height * width)
    return img.reshape(height * width * channels)


def _apply_rotation(img: np.ndarray, rotation: int) -> np.ndarray:
    if rotation % 360 == 0:
        return img
    k = rotation // 90
    return np.rot90(img, k, axes=(0, 1))


def _apply_noise(img: np.ndarray, sigma: float) -> np.ndarray:
    if sigma <= 0:
        return img
    rng = np.random.default_rng()
    noisy = img + rng.normal(0.0, sigma, size=img.shape)
    return np.clip(noisy, 0.0, 1.0)


def _apply_patch(img: np.ndarray, size: int, x: int, y: int) -> np.ndarray:
    if size <= 0:
        return img
    height, width = img.shape[:2]
    x0 = max(0, min(x, width - 1))
    y0 = max(0, min(y, height - 1))
    x1 = min(width, x0 + size)
    y1 = min(height, y0 + size)
    img = img.copy()
    if img.ndim == 2:
        img[y0:y1, x0:x1] = 0.0
    else:
        img[y0:y1, x0:x1, :] = 0.0
    return img


def _to_pil_image(img: np.ndarray) -> Image.Image:
    img = np.clip(img, 0.0, 1.0)
    if img.ndim == 2:
        arr = (img * 255).astype("uint8")
        return Image.fromarray(arr, mode="L")
    arr = (img * 255).astype("uint8")
    return Image.fromarray(arr, mode="RGB")


def _encode_single(
    model_key: str,
    model: object,
    classifier: nn.Module | None,
    x_flat: np.ndarray,
    vq_choice: str,
) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
    x_tensor = torch.from_numpy(x_flat).float().unsqueeze(0)
    if model_key == "topo":
        (
            _k_chart,
            _k_code,
            _z_n,
            z_tex,
            router_weights,
            z_geo,
            _vq_loss,
            _indices,
            _z_n_all,
            _c_bar,
        ) = model.encoder(x_tensor)
        recon, _ = model.decoder(z_geo, z_tex, None)
        logits = classifier(router_weights, z_geo) if classifier is not None else None
        latent = z_geo
    elif model_key == "std":
        z_e = model.encoder(x_tensor)
        embed = model.embeddings.weight
        z_sq = (z_e**2).sum(dim=1, keepdim=True)
        e_sq = (embed**2).sum(dim=1).unsqueeze(0)
        dot = torch.matmul(z_e, embed.t())
        dist = z_sq + e_sq - 2.0 * dot
        indices = torch.argmin(dist, dim=1)
        z_q = embed[indices]
        z_st = z_e + (z_q - z_e).detach()
        recon = model.decoder(z_st)
        logits = classifier(z_e) if classifier is not None else None
        latent = z_q if vq_choice == "quantized" else z_e
    else:
        z = model.encoder(x_tensor)
        recon = model.decoder(z)
        logits = classifier(z) if classifier is not None else None
        latent = z
    recon_np = recon.squeeze(0).detach().cpu().numpy()
    latent_np = latent.squeeze(0).detach().cpu().numpy()
    probs = None
    if logits is not None:
        probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().numpy()
    return recon_np, latent_np, probs


def _latent_to_3d(model_key: str, latent: np.ndarray, vq_choice: str) -> np.ndarray:
    if model_key == "std":
        pca_entry = data_cache.get("pca", {}).get("std", {})
        pca_params = pca_entry.get("z_q") if vq_choice == "quantized" else pca_entry.get("z_e")
    else:
        pca_params = data_cache.get("pca", {}).get(model_key)
    if not pca_params:
        return _pca_to_3d(latent.reshape(1, -1))
    mean, basis = pca_params
    return _pca_apply(latent.reshape(1, -1), mean, basis)


def _latent_background(model_key: str, vq_choice: str) -> np.ndarray | None:
    latents = data_cache.get("latents", {})
    if model_key == "std":
        data = latents.get("std")
        if data is None:
            return None
        return data.get("z_q_3d") if vq_choice == "quantized" else data.get("z_e_3d")
    data = latents.get(model_key)
    if data is None:
        return None
    return data.get("z_3d")


def _probability_plot(probs: np.ndarray | None, class_names: list[str]) -> pn.viewable.Viewable:
    if probs is None:
        return pn.pane.Markdown("Classifier unavailable for this model.")
    labels = class_names or [str(i) for i in range(len(probs))]
    df = pd.DataFrame({"class": labels, "prob": probs})
    bars = hv.Bars(df, kdims=["class"], vdims=["prob"]).opts(
        width=420,
        height=220,
        ylabel="p(class)",
        xlabel="class",
        title="Output distribution",
    )
    return bars


def _latent_plot(
    z_background: np.ndarray | None,
    z_orig: np.ndarray,
    z_pert: np.ndarray,
    show_cloud: bool,
) -> pn.viewable.Viewable:
    overlays = []
    if show_cloud and z_background is not None:
        df_bg = pd.DataFrame(
            {"x": z_background[:, 0], "y": z_background[:, 1], "z": z_background[:, 2]}
        )
        overlays.append(
            hv.Scatter3D(df_bg, kdims=["x", "y", "z"]).opts(
                color="#888888",
                size=2,
                alpha=0.12,
            )
        )
    df_orig = pd.DataFrame({"x": [z_orig[0, 0]], "y": [z_orig[0, 1]], "z": [z_orig[0, 2]]})
    df_pert = pd.DataFrame({"x": [z_pert[0, 0]], "y": [z_pert[0, 1]], "z": [z_pert[0, 2]]})
    overlays.append(
        hv.Scatter3D(df_orig, kdims=["x", "y", "z"]).opts(color="#1f77b4", size=8)
    )
    overlays.append(
        hv.Scatter3D(df_pert, kdims=["x", "y", "z"]).opts(color="#ff7f0e", size=8)
    )
    plot = hv.Overlay(overlays).opts(
        width=420,
        height=420,
        title="Latent position (orig vs perturbed)",
        xlabel="z1",
        ylabel="z2",
        zlabel="z3",
    )
    return plot


@pn.depends(
    image_model.param.value,
    image_class.param.value,
    image_index.param.value,
    image_rotation.param.value,
    image_noise.param.value_throttled,
    image_patch.param.value_throttled,
    image_patch_x.param.value_throttled,
    image_patch_y.param.value_throttled,
    image_vq_latent.param.value,
    image_show_latent.param.value,
)
def render_image_dashboard(
    model_label,
    class_value,
    index_value,
    rotation_value,
    noise_value,
    patch_value,
    patch_x_value,
    patch_y_value,
    vq_choice,
    show_latent_value,
):
    if not data_cache.get("loaded"):
        return pn.pane.Markdown("Load an experiment in the first dashboard.")

    _ensure_image_controls()
    pool = _sync_sample_index()
    if pool.size == 0:
        return pn.pane.Markdown("No samples available for the selected class.")

    sample_idx = int(pool[int(index_value)])
    x_sample = data_cache["x_sample"][sample_idx]
    labels = data_cache.get("labels_sample")
    true_label = int(labels[sample_idx]) if labels is not None else None
    class_names = data_cache.get("class_names") or []
    image_shape = data_cache.get("image_shape") or (28, 28, 1)

    orig_img = _to_image_array(x_sample, image_shape)
    pert_img = _apply_rotation(orig_img, rotation_value)
    pert_img = _apply_noise(pert_img, noise_value)
    pert_img = _apply_patch(pert_img, patch_value, patch_x_value, patch_y_value)

    orig_flat = _flatten_image(orig_img, image_shape)
    pert_flat = _flatten_image(pert_img, image_shape)

    model_key = IMAGE_MODEL_KEYS.get(model_label)
    model = data_cache.get("models", {}).get(model_key)
    classifier = data_cache.get("classifiers", {}).get(model_key)
    if model is None:
        return pn.pane.Markdown(f"Model `{model_label}` unavailable for this checkpoint.")

    recon, latent_pert, probs = _encode_single(
        model_key, model, classifier, pert_flat, vq_choice
    )
    _recon_orig, latent_orig, _ = _encode_single(
        model_key, model, classifier, orig_flat, vq_choice
    )

    recon_img = _to_image_array(recon, image_shape)

    z_orig = _latent_to_3d(model_key, latent_orig, vq_choice)
    z_pert = _latent_to_3d(model_key, latent_pert, vq_choice)
    z_bg = _latent_background(model_key, vq_choice)

    image_status.object = (
        f"**Sample index:** {sample_idx}\n"
        f"- class: {class_names[true_label] if class_names and true_label is not None else true_label}\n"
        f"- model: {model_label}"
    )

    prob_plot = _probability_plot(probs, class_names)
    latent_plot = _latent_plot(z_bg, z_orig, z_pert, show_latent_value)

    images = pn.Row(
        pn.pane.Image(_to_pil_image(orig_img), width=180, height=180, caption="Original"),
        pn.pane.Image(_to_pil_image(pert_img), width=180, height=180, caption="Perturbed"),
        pn.pane.Image(_to_pil_image(recon_img), width=180, height=180, caption="Reconstruction"),
    )

    return pn.Column(image_status, images, prob_plot, latent_plot)


image_random.on_click(_choose_random_sample)

image_controls = pn.Column(
    "### Sample",
    image_model,
    image_class,
    image_index,
    image_random,
    "### Perturbations",
    image_rotation,
    image_noise,
    image_patch,
    image_patch_x,
    image_patch_y,
    image_vq_latent,
    image_show_latent,
    width=340,
)

image_dashboard = pn.Column(
    pn.Row(image_controls, render_image_dashboard),
    sizing_mode="stretch_width",
)

image_dashboard
