# TopoEncoder Latent Space Dashboard (Panel + HoloViews)

This notebook builds a Panel dashboard that loads a TopoEncoder experiment folder (checkpoint + benchmarks)
and compares 3D latent spaces for the TopoEncoder, StandardVQ, and VanillaAE baselines. Use the controls
to select checkpoints, sampling, coloring, and styling so you can swap visualization styles without
regenerating static images.

**Usage**
1. Set the experiment folder in the sidebar (for example: `outputs/3d_topoencoder_mnist_cpu_adapt_lr7_bnch`).
2. Click **Scan checkpoints** and **Load + Analyze**.
3. Explore the latent plots and metrics.


In [1]:
from __future__ import annotations

from pathlib import Path
import re

import numpy as np
import pandas as pd
import torch
import panel as pn
import holoviews as hv

from fragile.core.layers import StandardVQ, TopoEncoderPrimitives, VanillaAE

# Plotly backend is required for 3D scatter. Raise a clear error if missing.
try:
    pn.extension("plotly", sizing_mode="stretch_width")
    hv.extension("plotly")
except Exception as exc:
    raise RuntimeError(
        "Plotly backend is required for this dashboard. Install plotly and rerun."
    ) from exc


In [3]:
def _load_checkpoint(path: Path) -> dict:
    if not path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    try:
        return torch.load(path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(path, map_location="cpu")


def _load_benchmarks(checkpoint_path: Path) -> dict | None:
    bench_path = checkpoint_path.parent / "benchmarks.pt"
    if not bench_path.exists():
        return None
    try:
        return torch.load(bench_path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(bench_path, map_location="cpu")


def _checkpoint_sort_key(path: Path) -> tuple[int, int, str]:
    name = path.name
    if "final" in name:
        return (1, 0, name)
    match = re.search(r"(\d+)", name)
    if match:
        return (0, int(match.group(1)), name)
    return (0, -1, name)


def _collect_checkpoints(exp_dir: Path) -> list[Path]:
    if not exp_dir.exists():
        return []
    return sorted(
        [p for p in exp_dir.glob("*.pt") if p.name != "benchmarks.pt"],
        key=_checkpoint_sort_key,
    )


def _default_experiment_dir() -> Path:
    preferred = Path("outputs/3d_topoencoder_mnist_cpu_adapt_lr7_bnch")
    if preferred.exists():
        return preferred
    outputs = Path("outputs")
    if not outputs.exists():
        return Path(".")
    for candidate in sorted(outputs.iterdir()):
        if candidate.is_dir() and (candidate / "benchmarks.pt").exists():
            return candidate
    return outputs


def _as_tensor(x: np.ndarray | torch.Tensor) -> torch.Tensor:
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x).float()
    return x.float()


def _as_numpy(x) -> np.ndarray | None:
    if x is None:
        return None
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)


def _batch_iter(x: torch.Tensor, batch_size: int):
    n = x.shape[0]
    for start in range(0, n, batch_size):
        yield x[start : start + batch_size]


def _pca_to_3d(z: np.ndarray) -> np.ndarray:
    if z.shape[1] == 3:
        return z
    if z.shape[1] < 3:
        pad = 3 - z.shape[1]
        return np.pad(z, ((0, 0), (0, pad)), mode="constant")
    z_centered = z - z.mean(axis=0, keepdims=True)
    _, _, vt = np.linalg.svd(z_centered, full_matrices=False)
    return z_centered @ vt[:3].T


def _sample_data(
    x: torch.Tensor, labels: np.ndarray | None, max_samples: int, seed: int
) -> tuple[torch.Tensor, np.ndarray | None]:
    if max_samples <= 0 or max_samples >= x.shape[0]:
        return x, labels
    rng = np.random.default_rng(seed)
    indices = rng.choice(x.shape[0], size=max_samples, replace=False)
    indices_t = torch.from_numpy(indices)
    x = x[indices_t]
    if labels is not None:
        labels = labels[indices]
    return x, labels


def _prepare_models(
    config: dict,
    state: dict,
    metrics: dict,
    bench_state: dict | None,
    bench_dims: dict | None,
    device: str,
) -> tuple[dict[str, object], dict[str, str | None]]:
    model_atlas = TopoEncoderPrimitives(
        input_dim=config["input_dim"],
        hidden_dim=config["hidden_dim"],
        latent_dim=config["latent_dim"],
        num_charts=config["num_charts"],
        codes_per_chart=config["codes_per_chart"],
    ).to(device)
    model_atlas.load_state_dict(state["atlas"])
    model_atlas.eval()

    sources = {"std": None, "ae": None}

    model_std = None
    if state.get("std") is not None and not config.get("disable_vq", False):
        std_hidden_dim = metrics.get("std_hidden_dim", config["hidden_dim"])
        model_std = StandardVQ(
            input_dim=config["input_dim"],
            hidden_dim=std_hidden_dim,
            latent_dim=config["latent_dim"],
            num_codes=config["num_codes_standard"],
        ).to(device)
        model_std.load_state_dict(state["std"])
        model_std.eval()
        sources["std"] = "checkpoint"
    elif bench_state is not None and bench_state.get("std") is not None and not config.get(
        "disable_vq", False
    ):
        std_hidden_dim = (
            (bench_dims or {}).get("std_hidden_dim") or metrics.get("std_hidden_dim")
        ) or config["hidden_dim"]
        model_std = StandardVQ(
            input_dim=config["input_dim"],
            hidden_dim=int(std_hidden_dim),
            latent_dim=config["latent_dim"],
            num_codes=config["num_codes_standard"],
        ).to(device)
        model_std.load_state_dict(bench_state["std"])
        model_std.eval()
        sources["std"] = "benchmarks"

    model_ae = None
    if state.get("ae") is not None and not config.get("disable_ae", False):
        ae_hidden_dim = metrics.get("ae_hidden_dim", config["hidden_dim"])
        model_ae = VanillaAE(
            input_dim=config["input_dim"],
            hidden_dim=ae_hidden_dim,
            latent_dim=config["latent_dim"],
        ).to(device)
        model_ae.load_state_dict(state["ae"])
        model_ae.eval()
        sources["ae"] = "checkpoint"
    elif bench_state is not None and bench_state.get("ae") is not None and not config.get(
        "disable_ae", False
    ):
        ae_hidden_dim = (
            (bench_dims or {}).get("ae_hidden_dim") or metrics.get("ae_hidden_dim")
        ) or config["hidden_dim"]
        model_ae = VanillaAE(
            input_dim=config["input_dim"],
            hidden_dim=int(ae_hidden_dim),
            latent_dim=config["latent_dim"],
        ).to(device)
        model_ae.load_state_dict(bench_state["ae"])
        model_ae.eval()
        sources["ae"] = "benchmarks"

    models = {"topo": model_atlas, "std": model_std, "ae": model_ae}
    return models, sources


def _compute_topo_latents(
    model: TopoEncoderPrimitives, x: torch.Tensor, batch_size: int, device: str
) -> dict[str, np.ndarray]:
    z_list = []
    chart_list = []
    code_list = []
    mse_list = []

    model.eval()
    with torch.no_grad():
        for batch in _batch_iter(x, batch_size):
            batch = batch.to(device)
            (
                k_chart,
                k_code,
                _z_n,
                z_tex,
                _router,
                z_geo,
                _vq_loss,
                _indices,
                _z_n_all,
                _c_bar,
            ) = model.encoder(batch)
            recon, _ = model.decoder(z_geo, z_tex, None)
            mse = ((recon - batch) ** 2).mean(dim=1)
            z_list.append(z_geo.cpu().numpy())
            chart_list.append(k_chart.cpu().numpy())
            code_list.append(k_code.cpu().numpy())
            mse_list.append(mse.cpu().numpy())

    return {
        "z": np.concatenate(z_list, axis=0),
        "chart": np.concatenate(chart_list, axis=0),
        "code": np.concatenate(code_list, axis=0),
        "recon_mse": np.concatenate(mse_list, axis=0),
    }


def _compute_std_latents(
    model: StandardVQ, x: torch.Tensor, batch_size: int, device: str
) -> dict[str, np.ndarray]:
    z_e_list = []
    z_q_list = []
    code_list = []
    mse_list = []

    model.eval()
    with torch.no_grad():
        embed = model.embeddings.weight
        for batch in _batch_iter(x, batch_size):
            batch = batch.to(device)
            z_e = model.encoder(batch)
            z_sq = (z_e**2).sum(dim=1, keepdim=True)
            e_sq = (embed**2).sum(dim=1).unsqueeze(0)
            dot = torch.matmul(z_e, embed.t())
            dist = z_sq + e_sq - 2.0 * dot
            indices = torch.argmin(dist, dim=1)
            z_q = embed[indices]
            z_st = z_e + (z_q - z_e).detach()
            recon = model.decoder(z_st)
            mse = ((recon - batch) ** 2).mean(dim=1)
            z_e_list.append(z_e.cpu().numpy())
            z_q_list.append(z_q.cpu().numpy())
            code_list.append(indices.cpu().numpy())
            mse_list.append(mse.cpu().numpy())

    return {
        "z_e": np.concatenate(z_e_list, axis=0),
        "z_q": np.concatenate(z_q_list, axis=0),
        "code": np.concatenate(code_list, axis=0),
        "recon_mse": np.concatenate(mse_list, axis=0),
    }


def _compute_ae_latents(
    model: VanillaAE, x: torch.Tensor, batch_size: int, device: str
) -> dict[str, np.ndarray]:
    z_list = []
    mse_list = []

    model.eval()
    with torch.no_grad():
        for batch in _batch_iter(x, batch_size):
            batch = batch.to(device)
            z = model.encoder(batch)
            recon = model.decoder(z)
            mse = ((recon - batch) ** 2).mean(dim=1)
            z_list.append(z.cpu().numpy())
            mse_list.append(mse.cpu().numpy())

    return {
        "z": np.concatenate(z_list, axis=0),
        "recon_mse": np.concatenate(mse_list, axis=0),
    }


def _coerce_color(values: np.ndarray | None) -> np.ndarray | None:
    if values is None:
        return None
    arr = np.asarray(values)
    if arr.ndim > 1:
        arr = arr.reshape(-1)
    if arr.dtype.kind in "iuf":
        return arr
    return pd.factorize(arr)[0]


def _build_metrics_table(
    metrics: dict,
    latents: dict[str, dict[str, np.ndarray] | None],
    models: dict[str, object],
    sources: dict[str, str | None],
) -> pd.DataFrame:
    rows = []

    topo = latents.get("topo")
    if topo is not None:
        topo_perplexity = None
        if models.get("topo") is not None:
            topo_perplexity = models["topo"].compute_perplexity(
                torch.from_numpy(topo["chart"])
            )
        rows.append(
            {
                "model": "TopoEncoder",
                "mse_sample": float(np.mean(topo["recon_mse"])),
                "mse_checkpoint": metrics.get("mse_atlas"),
                "ami_checkpoint": metrics.get("ami_atlas"),
                "perplexity_sample": topo_perplexity,
                "perplexity_checkpoint": metrics.get("atlas_perplexity"),
            }
        )

    std = latents.get("std")
    if std is not None and models.get("std") is not None:
        std_perplexity = models["std"].compute_perplexity(torch.from_numpy(std["code"]))
        rows.append(
            {
                "model": f"StandardVQ ({sources.get('std') or 'unknown'})",
                "mse_sample": float(np.mean(std["recon_mse"])),
                "mse_checkpoint": metrics.get("mse_std"),
                "ami_checkpoint": metrics.get("ami_std"),
                "perplexity_sample": std_perplexity,
                "perplexity_checkpoint": metrics.get("std_perplexity"),
            }
        )

    ae = latents.get("ae")
    if ae is not None:
        rows.append(
            {
                "model": f"VanillaAE ({sources.get('ae') or 'unknown'})",
                "mse_sample": float(np.mean(ae["recon_mse"])),
                "mse_checkpoint": metrics.get("mse_ae"),
                "ami_checkpoint": metrics.get("ami_ae"),
                "perplexity_sample": None,
                "perplexity_checkpoint": None,
            }
        )

    return pd.DataFrame(rows)


In [4]:
COLOR_OPTIONS = ["label", "chart", "code", "recon_mse", "latent_norm", "none"]
COLOR_SCALES = [
    "Viridis",
    "Plasma",
    "Cividis",
    "Turbo",
    "Magma",
    "Inferno",
    "IceFire",
    "Greys",
    "Blues",
    "Reds",
]

MODEL_LABELS = {"topo": "TopoEncoder", "std": "StandardVQ", "ae": "VanillaAE"}

data_cache = {
    "loaded": False,
    "labels": None,
    "latents": {},
    "config": {},
    "metrics": {},
    "sources": {},
    "codes_per_chart": None,
    "checkpoint": None,
}

default_dir = _default_experiment_dir()

experiment_dir = pn.widgets.TextInput(
    name="Experiment folder", value=str(default_dir), placeholder="outputs/..."
)
checkpoint_select = pn.widgets.Select(name="Checkpoint", options=[])
refresh_button = pn.widgets.Button(name="Scan checkpoints", button_type="primary")
load_button = pn.widgets.Button(name="Load + Analyze", button_type="success")

max_samples = pn.widgets.IntInput(name="Max samples (0=all)", value=3000, step=500)
batch_size = pn.widgets.IntInput(name="Batch size", value=512, step=128)
sample_seed = pn.widgets.IntInput(name="Sample seed", value=7, step=1)

color_by = pn.widgets.Select(name="Color by", options=COLOR_OPTIONS, value="label")
colorscale = pn.widgets.Select(name="Colorscale", options=COLOR_SCALES, value="Viridis")
marker_size = pn.widgets.FloatSlider(
    name="Marker size", value=4.0, start=1.0, end=12.0, step=0.5
)
marker_opacity = pn.widgets.FloatSlider(
    name="Marker opacity", value=0.7, start=0.1, end=1.0, step=0.05
)
show_colorbar = pn.widgets.Checkbox(name="Show colorbar", value=True)
vq_latent = pn.widgets.RadioButtonGroup(
    name="StandardVQ latent", options=["quantized", "pre-quantized"], value="quantized"
)
show_models = pn.widgets.CheckButtonGroup(
    name="Show models", options=list(MODEL_LABELS.values()), value=list(MODEL_LABELS.values())
)

status_pane = pn.pane.Markdown("Load an experiment to begin.")
metrics_pane = pn.pane.DataFrame(pd.DataFrame())


def _refresh_checkpoints(_=None) -> None:
    exp_dir = Path(experiment_dir.value).expanduser()
    options = [p.name for p in _collect_checkpoints(exp_dir)]
    checkpoint_select.options = options
    if options:
        default = next((opt for opt in options if "final" in opt), options[-1])
        checkpoint_select.value = default
    else:
        checkpoint_select.value = None


def _resolve_latent(model_key: str, vq_choice: str) -> np.ndarray | None:
    data = data_cache["latents"].get(model_key)
    if data is None:
        return None
    if model_key == "std":
        if vq_choice == "pre-quantized":
            return data.get("z_e_3d")
        return data.get("z_q_3d")
    return data.get("z_3d")


def _resolve_color(model_key: str, color_choice: str, vq_choice: str) -> np.ndarray | None:
    labels = data_cache.get("labels")
    data = data_cache["latents"].get(model_key)
    if data is None:
        return None
    if color_choice == "label":
        return labels
    if color_choice == "chart":
        return data.get("chart") if model_key == "topo" else labels
    if color_choice == "code":
        if model_key == "topo":
            return data.get("code_global")
        if model_key == "std":
            return data.get("code")
        return labels
    if color_choice == "recon_mse":
        return data.get("recon_mse")
    if color_choice == "latent_norm":
        z = _resolve_latent(model_key, vq_choice)
        return np.linalg.norm(z, axis=1) if z is not None else labels
    if color_choice == "none":
        return None
    return labels


def _plot_model(
    model_key: str,
    color_choice: str,
    colorscale_choice: str,
    size_value: float,
    opacity_value: float,
    vq_choice: str,
    show_colorbar_value: bool,
):
    if not data_cache["loaded"]:
        return pn.pane.Markdown("Load an experiment to render plots.")

    z = _resolve_latent(model_key, vq_choice)
    if z is None:
        return pn.pane.Markdown(f"{MODEL_LABELS[model_key]} not available.")

    df = pd.DataFrame({"x": z[:, 0], "y": z[:, 1], "z": z[:, 2]})
    color_values = _coerce_color(_resolve_color(model_key, color_choice, vq_choice))

    if color_values is not None:
        df["color"] = color_values
        scatter = hv.Scatter3D(df, kdims=["x", "y", "z"], vdims=["color"]).opts(
            color="color",
            cmap=colorscale_choice,
            colorbar=show_colorbar_value,
        )
    else:
        scatter = hv.Scatter3D(df, kdims=["x", "y", "z"]).opts(color="#1f77b4")

    title = MODEL_LABELS[model_key]
    if model_key == "std":
        title = f"{title} ({vq_choice})"

    return scatter.opts(
        width=420,
        height=420,
        size=size_value,
        alpha=opacity_value,
        title=title,
        xlabel="z1",
        ylabel="z2",
        zlabel="z3",
        tools=["hover"],
    )


@pn.depends(
    show_models.param.value,
    color_by.param.value,
    colorscale.param.value,
    marker_size.param.value,
    marker_opacity.param.value,
    vq_latent.param.value,
    show_colorbar.param.value,
)
def plot_row(
    show_models_value,
    color_choice,
    colorscale_choice,
    size_value,
    opacity_value,
    vq_choice,
    show_colorbar_value,
):
    if not data_cache["loaded"]:
        return pn.pane.Markdown("Load an experiment to render plots.")

    panels = []
    if "TopoEncoder" in show_models_value:
        panels.append(
            pn.panel(
                _plot_model(
                    "topo",
                    color_choice,
                    colorscale_choice,
                    size_value,
                    opacity_value,
                    vq_choice,
                    show_colorbar_value,
                )
            )
        )
    if "StandardVQ" in show_models_value:
        panels.append(
            pn.panel(
                _plot_model(
                    "std",
                    color_choice,
                    colorscale_choice,
                    size_value,
                    opacity_value,
                    vq_choice,
                    show_colorbar_value,
                )
            )
        )
    if "VanillaAE" in show_models_value:
        panels.append(
            pn.panel(
                _plot_model(
                    "ae",
                    color_choice,
                    colorscale_choice,
                    size_value,
                    opacity_value,
                    vq_choice,
                    show_colorbar_value,
                )
            )
        )

    return pn.FlexBox(*panels, flex_wrap="wrap", gap="16px")


def _load_experiment(_=None) -> None:
    exp_dir = Path(experiment_dir.value).expanduser()
    if not exp_dir.exists():
        status_pane.object = f"Experiment folder not found: `{exp_dir}`"
        data_cache["loaded"] = False
        return

    if not checkpoint_select.value:
        status_pane.object = "No checkpoint selected."
        data_cache["loaded"] = False
        return

    checkpoint_path = exp_dir / checkpoint_select.value
    checkpoint = _load_checkpoint(checkpoint_path)
    if "data" not in checkpoint:
        status_pane.object = f"Checkpoint has no data: `{checkpoint_path.name}`"
        data_cache["loaded"] = False
        return

    config = checkpoint["config"]
    metrics = checkpoint.get("metrics", {})
    state = checkpoint["state"]

    benchmarks = _load_benchmarks(checkpoint_path)
    bench_state = benchmarks.get("state", {}) if benchmarks else None
    bench_dims = benchmarks.get("dims", {}) if benchmarks else None

    device = "cpu"
    models, sources = _prepare_models(config, state, metrics, bench_state, bench_dims, device)

    data = checkpoint["data"]
    x_test = _as_tensor(data["X_test"])
    labels = _as_numpy(data.get("labels_test"))
    x_test, labels = _sample_data(x_test, labels, max_samples.value, sample_seed.value)

    latents = {}
    topo = _compute_topo_latents(models["topo"], x_test, batch_size.value, device)
    topo["z_3d"] = _pca_to_3d(topo["z"])
    codes_per_chart = int(config.get("codes_per_chart", 1))
    topo["code_global"] = topo["chart"] * codes_per_chart + topo["code"]
    latents["topo"] = topo

    if models["std"] is not None:
        std = _compute_std_latents(models["std"], x_test, batch_size.value, device)
        std["z_e_3d"] = _pca_to_3d(std["z_e"])
        std["z_q_3d"] = _pca_to_3d(std["z_q"])
        latents["std"] = std
    else:
        latents["std"] = None

    if models["ae"] is not None:
        ae = _compute_ae_latents(models["ae"], x_test, batch_size.value, device)
        ae["z_3d"] = _pca_to_3d(ae["z"])
        latents["ae"] = ae
    else:
        latents["ae"] = None

    data_cache.update(
        {
            "loaded": True,
            "labels": labels,
            "latents": latents,
            "config": config,
            "metrics": metrics,
            "sources": sources,
            "codes_per_chart": codes_per_chart,
            "checkpoint": checkpoint_path,
        }
    )

    metrics_pane.object = _build_metrics_table(metrics, latents, models, sources)

    available = ["TopoEncoder"]
    if latents.get("std") is not None:
        available.append("StandardVQ")
    if latents.get("ae") is not None:
        available.append("VanillaAE")
    show_models.options = available
    show_models.value = [name for name in show_models.value if name in available] or available

    dataset = config.get("dataset", "unknown")
    status_pane.object = (
        f"**Checkpoint:** `{checkpoint_path}`"
        f"- dataset: `{dataset}`"
        f"- samples: {x_test.shape[0]}"
        f"- latent_dim: {config.get('latent_dim')}"
        f"- charts: {config.get('num_charts')}"
        f"- benchmarks: std={sources.get('std')}, ae={sources.get('ae')}"
    )


refresh_button.on_click(_refresh_checkpoints)
load_button.on_click(_load_experiment)
experiment_dir.param.watch(lambda event: _refresh_checkpoints(), "value")
_refresh_checkpoints()

controls = pn.Column(
    "### Load",
    experiment_dir,
    refresh_button,
    checkpoint_select,
    load_button,
    "### Sampling",
    max_samples,
    batch_size,
    sample_seed,
    "### Display",
    color_by,
    colorscale,
    marker_size,
    marker_opacity,
    show_colorbar,
    vq_latent,
    show_models,
    width=340,
)

dashboard = pn.Column(
    pn.Row(controls, pn.Column(status_pane, metrics_pane, sizing_mode="stretch_width")),
    plot_row,
    sizing_mode="stretch_width",
)

dashboard
