# This is slow, rework this later.

# Latent Averages → Nearest Real Images (with Robust Path Resolver)

This notebook computes mean/median/geometric-median latents, decodes them,
re-embeds via the encoder, finds nearest real images in latent space (Euclidean & Cosine),
and shows a 3-panel interactive viewer.

**Image root:** `D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs`


In [1]:
%matplotlib widget
import os, math, glob
import h5py
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as W
from IPython.display import display
from tifffile import imread

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Dense, Reshape, Conv2D, Conv2DTranspose,
                                     ReLU, AveragePooling2D, Flatten)

print("TensorFlow:", tf.__version__)
try:
    for g in tf.config.list_physical_devices("GPU"):
        tf.config.experimental.set_memory_growth(g, True)
except Exception as e:
    print("[!] Could not set GPU memory growth:", e)


TensorFlow: 2.20.0


In [2]:
# ──────────────────────────────────────────────────────────────────────
# Paths & knobs (EDIT IF NEEDED)
# ──────────────────────────────────────────────────────────────────────
RESULTS_DIR = r'D:/Results/09052025_AE1M_Conv2DTranspose'
LATENTS_H5  = os.path.join(RESULTS_DIR, 'latents.h5')
ENCODER_W   = os.path.join(RESULTS_DIR, 'encoder_weights.h5')
DECODER_W   = os.path.join(RESULTS_DIR, 'decoder_weights.h5')

INPUT_SHAPE = (256, 256, 1)
LATENT_DIM  = 512

CHUNK_ROWS   = 200_000   # rows per chunk when scanning /z
K_NEIGHBORS  = 1000      # top-K neighbors to browse

MEDIAN_SUBSAMPLE_N    = 200_000
GEOMEDIAN_SUBSAMPLE_N = 100_000
GEOMEDIAN_MAX_ITERS   = 200
GEOMEDIAN_TOL         = 1e-6

assert os.path.isfile(LATENTS_H5), f"latents.h5 not found: {LATENTS_H5}"
assert os.path.isfile(ENCODER_W),  f"encoder_weights.h5 not found: {ENCODER_W}"
assert os.path.isfile(DECODER_W),  f"decoder_weights.h5 not found: {DECODER_W}"
print("Using:\n  ", LATENTS_H5, "\n  ", ENCODER_W, "\n  ", DECODER_W)


Using:
   D:/Results/09052025_AE1M_Conv2DTranspose\latents.h5 
   D:/Results/09052025_AE1M_Conv2DTranspose\encoder_weights.h5 
   D:/Results/09052025_AE1M_Conv2DTranspose\decoder_weights.h5


In [3]:
# ──────────────────────────────────────────────────────────────────────
# Robust path resolver
# ──────────────────────────────────────────────────────────────────────
ROI_ROOT = r'D:\\Confocal_imaging_nuclei_tif\\MIST_Fused_Images\\ROIs'

# Map any old prefix to ROI_ROOT if present
PATH_REWRITE = [
    (r'C:\\Users\\Work\\Desktop\\Github_repo\\cell_browser\\notebooks', ROI_ROOT),
]

CANDIDATE_ROOTS = [ROI_ROOT]

def _normalize(p: str) -> str:
    # Normalize slashes and collapse redundant separators
    p = p.replace('/', '\\')
    return os.path.normpath(p)

def _try_extensions(path_wo_ext: str, exts=(".tif", ".tiff", ".png")):
    for ext in exts:
        q = path_wo_ext + ext
        if os.path.exists(q):
            return q
    return None

def resolve_path(p_in):
    if p_in is None:
        return None
    p = p_in.decode('utf-8') if isinstance(p_in, (bytes, bytearray)) else str(p_in)
    p = _normalize(p)

    # direct
    if os.path.exists(p):
        return p

    # try rewrite prefixes
    for old, new in PATH_REWRITE:
        old_n = _normalize(old)
        if p.lower().startswith(old_n.lower()):
            rel = os.path.relpath(p, old_n)
            q = _normalize(os.path.join(new, rel))
            if os.path.exists(q):
                return q
            # try alt extensions
            root, ext = os.path.splitext(q)
            q2 = _try_extensions(root)
            if q2:
                return q2

    # basename in candidate roots (direct)
    base = os.path.basename(p)
    if base:
        for root in CANDIDATE_ROOTS:
            q = _normalize(os.path.join(root, base))
            if os.path.exists(q):
                return q
            # try alt extensions if base has one
            b0, _ = os.path.splitext(q)
            q2 = _try_extensions(b0)
            if q2:
                return q2

    # last resort: recursive basename match (can be slower but robust)
    if base:
        for root in CANDIDATE_ROOTS:
            hits = glob.glob(os.path.join(root, "**", base), recursive=True)
            if hits:
                return _normalize(hits[0])
            # alt extensions
            bname, _ = os.path.splitext(base)
            for ext in (".tif", ".tiff"):
                hits = glob.glob(os.path.join(root, "**", bname + ext), recursive=True)
                if hits:
                    return _normalize(hits[0])

    return None


In [4]:
# ──────────────────────────────────────────────────────────────────────
# Utilities: streaming mean, subsample, Weiszfeld geometric median
# ──────────────────────────────────────────────────────────────────────
def stream_mean_z(h5path, dataset='z', chunk_rows=100_000):
    with h5py.File(h5path, 'r') as h5:
        z = h5[dataset]
        N, D = z.shape
        acc = np.zeros(D, dtype=np.float64)
        for start in range(0, N, chunk_rows):
            end = min(start + chunk_rows, N)
            acc += z[start:end].astype(np.float64).sum(axis=0)
        return (acc / N).astype(np.float32)

def random_subsample(h5path, n_wanted, dataset='z', seed=42, chunk=65536):
    rng = np.random.default_rng(seed)
    with h5py.File(h5path, 'r') as h5:
        z = h5[dataset]
        N, D = z.shape
        M = min(n_wanted, N)
        idx = rng.choice(N, size=M, replace=False)
        idx.sort()
        out = np.empty((M, D), dtype=np.float32)
        pos = 0
        while pos < M:
            end = min(pos + chunk, M)
            out[pos:end] = z[idx[pos:end]]
            pos = end
        return out

def weiszfeld_geometric_median(points, x0=None, max_iters=200, tol=1e-6, eps=1e-12):
    P = np.asarray(points, dtype=np.float64)
    x = np.median(P, axis=0) if x0 is None else np.asarray(x0, dtype=np.float64)
    for it in range(1, max_iters + 1):
        diff = P - x
        dist = np.linalg.norm(diff, axis=1)
        zero_mask = dist < eps
        if np.any(zero_mask):
            return P[zero_mask][0].astype(np.float32), it
        w = 1.0 / np.maximum(dist, eps)
        x_new = (P * w[:, None]).sum(axis=0) / w.sum()
        if np.linalg.norm(x_new - x) < tol:
            return x_new.astype(np.float32), it
        x = x_new
    return x.astype(np.float32), max_iters


In [5]:
# ──────────────────────────────────────────────────────────────────────
# Representatives: mean / median / geometric median
# ──────────────────────────────────────────────────────────────────────
with h5py.File(LATENTS_H5, 'r') as h5:
    N, D = h5['z'].shape
print(f"[i] Latents: N={N:,}, D={D}")
assert D == LATENT_DIM, f"LATENT_DIM mismatch: z has {D}, expected {LATENT_DIM}"

z_mean = stream_mean_z(LATENTS_H5, 'z', CHUNK_ROWS)
Z_med_sample = random_subsample(LATENTS_H5, MEDIAN_SUBSAMPLE_N, 'z')
z_median = np.median(Z_med_sample, axis=0).astype(np.float32)
Z_geo_sample = random_subsample(LATENTS_H5, GEOMEDIAN_SUBSAMPLE_N, 'z', seed=123)
z_geomed, iters = weiszfeld_geometric_median(Z_geo_sample, x0=z_median,
                                            max_iters=GEOMEDIAN_MAX_ITERS,
                                            tol=GEOMEDIAN_TOL)
print("[i] Representatives ready (mean/median/geomed).")


[i] Latents: N=1,061,277, D=512
[i] Representatives ready (mean/median/geomed).


In [6]:
# ──────────────────────────────────────────────────────────────────────
# Decoder → decode reps
# ──────────────────────────────────────────────────────────────────────
def build_decoder(latent_dim=512):
    latent_in = Input((latent_dim,), name='z_sampling')
    x = Dense(16 * 16 * 128)(latent_in)
    x = Reshape((16, 16, 128))(x)
    for filters in [128, 64, 32, 16]:
        x = Conv2DTranspose(filters, 3, strides=2, padding='same')(x); x = ReLU()(x)
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
    decoded = Conv2D(1, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
    return Model(latent_in, decoded, name='decoder')

decoder = build_decoder(LATENT_DIM)
decoder.load_weights(DECODER_W)
decoder.trainable = False

def decode_vec(v):
    arr = np.asarray(v, dtype=np.float32)[None, ...]
    out = decoder.predict(arr, batch_size=1, verbose=0)[0, ..., 0]
    return np.clip(out, 0.0, 1.0).astype(np.float32)

img_mean   = decode_vec(z_mean)
img_median = decode_vec(z_median)
img_geomed = decode_vec(z_geomed)
print("[i] Decoded mean/median/geomedian images.")


[i] Decoded mean/median/geomedian images.


In [7]:
# ──────────────────────────────────────────────────────────────────────
# Encoder → embed decoded images
# ──────────────────────────────────────────────────────────────────────
def build_encoder(input_shape=(256,256,1), latent_dim=512):
    inp = Input(input_shape, name='encoder_input')
    x = inp
    for filters in [16, 32, 64, 128]:
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
        x = AveragePooling2D(pool_size=2, strides=2, padding='valid')(x)
    flat = Flatten()(x)
    z = Dense(latent_dim, name='z')(flat)
    return Model(inp, z, name='encoder')

encoder = build_encoder(INPUT_SHAPE, LATENT_DIM)
encoder.load_weights(ENCODER_W)
encoder.trainable = False

def encode_image(img01):
    arr = img01.astype(np.float32)[None, ..., None]
    return encoder.predict(arr, batch_size=1, verbose=0)[0].astype(np.float32)

q_mean   = encode_image(img_mean)
q_median = encode_image(img_median)
q_geomed = encode_image(img_geomed)
print("[i] Encoded average images → latent vectors.")


[i] Encoded average images → latent vectors.


In [8]:
# ──────────────────────────────────────────────────────────────────────
# Nearest neighbors (streaming Top-K) + filename resolution
# ──────────────────────────────────────────────────────────────────────
def _merge_topk(idxs_a, dists_a, idxs_b, dists_b, K):
    if idxs_a.size == 0:
        idxs, dists = idxs_b, dists_b
    else:
        idxs = np.concatenate([idxs_a, idxs_b])
        dists = np.concatenate([dists_a, dists_b])
    if idxs.size <= K:
        return idxs, dists
    keep = np.argpartition(dists, K-1)[:K]
    idxs, dists = idxs[keep], dists[keep]
    order = np.argsort(dists)
    return idxs[order], dists[order]

def topk_euclidean(h5path, q, K=1000, chunk_rows=200_000):
    q = q.astype(np.float32)
    idxs_keep = np.empty(0, dtype=np.int64)
    dists_keep = np.empty(0, dtype=np.float32)
    with h5py.File(h5path, 'r') as h5:
        Z = h5['z']
        N, D = Z.shape
        for s in range(0, N, chunk_rows):
            e = min(s + chunk_rows, N)
            chunk = Z[s:e].astype(np.float32)
            diffs = chunk - q
            d2 = np.einsum('ij,ij->i', diffs, diffs)
            k_local = min(K, d2.size)
            loc_sel = np.argpartition(d2, k_local-1)[:k_local]
            idxs_b = (s + loc_sel).astype(np.int64)
            dists_b = d2[loc_sel]
            idxs_keep, dists_keep = _merge_topk(idxs_keep, dists_keep, idxs_b, dists_b, K)
    dists_keep = np.sqrt(dists_keep)
    order = np.argsort(dists_keep)
    return idxs_keep[order], dists_keep[order]

def topk_cosine(h5path, q, K=1000, chunk_rows=200_000, eps=1e-8):
    q = q.astype(np.float32)
    q_norm = max(np.linalg.norm(q), eps)
    idxs_keep = np.empty(0, dtype=np.int64)
    cd_keep = np.empty(0, dtype=np.float32)
    with h5py.File(h5path, 'r') as h5:
        Z = h5['z']
        N, D = Z.shape
        for s in range(0, N, chunk_rows):
            e = min(s + chunk_rows, N)
            chunk = Z[s:e].astype(np.float32)
            z_norms = np.linalg.norm(chunk, axis=1)
            dots = chunk @ q
            sims = dots / (np.maximum(z_norms, eps) * q_norm)
            cdist = 1.0 - sims
            k_local = min(K, cdist.size)
            loc_sel = np.argpartition(cdist, k_local-1)[:k_local]
            idxs_b = (s + loc_sel).astype(np.int64)
            cd_b = cdist[loc_sel]
            idxs_keep, cd_keep = _merge_topk(idxs_keep, cd_keep, idxs_b, cd_b, K)
    order = np.argsort(cd_keep)
    return idxs_keep[order], cd_keep[order]

def get_filenames(h5path, indices, dataset='filenames'):
    idx_in = np.asarray(indices, dtype=np.int64)
    order = np.argsort(idx_in)
    idx_sorted = idx_in[order]
    with h5py.File(h5path, 'r') as h5:
        ds = h5[dataset]
        fn_sorted = ds[idx_sorted]
    fn_sorted = [f.decode('utf-8') if isinstance(f, (bytes, bytearray)) else str(f)
                 for f in fn_sorted]
    inv = np.empty_like(order)
    inv[order] = np.arange(order.size)
    return list(np.asarray(fn_sorted, dtype=object)[inv])

queries = {
    'Arithmetic mean': encode_image(img_mean),
    'Component-wise median': encode_image(img_median),
    'Geometric median': encode_image(img_geomed),
}

neighbors = {'Euclidean': {}, 'Cosine': {}}
for key, q in queries.items():
    idx_e, d_e = topk_euclidean(LATENTS_H5, q, K=K_NEIGHBORS, chunk_rows=CHUNK_ROWS)
    neighbors['Euclidean'][key] = {'idx': idx_e, 'val': d_e}
    idx_c, d_c = topk_cosine(LATENTS_H5, q, K=K_NEIGHBORS, chunk_rows=CHUNK_ROWS)
    neighbors['Cosine'][key] = {'idx': idx_c, 'val': d_c}
print("[i] Nearest neighbors computed.")

filenames_cache = {'Euclidean': {}, 'Cosine': {}}
for metric in ['Euclidean', 'Cosine']:
    for key in queries.keys():
        idxs = neighbors[metric][key]['idx']
        raw  = get_filenames(LATENTS_H5, idxs)
        resolved = [resolve_path(p) for p in raw]
        missing = sum(1 for p in resolved if p is None)
        print(f"[{metric} · {key}] resolved={len(resolved)} missing={missing}")
        filenames_cache[metric][key] = resolved
print("[i] Filenames cached & resolved.")


[i] Nearest neighbors computed.
[Euclidean · Arithmetic mean] resolved=1000 missing=0
[Euclidean · Component-wise median] resolved=1000 missing=0
[Euclidean · Geometric median] resolved=1000 missing=0
[Cosine · Arithmetic mean] resolved=1000 missing=0
[Cosine · Component-wise median] resolved=1000 missing=0
[Cosine · Geometric median] resolved=1000 missing=0
[i] Filenames cached & resolved.


In [9]:
# ──────────────────────────────────────────────────────────────────────
# Sanity check: show a few resolved paths
# ──────────────────────────────────────────────────────────────────────
for metric in ['Euclidean', 'Cosine']:
    for key in ['Arithmetic mean']:
        sample = filenames_cache[metric][key][:5]
        print(metric, key, 'sample resolved paths:')
        for p in sample:
            print('  ', p)


Euclidean Arithmetic mean sample resolved paths:
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run97BL_top_right_ROI_307.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run11BR_bottom_right_ROI_300.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run91TR_bottom_right_ROI_640.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run89BL_bottom_left_ROI_586.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run91TR_bottom_left_ROI_657.tif
Cosine Arithmetic mean sample resolved paths:
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run48BL_bottom_left_ROI_15.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run74BL_top_left_ROI_316.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run97BL_top_right_ROI_307.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run11BL_bottom_right_ROI_166.tif
   D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs\Run97BL_top_left_ROI_345.tif


In [10]:
# ──────────────────────────────────────────────────────────────────────
# Interactive 3-panel viewer (single live figure)
# ──────────────────────────────────────────────────────────────────────
plt.close('all'); plt.ioff()

metric_toggle = W.ToggleButtons(options=['Euclidean', 'Cosine'], value='Euclidean', description='Metric:')
rank_slider   = W.IntSlider(value=0, min=0, max=min(K_NEIGHBORS, 1000)-1, step=1, description='Rank:')
rng_slider    = W.IntRangeSlider(value=[0, 255], min=0, max=255, step=1, description='Range:')
btn_auto      = W.Button(description='Auto (2–98%)')
btn_reset     = W.Button(description='Reset 0–255')
skip_missing  = W.Checkbox(value=True, description='Skip missing')

fig, axes = plt.subplots(1, 3, figsize=(10, 4))
for ax in axes:
    ax.axis('off')

keys = ['Arithmetic mean', 'Component-wise median', 'Geometric median']
im_artists = []
for ax in axes:
    im = ax.imshow(np.zeros((256,256), dtype=np.float32), cmap='gray', vmin=0, vmax=1)
    im_artists.append(im)

image_cache = {}
def load_img_cached(path):
    if not path or not os.path.exists(path):
        return None
    arr = image_cache.get(path)
    if arr is None:
        arr = imread(path).astype(np.float32)
        if arr.ndim == 3:
            arr = arr[..., 0]
        arr = np.clip(arr / 255.0, 0.0, 1.0)
        image_cache[path] = arr
    return arr

def _first_available(metric, key, start_k):
    # find the first rank >= start_k that has a resolvable file
    idxs = filenames_cache[metric][key]
    for kk in range(start_k, min(len(idxs), K_NEIGHBORS)):
        if idxs[kk] and os.path.exists(idxs[kk]):
            return kk
    return start_k

def update_view(*_):
    metric = metric_toggle.value
    k = rank_slider.value
    if skip_missing.value:
        # ensure k refers to available files for each panel
        k = min(_first_available(metric, key, k) for key in keys)
        rank_slider.value = k

    lo, hi = rng_slider.value
    lo01, hi01 = lo/255.0, hi/255.0

    for ax, key, im_artist in zip(axes, keys, im_artists):
        path  = filenames_cache[metric][key][k]
        val_k = neighbors[metric][key]['val'][k]
        img01 = load_img_cached(path)
        if img01 is None:
            img01 = np.zeros((256, 256), dtype=np.float32)
            base = os.path.basename(path) if path else "(missing)"
            title_val = (f"cos sim=—" if metric == 'Cosine' else "L2=—")
            ax.set_title(f"{key}\nrank {k} • {title_val}\nMISSING • {base}")
        else:
            if metric == 'Cosine':
                sim = 1.0 - float(val_k)
                ax.set_title(f"{key}\nrank {k} • cos sim={sim:.4f}\n{os.path.basename(path)}")
            else:
                ax.set_title(f"{key}\nrank {k} • L2={float(val_k):.4f}\n{os.path.basename(path)}")
        im_artist.set_data(img01)
        im_artist.set_clim(lo01, hi01)
    fig.canvas.draw_idle()

def on_auto_click(_):
    metric = metric_toggle.value
    k = rank_slider.value
    path = filenames_cache[metric]['Arithmetic mean'][k]
    arr  = load_img_cached(path)
    if arr is None:
        rng_slider.value = (0, 255)
    else:
        lo = int(np.clip(round(np.percentile(arr, 2)  * 255.0), 0, 255))
        hi = int(np.clip(round(np.percentile(arr, 98) * 255.0), 0, 255))
        if lo >= hi:
            lo, hi = 0, 255
        rng_slider.value = (lo, hi)
    update_view()

def on_reset_click(_):
    rng_slider.value = (0, 255)
    update_view()

metric_toggle.observe(update_view, names='value')
rank_slider.observe(update_view, names='value')
rng_slider.observe(update_view, names='value')
btn_auto.on_click(on_auto_click)
btn_reset.on_click(on_reset_click)

controls = W.VBox([
    metric_toggle,
    rank_slider,
    rng_slider,
    W.HBox([btn_auto, btn_reset, skip_missing]),
])

display(W.VBox([controls, fig.canvas]))
update_view()


VBox(children=(VBox(children=(ToggleButtons(description='Metric:', options=('Euclidean', 'Cosine'), value='Euc…