# Decoded latent interpolations between pairs from NucleusNet-10K.

In [1]:
#| label: nucleusnet10k-interpolation

%matplotlib widget

import os, random, math, logging
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as W
from IPython.display import display
from tifffile import imread as tiff_imread
from urllib.parse import urlparse
import fsspec
from pathlib import Path

# Quiet HF logs
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["TQDM_DISABLE"] = "1"
for _name in ("fsspec", "huggingface_hub", "urllib3", "datasets"):
    logging.getLogger(_name).setLevel(logging.ERROR)

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, Conv2DTranspose, AveragePooling2D,
    Flatten, Dense, Reshape, ReLU
)

# ── Config ──────────────────────────────────────────────────────────────
REPO_ID    = "RussellBarkley/msa-em-figures"
BRANCH     = "main"
WEIGHT_DIRS = [
    Path("..") / "data" / "nucleus-ae",
]
ENC_NAME    = 'encoder_weights.h5'
DEC_NAME    = 'decoder_weights.h5'
LATENT_DIM  = 512

N_PAIRS   = 50
T_VALUES  = np.round(np.linspace(0.0, 1.0, 6), 2)

# ── Build models to match your AE ───────────────────────────────────────
def build_encoder(latent_dim=LATENT_DIM):
    inp = Input((256,256,1), name='encoder_input')
    x = inp
    for filters in [16, 32, 64, 128]:
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
        x = AveragePooling2D(pool_size=2)(x)
    flat = Flatten()(x)
    z = Dense(latent_dim, name='z')(flat)
    return Model(inp, z, name='encoder')

def build_decoder(latent_dim=LATENT_DIM):
    z_in = Input((latent_dim,), name='z_sampling')
    x = Dense(16 * 16 * 128)(z_in)
    x = Reshape((16, 16, 128))(x)
    for filters in [128, 64, 32, 16]:
        x = Conv2DTranspose(filters, 3, strides=2, padding='same')(x); x = ReLU()(x)
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
    out = Conv2D(1, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
    return Model(z_in, out, name='decoder')

# Tame GPU memory spikes
try:
    for g in tf.config.list_physical_devices("GPU"):
        tf.config.experimental.set_memory_growth(g, True)
except Exception:
    pass

encoder = build_encoder()
decoder = build_decoder()

# Try to find and load weights (silent)
enc_path = dec_path = None
for d in WEIGHT_DIRS:
    e = os.path.join(d, ENC_NAME)
    c = os.path.join(d, DEC_NAME)
    if os.path.isfile(e) and os.path.isfile(c):
        enc_path, dec_path = e, c
        break

latent_ready = False
if enc_path and dec_path:
    try:
        encoder.load_weights(enc_path)
        decoder.load_weights(dec_path)
        latent_ready = True
    except Exception:
        pass  # stay silent
# else: stay silent

# ── HF file/dataset discovery ───────────────────────────────────────────
USE_SIMPLECACHE = True
fs_hf = fsspec.filesystem("hf")

def _ensure_hf_uri(p: str) -> str:
    return p if p.startswith("hf://") else ("hf://" + p.lstrip("/"))

def _maybe_cache(uri: str) -> str:
    return f"simplecache::{uri}" if USE_SIMPLECACHE else uri

def _basename_from_uri(uri: str) -> str:
    return os.path.basename(urlparse(uri).path)

def _list_all_tiff_paths():
    globs = [
        f"hf://datasets/{REPO_ID}@{BRANCH}/**/*.tif",
        f"hf://datasets/{REPO_ID}@{BRANCH}/**/*.tiff",
    ]
    paths = []
    for pat in globs:
        try:
            for p in fs_hf.glob(pat):
                paths.append(_ensure_hf_uri(p))
        except Exception:
            pass
    return sorted(set(paths))

ALL_PATHS = _list_all_tiff_paths()
MODE = "files" if len(ALL_PATHS) > 0 else "dataset"

DS_OBJ = None
DS_LEN = 0
if MODE == "dataset":
    from datasets import load_dataset
    DS_OBJ = load_dataset(REPO_ID, split="train", keep_in_memory=False)
    DS_LEN = len(DS_OBJ)

# ── Helpers ─────────────────────────────────────────────────────────────
def normalize01(x):
    x = x.astype(np.float32)
    return x/255.0 if x.max() > 1.0 else x

def load_img_file_hf(uri: str):
    with fsspec.open(_maybe_cache(uri), "rb") as f:
        img = tiff_imread(f)
    if img.ndim == 3:
        if img.shape[-1] == 1:    img = img[...,0]
        elif img.shape[-1] == 3:  img = 0.2126*img[...,0] + 0.7152*img[...,1] + 0.0722*img[...,2]
        else:                     img = img[...,0]
    if img.shape != (256,256):
        raise ValueError(f"{_basename_from_uri(uri)} has shape {img.shape}, expected 256x256.")
    return normalize01(img)

def load_img_dataset_idx(idx: int):
    ex = DS_OBJ[int(idx)]
    arr = np.asarray(ex["image"])
    if arr.ndim == 3:
        if arr.shape[-1] == 1:    arr = arr[...,0]
        elif arr.shape[-1] == 3:  arr = 0.2126*arr[...,0] + 0.7152*arr[...,1] + 0.0722*arr[...,2]
        else:                     arr = arr[...,0]
    if arr.shape != (256,256):
        raise ValueError(f"sample_{idx} has shape {arr.shape}, expected 256x256.")
    return normalize01(arr)

def encode_img(img01):
    return encoder.predict(img01[None, ..., None], batch_size=8, verbose=0)[0]

def decode_many(z_batch):
    rec = decoder.predict(z_batch, batch_size=min(len(z_batch), 16), verbose=0)[..., 0]
    return np.clip(rec, 0.0, 1.0)

# ── Pair construction from FULL dataset (no 1k pre-sample) ──────────────
rng_py = random.Random()

def make_pairs_from_full(n_pairs=N_PAIRS):
    pairs = []
    if MODE == "files":
        N = len(ALL_PATHS)
        if N < 2:
            return []
        m = min(2*n_pairs, N)
        idxs = list(range(N))
        rng_py.shuffle(idxs)
        chosen = idxs[:m]
        for i in range(0, m - (m % 2), 2):
            pairs.append((ALL_PATHS[chosen[i]], ALL_PATHS[chosen[i+1]]))
        while len(pairs) < n_pairs:
            a = ALL_PATHS[rng_py.randrange(N)]
            b = ALL_PATHS[rng_py.randrange(N)]
            pairs.append((a, b))
    else:
        N = DS_LEN
        if N < 2:
            return []
        m = min(2*n_pairs, N)
        idxs = list(range(N))
        rng_py.shuffle(idxs)
        chosen = idxs[:m]
        for i in range(0, m - (m % 2), 2):
            pairs.append((chosen[i], chosen[i+1]))  # store indices
        while len(pairs) < n_pairs:
            a = rng_py.randrange(N)
            b = rng_py.randrange(N)
            pairs.append((a, b))
    return pairs

# ── Widgets ─────────────────────────────────────────────────────────────
w_rescan = W.Button(description="Resample pairs", button_style="info")
w_pair   = W.IntSlider(value=0, min=0, max=N_PAIRS-1, step=1, description="Pair #", readout=True, continuous_update=False)
w_prev   = W.Button(description="Prev")
w_next   = W.Button(description="Next")
w_status = W.HTML("")

# ── Figure (single small row, ~4x2 in; titles above each image) ─────────
_was_interactive = plt.isinteractive()
plt.ioff()

K = len(T_VALUES)
n_cols, n_rows = K, 1

FIG_W_IN, FIG_H_IN, DPI = 4.4, 2.0, 150
fig = plt.figure(figsize=(FIG_W_IN, FIG_H_IN), dpi=DPI, constrained_layout=False)
gs  = fig.add_gridspec(n_rows, n_cols)

axes, ims = [], []
for k in range(K):
    ax = fig.add_subplot(gs[0, k])
    im = ax.imshow(np.zeros((256, 256)), cmap="gray", vmin=0, vmax=1)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_frame_on(False)
    ax.set_title(f"t={T_VALUES[k]:.2f}", fontsize=6, pad=2)  # ← title above the image
    axes.append(ax); ims.append(im)

# Tight spacing; no forced 100% width so it won't zoom/stretch
fig.subplots_adjust(left=0.01, right=0.99, bottom=0.02, top=0.88, wspace=0.02)
fig.canvas.header_visible  = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.width    = "auto"
fig.canvas.layout.height   = "auto"

if _was_interactive:
    plt.ion()


# ── State & cache ───────────────────────────────────────────────────────
pairs = []   # files mode: list[(uriA, uriB)], dataset mode: list[(idxA, idxB)]
pair_cache = {}  # idx -> dict('imgs', 'A','B','A_z','B_z')

def set_status(msg, ok=True):
    # keep a minimal status line; still helpful for errors
    color = "#1b5e20" if ok else "#b71c1c"
    w_status.value = f"<span style='color:{color}'>{msg}</span>"

def _handle_to_name(h):
    if MODE == "files":
        return _basename_from_uri(h)
    else:
        return f"sample_{int(h):06d}.tif"

def _load_handle(h):
    if MODE == "files":
        return load_img_file_hf(h)
    else:
        return load_img_dataset_idx(int(h))

def show_frames(framesK, pair_idx):
    for k, im in enumerate(ims):
        im.set_data(framesK[k])
    fig.canvas.draw_idle()


def compute_pair(idx):
    if idx in pair_cache:
        return pair_cache[idx]["imgs"]
    a_h, b_h = pairs[idx]
    A = _load_handle(a_h)
    B = _load_handle(b_h)
    A_z = encode_img(A)
    B_z = encode_img(B)
    Z = np.stack([(1.0 - t) * A_z + t * B_z for t in T_VALUES], axis=0)  # (K, latent_dim)
    frames = decode_many(Z)                                              # (K, 256, 256)
    pair_cache[idx] = {"imgs": frames, "A": a_h, "B": b_h, "A_z": A_z, "B_z": B_z}
    return frames

def on_rescan_clicked(_):
    if not latent_ready:
        set_status("Weights not loaded; cannot interpolate in latent space.", ok=False)
        return
    if MODE == "files" and len(ALL_PATHS) < 2:
        set_status("No TIFFs found in repo.", ok=False)
        return
    if MODE == "dataset" and DS_LEN < 2:
        set_status("Dataset is empty.", ok=False)
        return

    global pairs, pair_cache
    pairs = make_pairs_from_full(N_PAIRS)
    pair_cache = {}
    if not pairs:
        set_status("Could not build pairs from the repository.", ok=False)
        return

    w_pair.max = len(pairs)-1
    w_pair.value = 0
    # no verbose banner; just build silently and show first pair
    try:
        frames = compute_pair(0)
        show_frames(frames, 0)
    except Exception as ex:
        set_status(f"Failed pair 1: {ex}", ok=False)

def on_pair_change(change):
    idx = int(change['new'])
    if not pairs:
        return
    try:
        frames = compute_pair(idx)
        show_frames(frames, idx)
    except Exception as ex:
        set_status(f"Failed pair {idx+1}: {ex}", ok=False)

def on_prev(_):
    if pairs:
        w_pair.value = max(0, w_pair.value - 1)

def on_next(_):
    if pairs:
        w_pair.value = min(w_pair.max, w_pair.value + 1)

w_rescan.on_click(on_rescan_clicked)
w_pair.observe(on_pair_change, names='value')
w_prev.on_click(on_prev)
w_next.on_click(on_next)

# ── Display ──────────
container = W.VBox([
    W.HBox([w_rescan, w_prev, w_next, w_pair]),
    fig.canvas
], layout=W.Layout(width="100%"))
display(container)

# Kick off an initial build
on_rescan_clicked(None)

None


VBox(children=(HBox(children=(Button(button_style='info', description='Resample pairs', style=ButtonStyle()), …