# TopoEncoder Latent Interpretation

This notebook loads a topoencoder checkpoint and optional benchmarks to interpret the latent space, routing, and reconstructions.


## What this notebook covers
- Latent scatter by class (2D PCA + optional 3D view).
- Chart assignments and code usage heatmap.
- Reconstruction error distribution and per-chart error.
- Chart prototypes (nearest samples to chart centroids).
- z_n norm statistics and feature correlations.
- Optional benchmark comparison (AE and VQ if available).


In [None]:
%matplotlib inline
import os
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt

from fragile.core.layers import TopoEncoderPrimitives, StandardVQ, VanillaAE


In [None]:
# Update this path to your checkpoint
checkpoint_path = Path("outputs/topoencoder_mnist_cpu_adapt_lr7/topo_final.pt")
device = "cpu"

output_dir = checkpoint_path.parent
bench_path = output_dir / "benchmarks.pt"


In [None]:
def load_checkpoint(path: Path):
    if not path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    try:
        return torch.load(path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(path, map_location="cpu")

checkpoint = load_checkpoint(checkpoint_path)
benchmarks = load_checkpoint(bench_path) if bench_path.exists() else None

config = checkpoint["config"]
data = checkpoint.get("data", {})
metrics = checkpoint.get("metrics", {})
state = checkpoint["state"]

bench_state = benchmarks.get("state", {}) if benchmarks else None
bench_dims = benchmarks.get("dims", {}) if benchmarks else None


In [None]:
def build_models(config, state, metrics, bench_state, bench_dims, device):
    model_atlas = TopoEncoderPrimitives(
        input_dim=config["input_dim"],
        hidden_dim=config["hidden_dim"],
        latent_dim=config["latent_dim"],
        num_charts=config["num_charts"],
        codes_per_chart=config["codes_per_chart"],
    ).to(device)
    model_atlas.load_state_dict(state["atlas"])
    model_atlas.eval()

    model_std = None
    if state.get("std") is not None and not config.get("disable_vq", False):
        std_hidden_dim = metrics.get("std_hidden_dim", config["hidden_dim"])
        model_std = StandardVQ(
            input_dim=config["input_dim"],
            hidden_dim=std_hidden_dim,
            latent_dim=config["latent_dim"],
            num_codes=config["num_codes_standard"],
        ).to(device)
        model_std.load_state_dict(state["std"])
        model_std.eval()
    elif bench_state is not None and bench_state.get("std") is not None:
        std_hidden_dim = (
            (bench_dims or {}).get("std_hidden_dim")
            or metrics.get("std_hidden_dim")
            or config["hidden_dim"]
        )
        model_std = StandardVQ(
            input_dim=config["input_dim"],
            hidden_dim=int(std_hidden_dim),
            latent_dim=config["latent_dim"],
            num_codes=config["num_codes_standard"],
        ).to(device)
        model_std.load_state_dict(bench_state["std"])
        model_std.eval()

    model_ae = None
    if state.get("ae") is not None and not config.get("disable_ae", False):
        ae_hidden_dim = metrics.get("ae_hidden_dim", config["hidden_dim"])
        model_ae = VanillaAE(
            input_dim=config["input_dim"],
            hidden_dim=ae_hidden_dim,
            latent_dim=config["latent_dim"],
        ).to(device)
        model_ae.load_state_dict(state["ae"])
        model_ae.eval()
    elif bench_state is not None and bench_state.get("ae") is not None:
        ae_hidden_dim = (
            (bench_dims or {}).get("ae_hidden_dim")
            or metrics.get("ae_hidden_dim")
            or config["hidden_dim"]
        )
        model_ae = VanillaAE(
            input_dim=config["input_dim"],
            hidden_dim=int(ae_hidden_dim),
            latent_dim=config["latent_dim"],
        ).to(device)
        model_ae.load_state_dict(bench_state["ae"])
        model_ae.eval()

    return model_atlas, model_std, model_ae

model_atlas, model_std, model_ae = build_models(
    config, state, metrics, bench_state, bench_dims, device
)


In [None]:
def dataset_specs(dataset: str):
    if dataset == "mnist":
        return [str(i) for i in range(10)], (28, 28, 1)
    if dataset == "cifar10":
        from fragile.datasets import CIFAR10_CLASSES
        return list(CIFAR10_CLASSES), (32, 32, 3)
    raise ValueError(f"Unsupported dataset: {dataset}")

class_names, image_shape = dataset_specs(config.get("dataset", "mnist"))

X_test = data.get("X_test")
labels_test = data.get("labels_test")
if X_test is None or labels_test is None:
    raise RuntimeError("No test data in checkpoint. Re-run training with data saving enabled.")

if isinstance(X_test, np.ndarray):
    X_test_tensor = torch.from_numpy(X_test).float()
else:
    X_test_tensor = X_test.float()

if isinstance(labels_test, torch.Tensor):
    labels_np = labels_test.cpu().numpy()
else:
    labels_np = np.asarray(labels_test)

X_test_device = X_test_tensor.to(device)


In [None]:
with torch.no_grad():
    enc_out = model_atlas.encoder(X_test_device)
    K_chart = enc_out[0]
    K_code = enc_out[1]
    z_n = enc_out[2]
    z_tex = enc_out[3]
    enc_w = enc_out[4]
    z_geo = enc_out[5]
    indices = enc_out[7]
    z_n_all = enc_out[8]
    c_bar = enc_out[9]

    recon_atlas = model_atlas(X_test_device, use_hard_routing=False)[0]

z_geo_np = z_geo.cpu().numpy()
z_n_np = z_n.cpu().numpy()
K_chart_np = K_chart.cpu().numpy()
K_code_np = K_code.cpu().numpy()

recon_atlas_cpu = recon_atlas.cpu()


In [None]:
def pca_project(x: np.ndarray, n_components: int = 2) -> np.ndarray:
    x_centered = x - x.mean(axis=0, keepdims=True)
    u, s, vt = np.linalg.svd(x_centered, full_matrices=False)
    return x_centered @ vt[:n_components].T

def scatter_2d(z: np.ndarray, c: np.ndarray, title: str, cmap: str = "tab10"):
    fig, ax = plt.subplots(figsize=(6, 5))
    ax.scatter(z[:, 0], z[:, 1], c=c, cmap=cmap, s=6, alpha=0.7)
    ax.set_title(title)
    ax.set_xlabel("z1")
    ax.set_ylabel("z2")
    ax.grid(alpha=0.3)

def scatter_3d(z: np.ndarray, c: np.ndarray, title: str, cmap: str = "tab10"):
    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(111, projection="3d")
    ax.scatter(z[:, 0], z[:, 1], z[:, 2], c=c, cmap=cmap, s=6, alpha=0.7)
    ax.set_title(title)
    ax.set_xlabel("z1")
    ax.set_ylabel("z2")
    ax.set_zlabel("z3")


In [None]:
# PCA view (2D) by class and by chart
z_pca_2d = pca_project(z_geo_np, n_components=2)
scatter_2d(z_pca_2d, labels_np, "Latent (PCA 2D) by class")
scatter_2d(z_pca_2d, K_chart_np, "Latent (PCA 2D) by chart", cmap="tab20")


In [None]:
# 3D view if latent_dim >= 3
if z_geo_np.shape[1] >= 3:
    z_pca_3d = pca_project(z_geo_np, n_components=3)
    scatter_3d(z_pca_3d, labels_np, "Latent (PCA 3D) by class")
    scatter_3d(z_pca_3d, K_chart_np, "Latent (PCA 3D) by chart", cmap="tab20")


In [None]:
# Chart usage and code usage heatmap
num_charts = config["num_charts"]
codes_per_chart = config["codes_per_chart"]

chart_counts = np.bincount(K_chart_np, minlength=num_charts)
plt.figure(figsize=(6, 3))
plt.bar(np.arange(num_charts), chart_counts)
plt.title("Chart usage")
plt.xlabel("chart")
plt.ylabel("count")
plt.grid(alpha=0.3)

code_counts = np.zeros((num_charts, codes_per_chart), dtype=np.int64)
for c, k in zip(K_chart_np, K_code_np):
    code_counts[c, k] += 1

plt.figure(figsize=(8, 4))
plt.imshow(code_counts, aspect="auto", cmap="viridis")
plt.colorbar(label="count")
plt.title("Code usage per chart")
plt.xlabel("code index")
plt.ylabel("chart")


In [None]:
# Reconstruction error distribution and per-chart mean
mse_per_sample = ((recon_atlas_cpu - X_test_tensor) ** 2).mean(dim=1).cpu().numpy()

plt.figure(figsize=(6, 3))
plt.hist(mse_per_sample, bins=40, color="steelblue", alpha=0.8)
plt.title("Reconstruction MSE distribution")
plt.xlabel("mse")
plt.ylabel("count")

per_chart_mse = []
for c in range(num_charts):
    mask = K_chart_np == c
    per_chart_mse.append(mse_per_sample[mask].mean() if mask.any() else 0.0)

plt.figure(figsize=(6, 3))
plt.bar(np.arange(num_charts), per_chart_mse)
plt.title("Mean reconstruction MSE per chart")
plt.xlabel("chart")
plt.ylabel("mean mse")


In [None]:
# Chart prototypes: nearest sample to each chart centroid
def to_image(x, shape):
    h, w, c = shape
    img = x.detach().cpu().numpy().reshape(h, w, c)
    img = np.clip(img, 0, 1)
    return img

top_charts = np.argsort(chart_counts)[::-1][: min(10, num_charts)]

fig, axes = plt.subplots(2, len(top_charts), figsize=(2 * len(top_charts), 4))
if len(top_charts) == 1:
    axes = np.array([[axes[0]], [axes[1]]])

for i, chart in enumerate(top_charts):
    mask = K_chart_np == chart
    if not mask.any():
        continue
    chart_z = z_geo_np[mask]
    center = chart_z.mean(axis=0)
    idx = np.where(mask)[0]
    nearest = idx[np.argmin(((chart_z - center) ** 2).sum(axis=1))]

    axes[0, i].imshow(to_image(X_test_tensor[nearest], image_shape), cmap="gray")
    axes[0, i].set_title(f"chart {chart}")
    axes[0, i].axis("off")

    axes[1, i].imshow(to_image(recon_atlas_cpu[nearest], image_shape), cmap="gray")
    axes[1, i].axis("off")

axes[0, 0].set_ylabel("input")
axes[1, 0].set_ylabel("recon")
plt.tight_layout()


In [None]:
# z_n norm by chart
z_n_norm = np.linalg.norm(z_n_np, axis=1)
data_by_chart = [z_n_norm[K_chart_np == c] for c in range(num_charts)]

plt.figure(figsize=(8, 3))
plt.boxplot(data_by_chart, showfliers=False)
plt.title("z_n norm by chart")
plt.xlabel("chart")
plt.ylabel("norm")


In [None]:
# Correlation between z_geo dims and mean pixel intensity
mean_intensity = X_test_tensor.mean(dim=1).cpu().numpy()
correlations = []
for d in range(z_geo_np.shape[1]):
    corr = np.corrcoef(z_geo_np[:, d], mean_intensity)[0, 1]
    correlations.append(corr)

plt.figure(figsize=(6, 3))
plt.bar(np.arange(len(correlations)), correlations)
plt.title("Correlation of z_geo dims with mean intensity")
plt.xlabel("latent dim")
plt.ylabel("corr")


In [None]:
# Benchmark comparison (if available)
if model_std is not None or model_ae is not None:
    with torch.no_grad():
        if model_std is not None:
            z_std = model_std.encoder(X_test_device).cpu().numpy()
            z_std_2d = pca_project(z_std, n_components=2)
            scatter_2d(z_std_2d, labels_np, "StandardVQ latent (PCA 2D)")
        if model_ae is not None:
            recon_ae, z_ae = model_ae(X_test_device)
            z_ae_2d = pca_project(z_ae.cpu().numpy(), n_components=2)
            scatter_2d(z_ae_2d, labels_np, "VanillaAE latent (PCA 2D)")

        if model_std is not None:
            recon_std = model_std(X_test_device)[0].cpu()
            mse_std = ((recon_std - X_test_tensor) ** 2).mean().item()
            print(f"StandardVQ MSE: {mse_std:.5f}")
        if model_ae is not None:
            mse_ae = ((recon_ae.cpu() - X_test_tensor) ** 2).mean().item()
            print(f"VanillaAE MSE: {mse_ae:.5f}")

    mse_atlas = ((recon_atlas_cpu - X_test_tensor) ** 2).mean().item()
    print(f"TopoEncoder MSE: {mse_atlas:.5f}")
else:
    print("No benchmark models found in this checkpoint directory.")
