# Latent Averages → Reconstructions (Interactive Viewer)

This notebook:
1. Loads latent vectors from `latents.h5` (datasets `/z`, `/filenames`).
2. Computes **arithmetic mean**, **component-wise median**, and an **approximate geometric median** (Weiszfeld) of the latents.
3. Rebuilds your decoder (matching your training architecture), loads `decoder_weights.h5`, and reconstructs those three vectors.
4. Displays an **interactive figure** where you can **toggle dynamic range** (0–255) and use **Auto (2–98%)**.

**Note:** For median/geometric median on massive datasets, we use a configurable **subsample** for tractability.


In [1]:
# ──────────────────────────────────────────────────────────────────────
# 0) Imports & interactive backend
# ──────────────────────────────────────────────────────────────────────
%matplotlib widget
import os, math
import h5py
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as W
from IPython.display import display

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Conv2D, Conv2DTranspose, ReLU

print(tf.__version__)
print("GPU devices:", tf.config.list_physical_devices("GPU"))

# enable GPU memory growth (safe if no GPU available)
try:
    for g in tf.config.list_physical_devices("GPU"):
        tf.config.experimental.set_memory_growth(g, True)
except Exception as e:
    print("[!] Could not set GPU memory growth:", e)


2.20.0
GPU devices: []


In [2]:
# ──────────────────────────────────────────────────────────────────────
# 1) Paths & knobs (EDIT THESE IF NEEDED)
# ──────────────────────────────────────────────────────────────────────
RESULTS_DIR = r'D:/Results/09052025_AE1M_Conv2DTranspose'   # folder with decoder_weights.h5 and latents.h5
LATENTS_H5  = os.path.join(RESULTS_DIR, 'latents.h5')
DECODER_W   = os.path.join(RESULTS_DIR, 'decoder_weights.h5')

# Model/latent shape (must match training)
INPUT_SHAPE = (256, 256, 1)
LATENT_DIM  = 512

# Chunk size for streaming mean
CHUNK_ROWS  = 100_000  # adjust if needed

# Subsample sizes for (component-wise) median & geometric median (Weiszfeld)
MEDIAN_SUBSAMPLE_N    = 200_000   # per-dim median from a random subset
GEOMEDIAN_SUBSAMPLE_N = 100_000   # Weiszfeld runs on a random subset

# Weiszfeld algorithm knobs
GEOMEDIAN_MAX_ITERS = 200
GEOMEDIAN_TOL       = 1e-6

assert os.path.isfile(LATENTS_H5), f"latents.h5 not found: {LATENTS_H5}"
assert os.path.isfile(DECODER_W),  f"decoder_weights.h5 not found: {DECODER_W}"
print("Using:")
print("  ", LATENTS_H5)
print("  ", DECODER_W)


Using:
   D:/Results/09052025_AE1M_Conv2DTranspose\latents.h5
   D:/Results/09052025_AE1M_Conv2DTranspose\decoder_weights.h5


In [3]:
# ──────────────────────────────────────────────────────────────────────
# 2) Utilities: streaming mean, subsampling, Weiszfeld geometric median
# ──────────────────────────────────────────────────────────────────────
def stream_mean_z(h5path, dataset='z', chunk_rows=100_000):
    """Compute arithmetic mean of rows in /z via streaming chunks."""
    with h5py.File(h5path, 'r') as h5:
        z = h5[dataset]
        N, D = z.shape
        acc = np.zeros(D, dtype=np.float64)
        for start in range(0, N, chunk_rows):
            end = min(start + chunk_rows, N)
            acc += z[start:end].astype(np.float64).sum(axis=0)
        return (acc / N).astype(np.float32)

def random_subsample(h5path, n_wanted, dataset='z', seed=42):
    """Return a (M, D) array by sampling rows uniformly at random."""
    rng = np.random.default_rng(seed)
    with h5py.File(h5path, 'r') as h5:
        z = h5[dataset]
        N, D = z.shape
        M = min(n_wanted, N)
        idx = rng.choice(N, size=M, replace=False)
        idx.sort()

        out = np.empty((M, D), dtype=np.float32)
        block_start = 0
        while block_start < M:
            # extend block while indices are consecutive
            block_end = block_start + 1
            while block_end < M and idx[block_end] == idx[block_end - 1] + 1:
                block_end += 1

            # read exactly those indices (fancy indexing), not z[s:e]
            block_idx = idx[block_start:block_end]
            out[block_start:block_end] = z[block_idx]

            block_start = block_end

        return out


def weiszfeld_geometric_median(points, x0=None, max_iters=200, tol=1e-6, eps=1e-12):
    """
    Geometric median via Weiszfeld's algorithm.
    points: (N, D) float array
    x0    : initial guess (D,)
    Returns: (median (D,), iters)
    """
    P = np.asarray(points, dtype=np.float64)
    if x0 is None:
        x = np.median(P, axis=0)
    else:
        x = np.asarray(x0, dtype=np.float64)

    for it in range(1, max_iters + 1):
        diff = P - x
        dist = np.linalg.norm(diff, axis=1)
        # if x coincides with a point, that's the median
        zero_mask = dist < eps
        if np.any(zero_mask):
            return P[zero_mask][0].astype(np.float32), it
        w = 1.0 / np.maximum(dist, eps)
        x_new = (P * w[:, None]).sum(axis=0) / w.sum()
        if np.linalg.norm(x_new - x) < tol:
            return x_new.astype(np.float32), it
        x = x_new
    return x.astype(np.float32), max_iters


In [4]:
# ──────────────────────────────────────────────────────────────────────
# 3) Compute latent representatives
# ──────────────────────────────────────────────────────────────────────
with h5py.File(LATENTS_H5, 'r') as h5:
    N, D = h5['z'].shape
print(f"[i] Latents: N={N:,}, D={D}")
assert D == LATENT_DIM, f"LATENT_DIM mismatch: file has {D}, config says {LATENT_DIM}"

# Arithmetic mean (streamed)
z_mean = stream_mean_z(LATENTS_H5, 'z', CHUNK_ROWS)
print("[i] Arithmetic mean computed.")

# Component-wise median from subsample
Z_med_sample = random_subsample(LATENTS_H5, MEDIAN_SUBSAMPLE_N, 'z')
z_median = np.median(Z_med_sample, axis=0).astype(np.float32)
print(f"[i] Component-wise median from sample of {len(Z_med_sample):,}.")

# Geometric median (Weiszfeld) from subsample
Z_geo_sample = random_subsample(LATENTS_H5, GEOMEDIAN_SUBSAMPLE_N, 'z', seed=123)
z_geomed, iters = weiszfeld_geometric_median(Z_geo_sample, x0=z_median,
                                            max_iters=GEOMEDIAN_MAX_ITERS,
                                            tol=GEOMEDIAN_TOL)
print(f"[i] Geometric median done in {iters} iters from sample of {len(Z_geo_sample):,}.")


[i] Latents: N=1,061,277, D=512
[i] Arithmetic mean computed.
[i] Component-wise median from sample of 200,000.
[i] Geometric median done in 8 iters from sample of 100,000.


In [5]:
# ──────────────────────────────────────────────────────────────────────
# 4) Rebuild decoder (must match training) and load weights
# ──────────────────────────────────────────────────────────────────────
def build_decoder(latent_dim=512):
    latent_in = Input((latent_dim,), name='z_sampling')
    x = Dense(16 * 16 * 128)(latent_in)
    x = Reshape((16, 16, 128))(x)
    for filters in [128, 64, 32, 16]:
        x = Conv2DTranspose(filters, 3, strides=2, padding='same')(x); x = ReLU()(x)
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
    decoded = Conv2D(1, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
    return Model(latent_in, decoded, name='decoder')

decoder = build_decoder(LATENT_DIM)
decoder.load_weights(DECODER_W)
decoder.trainable = False
print("[i] Decoder rebuilt & weights loaded.")


[i] Decoder rebuilt & weights loaded.


In [6]:
# ──────────────────────────────────────────────────────────────────────
# 5) Decode representatives → images in [0,1]
# ──────────────────────────────────────────────────────────────────────
def decode_vec(v):
    arr = np.asarray(v, dtype=np.float32)[None, ...]  # shape (1, D)
    out = decoder.predict(arr, batch_size=1, verbose=0)[0, ..., 0]
    return np.clip(out, 0.0, 1.0).astype(np.float32)

img_mean    = decode_vec(z_mean)
img_median  = decode_vec(z_median)
img_geomed  = decode_vec(z_geomed)
print("[i] Decoded mean/median/geomedian.")


[i] Decoded mean/median/geomedian.


In [7]:
# ──────────────────────────────────────────────────────────────────────
# 6) Interactive viewer (single live figure, no duplicates)
# ──────────────────────────────────────────────────────────────────────
import matplotlib.pyplot as plt
import ipywidgets as W
from IPython.display import display

# turn off auto-display so only our explicit display() runs
plt.close('all')
plt.ioff()

images = {
    'Arithmetic mean': img_mean,
    'Component-wise median': img_median,
    'Geometric median': img_geomed,
}

current_key = 'Arithmetic mean'
vmin_init, vmax_init = 0, 255

fig, ax = plt.subplots(figsize=(5, 5))
im_artist = ax.imshow(images[current_key], cmap='gray',
                      vmin=vmin_init/255.0, vmax=vmax_init/255.0)
ax.set_title(current_key)
ax.axis('off')

dd = W.Dropdown(options=list(images.keys()), value=current_key, description='Image:')
rng = W.IntRangeSlider(value=[vmin_init, vmax_init], min=0, max=255, step=1,
                       description='Range:', continuous_update=True)
btn_auto  = W.Button(description='Auto (2–98%)')
btn_reset = W.Button(description='Reset 0–255')

def update_image(*_):
    key = dd.value
    lo, hi = rng.value
    im_artist.set_data(images[key])
    im_artist.set_clim(lo/255.0, hi/255.0)
    ax.set_title(key)
    fig.canvas.draw_idle()

def on_auto_click(_):
    arr = images[dd.value]
    lo = int(np.clip(round(np.percentile(arr, 2)  * 255.0), 0, 255))
    hi = int(np.clip(round(np.percentile(arr, 98) * 255.0), 0, 255))
    if lo >= hi:
        lo, hi = 0, 255
    rng.value = (lo, hi)
    update_image()

def on_reset_click(_):
    rng.value = (0, 255)
    update_image()

dd.observe(update_image, names='value')
rng.observe(update_image, names='value')
btn_auto.on_click(on_auto_click)
btn_reset.on_click(on_reset_click)

ui = W.VBox([dd, rng, W.HBox([btn_auto, btn_reset])])


### Notes
- If you have *plenty* of RAM and want exact medians, set `MEDIAN_SUBSAMPLE_N` (and `GEOMEDIAN_SUBSAMPLE_N`) to your full `N` (or increase them). For huge datasets, the current defaults are a good balance.
- The geometric median uses **Weiszfeld** on a subsample with the component-wise median as the initial guess for faster convergence.
- The viewer uses `%matplotlib widget` from **ipympl** to avoid duplicate static plots.


In [8]:
#| label: fig:ae1m-theoretical
#| caption: "."

# Display exactly once: the controls + the single canvas
display(W.VBox([ui, fig.canvas]))

VBox(children=(VBox(children=(Dropdown(description='Image:', options=('Arithmetic mean', 'Component-wise media…