# PCA projections of latent space.

In [1]:
#| label: ae1m-pca

import os, h5py
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from ipywidgets import Dropdown, VBox, Output
from pathlib import Path

# CONFIG — update if needed
RESULTS_DIR = Path("..") / "data" / "nucleus-ae" 
PCA_H5 = f"{RESULTS_DIR}/pca_embeddings.h5"

# Enable interactive Matplotlib (ipympl) in notebooks
%matplotlib widget

plt.rcParams['figure.figsize'] = (7, 6)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.bbox'] = 'tight'


def _load_pairs_and_meta(pca_path):
    """
    Returns:
      pairs: list[(name:str, arr:(N,2) float32)]
      meta:  dict with keys:
             - 'evr': np.ndarray or None   (explained variance ratios — top3 or top5 if available)
             - 'scores_key': str or None   ('scores_top3' or 'scores_top5' if present)
    """
    pairs, meta = [], {"evr": None, "scores_key": None}
    if not os.path.isfile(pca_path):
        return pairs, meta

    with h5py.File(pca_path, "r") as h5:
        # Prefer /pairs group (new file layout)
        if "pairs" in h5 and isinstance(h5["pairs"], h5py.Group):
            grp = h5["pairs"]
            for name in sorted(grp.keys()):
                arr = np.asarray(grp[name][...], dtype=np.float32)
                pairs.append((name, arr))
        else:
            # Legacy fallback: look at root for PCi_PCj datasets (up to PC5)
            for i in range(1, 6):
                for j in range(i + 1, 6):
                    name = f"PC{i}_PC{j}"
                    if name in h5:
                        arr = np.asarray(h5[name][...], dtype=np.float32)
                        pairs.append((name, arr))

        # Grab EVR if present (prefer top3, else top5)
        if "explained_variance_ratio_top3" in h5.attrs:
            meta["evr"] = np.asarray(h5.attrs["explained_variance_ratio_top3"])
        elif "explained_variance_ratio_top5" in h5.attrs:
            meta["evr"] = np.asarray(h5.attrs["explained_variance_ratio_top5"])

        # Which scores dataset exists?
        if "scores_top3" in h5:
            meta["scores_key"] = "scores_top3"
        elif "scores_top5" in h5:
            meta["scores_key"] = "scores_top5"

    return pairs, meta


def _autosize_point_params(N):
    if N <= 50_000:
        return 6.0, 0.8
    elif N <= 200_000:
        return 2.5, 0.6
    elif N <= 1_000_000:
        return 0.8, 0.5
    else:
        return 0.4, 0.5


def scatter2d(ax, xy, title=None):
    N = xy.shape[0]
    s, alpha = _autosize_point_params(N)
    ax.scatter(xy[:, 0], xy[:, 1], s=s, alpha=alpha, rasterized=True)
    if title:
        ax.set_title(title)
    ax.set_xlabel("Dim 1"); ax.set_ylabel("Dim 2"); ax.grid(False)
    return ax


# Load pairs + meta once (no prints)
pca_pairs, pca_meta = _load_pairs_and_meta(PCA_H5)
evr = pca_meta.get("evr", None)

#| label: ae1m-pca
pca_out = Output()
pca_dd = Dropdown(
    options=[name for name, _ in pca_pairs] or ["<no PCA pairs found>"],
    description='PC Pair:',
    disabled=(len(pca_pairs) == 0),
)

def _title_with_evr(name):
    # If EVR is available, and name is like 'PCi_PCj', show EVR(i) and EVR(j) in title.
    if evr is None or "_" not in name:
        return f"PCA — {name}"
    try:
        pcx, pcy = name.split("_")
        ix = int(pcx.replace("PC", "")) - 1
        iy = int(pcy.replace("PC", "")) - 1
        ex = evr[ix] if ix < len(evr) else None
        ey = evr[iy] if iy < len(evr) else None
        if ex is not None and ey is not None:
            return f"PCA — {name}  (EVR: PC{ix+1}={ex:.4f}, PC{iy+1}={ey:.4f})"
    except Exception:
        pass
    return f"PCA — {name}"

def draw_pca(name):
    pca_out.clear_output(wait=True)
    xy = None
    for nm, arr in pca_pairs:
        if nm == name:
            xy = arr
            break
    with pca_out:
        if xy is None or (isinstance(name, str) and name.startswith("<no ")):
            print("PCA dataset not found.")
            return
        fig, ax = plt.subplots()
        xlab, ylab = name.split("_") if "_" in name else ("Dim 1", "Dim 2")
        scatter2d(ax, xy, title=_title_with_evr(name))
        ax.set_xlabel(xlab); ax.set_ylabel(ylab)
        plt.show()

def _on_pca_change(change):
    if change['name'] == 'value':
        draw_pca(change['new'])

if len(pca_dd.options) > 0 and not pca_dd.disabled:
    draw_pca(pca_dd.options[0])

pca_dd.observe(_on_pca_change, names='value')
VBox([pca_dd, pca_out])


VBox(children=(Dropdown(description='PC Pair:', options=('PC1_PC2', 'PC1_PC3', 'PC2_PC3'), value='PC1_PC2'), O…