In [1]:
# Latent interpolation browser (no duplicate static figure)
# - Randomly selects 100 pairs from the ROIs folder (top-level only)
# - For each pair, shows reconstructions at t = 0.0..1.0 (step 0.1) in a 2-row grid
# - No extra static figure is displayed

%matplotlib widget

import os, random, math
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as W
from IPython.display import display
from tifffile import imread

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, Conv2DTranspose, AveragePooling2D,
    Flatten, Dense, Reshape, ReLU
)

# ── Config ──────────────────────────────────────────────────────────────
DATASET_DIR = r'D:\Confocal_imaging_nuclei_tif\MIST_Fused_Images\ROIs'
WEIGHT_DIRS = [
    r'D:\Results\09052025_AE1M_Conv2DTranspose',
    r'D:/Results/09052025_AE1M_Conv2DTranspose',
]
ENC_NAME    = 'encoder_weights.h5'
DEC_NAME    = 'decoder_weights.h5'
LATENT_DIM  = 512

N_PAIRS     = 100          # number of pairs to browse
RESERVOIR_K = 1200         # how many candidate files to sample from top-level dir
T_VALUES    = np.round(np.linspace(0.0, 1.0, 11), 2)  # 0.0,0.1,...,1.0

# ── Build models to match your AE ───────────────────────────────────────
def build_encoder(latent_dim=LATENT_DIM):
    inp = Input((256,256,1), name='encoder_input')
    x = inp
    for filters in [16, 32, 64, 128]:
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
        x = AveragePooling2D(pool_size=2)(x)
    flat = Flatten()(x)
    z = Dense(latent_dim, name='z')(flat)
    return Model(inp, z, name='encoder')

def build_decoder(latent_dim=LATENT_DIM):
    z_in = Input((latent_dim,), name='z_sampling')
    x = Dense(16 * 16 * 128)(z_in)
    x = Reshape((16, 16, 128))(x)
    for filters in [128, 64, 32, 16]:
        x = Conv2DTranspose(filters, 3, strides=2, padding='same')(x); x = ReLU()(x)
        x = Conv2D(filters, 3, padding='same')(x); x = ReLU()(x)
    out = Conv2D(1, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
    return Model(z_in, out, name='decoder')

# Tame GPU memory spikes
try:
    for g in tf.config.list_physical_devices("GPU"):
        tf.config.experimental.set_memory_growth(g, True)
except Exception:
    pass

encoder = build_encoder()
decoder = build_decoder()

# Try to find and load weights
enc_path = dec_path = None
for d in WEIGHT_DIRS:
    e = os.path.join(d, ENC_NAME)
    c = os.path.join(d, DEC_NAME)
    if os.path.isfile(e) and os.path.isfile(c):
        enc_path, dec_path = e, c
        break

latent_ready = False
info_msg = []
if enc_path and dec_path:
    try:
        encoder.load_weights(enc_path)
        decoder.load_weights(dec_path)
        latent_ready = True
        info_msg.append(f"Loaded weights:\n• {enc_path}\n• {dec_path}")
    except Exception as ex:
        info_msg.append(f"⚠️ Could not load weights: {ex}")
else:
    info_msg.append("⚠️ Weights not found. Check WEIGHT_DIRS/ENC_NAME/DEC_NAME.")

# ── Helpers ─────────────────────────────────────────────────────────────
def normalize01(x):
    x = x.astype(np.float32)
    return x/255.0 if x.max() > 1.0 else x

def load_img_256_gray(path):
    img = imread(path)
    if img.ndim == 3:
        if img.shape[-1] == 1:    img = img[...,0]
        elif img.shape[-1] == 3:  img = 0.2126*img[...,0] + 0.7152*img[...,1] + 0.0722*img[...,2]
        else:                     img = img[...,0]
    if img.shape != (256,256):
        raise ValueError(f"{os.path.basename(path)} has shape {img.shape}, expected 256x256.")
    return normalize01(img)

def encode_img(img01):
    return encoder.predict(img01[None, ..., None], batch_size=8, verbose=0)[0]

def decode_many(z_batch):
    rec = decoder.predict(z_batch, batch_size=min(len(z_batch), 16), verbose=0)[..., 0]
    return np.clip(rec, 0.0, 1.0)

def reservoir_sample_tifs(folder, k=RESERVOIR_K):
    rng = random.Random()
    sample, count = [], 0
    with os.scandir(folder) as it:
        for entry in it:
            if not entry.is_file(): 
                continue
            name = entry.name.lower()
            if not (name.endswith('.tif') or name.endswith('.tiff')):
                continue
            count += 1
            if len(sample) < k:
                sample.append(entry.path)
            else:
                j = rng.randint(0, count-1)
                if j < k:
                    sample[j] = entry.path
    return sorted(sample)

def make_pairs(paths, n_pairs=N_PAIRS):
    rng = random.Random()
    pairs = []
    if len(paths) >= 2*n_pairs:
        rng.shuffle(paths)
        chosen = paths[:2*n_pairs]
        pairs = [(chosen[i], chosen[i+1]) for i in range(0, 2*n_pairs, 2)]
    else:
        for _ in range(n_pairs):
            a = rng.choice(paths); b = rng.choice(paths)
            pairs.append((a, b))
    return pairs

# ── Widgets ─────────────────────────────────────────────────────────────
w_info   = W.HTML(f"<pre style='white-space:pre-wrap'>Dataset: {DATASET_DIR}\n" + "\n".join(info_msg) + "</pre>")
w_resK   = W.BoundedIntText(value=RESERVOIR_K, min=100, max=10000, step=100, description="Candidates:")
w_rescan = W.Button(description="Sample 100 pairs", button_style="info")
w_pair   = W.IntSlider(value=0, min=0, max=N_PAIRS-1, step=1, description="Pair #", readout=True, continuous_update=False)
w_prev   = W.Button(description="Prev")
w_next   = W.Button(description="Next")
w_status = W.HTML("")

# ── Figure (no auto-display of a static snapshot) ───────────────────────
_was_interactive = plt.isinteractive()
plt.ioff()  # prevent Jupyter from auto-rendering a static figure during construction

fig = plt.figure(figsize=(11.5, 4.8))
gs  = fig.add_gridspec(2, 6, wspace=0.05, hspace=0.20)  # 12 slots; we'll use 11
axes, ims, titles = [], [], []
# row 1: t=0..0.5 (6 frames)
for i in range(6):
    ax = fig.add_subplot(gs[0, i])
    im = ax.imshow(np.zeros((256,256)), cmap="gray", vmin=0, vmax=1)
    ax.set_xticks([]); ax.set_yticks([])
    axes.append(ax); ims.append(im); titles.append(ax.set_title(f"t={T_VALUES[i]:.1f}", fontsize=8))
# row 2: t=0.6..1.0 (5 frames)
for i in range(5):
    ax = fig.add_subplot(gs[1, i])
    im = ax.imshow(np.zeros((256,256)), cmap="gray", vmin=0, vmax=1)
    ax.set_xticks([]); ax.set_yticks([])
    axes.append(ax); ims.append(im); titles.append(ax.set_title(f"t={T_VALUES[6+i]:.1f}", fontsize=8))

fig.canvas.toolbar_visible = True
fig.canvas.header_visible  = False

if _was_interactive:
    plt.ion()  # restore prior interactive state

# ── State & cache ───────────────────────────────────────────────────────
pairs = []                  # list of (pathA, pathB)
pair_cache = {}             # idx -> dict with 'imgs' (11,H,W), 'A','B','A_z','B_z'

def set_status(msg, ok=True):
    color = "#1b5e20" if ok else "#b71c1c"
    w_status.value = f"<span style='color:{color}'>{msg}</span>"

def show_frames(frames11, pair_idx):
    for k, im in enumerate(ims):
        im.set_data(frames11[k])
    a, b = pairs[pair_idx]
    fig.suptitle(
        f"Pair {pair_idx+1}/{N_PAIRS}\nA: {os.path.basename(a)}    B: {os.path.basename(b)}",
        fontsize=10
    )
    fig.canvas.draw_idle()

def compute_pair(idx):
    if idx in pair_cache:
        return pair_cache[idx]["imgs"]
    a_path, b_path = pairs[idx]
    A = load_img_256_gray(a_path)
    B = load_img_256_gray(b_path)
    A_z = encode_img(A)
    B_z = encode_img(B)
    Z = np.stack([(1.0 - t) * A_z + t * B_z for t in T_VALUES], axis=0)  # (11, latent_dim)
    frames = decode_many(Z)                                              # (11, 256, 256)
    pair_cache[idx] = {"imgs": frames, "A": a_path, "B": b_path, "A_z": A_z, "B_z": B_z}
    return frames

def on_rescan_clicked(_):
    if not latent_ready:
        set_status("Weights not loaded; cannot interpolate in latent space.", ok=False)
        return
    set_status("Sampling candidates…", ok=True)
    candidates = reservoir_sample_tifs(DATASET_DIR, k=int(w_resK.value))
    if len(candidates) == 0:
        set_status("No .tif/.tiff files found in the top-level of the folder.", ok=False)
        return
    global pairs, pair_cache
    pairs = make_pairs(candidates, n_pairs=N_PAIRS)
    pair_cache = {}
    w_pair.max = len(pairs)-1
    w_pair.value = 0
    set_status(f"Built {len(pairs)} random pairs from {len(candidates)} candidates.")
    try:
        frames = compute_pair(0)
        show_frames(frames, 0)
    except Exception as ex:
        set_status(f"Failed pair 1: {ex}", ok=False)

def on_pair_change(change):
    idx = int(change['new'])
    if not pairs:
        return
    try:
        frames = compute_pair(idx)
        show_frames(frames, idx)
        set_status(f"Showing pair {idx+1}/{len(pairs)}.")
    except Exception as ex:
        set_status(f"Failed pair {idx+1}: {ex}", ok=False)

def on_prev(_):
    if pairs:
        w_pair.value = max(0, w_pair.value - 1)

def on_next(_):
    if pairs:
        w_pair.value = min(w_pair.max, w_pair.value + 1)

w_rescan.on_click(on_rescan_clicked)
w_pair.observe(on_pair_change, names='value')
w_prev.on_click(on_prev)
w_next.on_click(on_next)



In [2]:

#| label: fig:ae1m-interpolation
#| caption: "."

# ── Display: only the interactive widget (no duplicate static) ──────────
container = W.VBox([
    w_info,
    W.HBox([w_resK, w_rescan, w_prev, w_next, w_pair]),
    w_status,
    fig.canvas  # embed the ipympl canvas in the widget layout
])
display(container)

# Kick off an initial sampling
on_rescan_clicked(None)

# Ensure the figure object itself isn't auto-returned as the cell's last value
None


VBox(children=(HTML(value="<pre style='white-space:pre-wrap'>Dataset: D:\\Confocal_imaging_nuclei_tif\\MIST_Fu…