# EM Compression — Demo Overview

Below is the demo overview for three EM (electron microscopy) compression / reconstruction modes. Each mode targets a different balance of compression ratio, segmentation reliability, and fine-structure fidelity.

---

## Quick Summary

| Mode | Default Compression | Best For | Key Behavior | Known Artifacts | Extra Training |
|---|---:|---|---|---|---|
| 1) **Top-Only Training / Reconstruction** | **1024×** | Whole-cell segmentation without a trained prior | Membrane prediction remains stable even at extreme compression | Occasional **2D membrane breaks**; requires **3D repair** | None |
| 2) **Two-Level VQ-VAE** | **204×** | Preserving overall cellular **texture** with moderate compression | Good global appearance retention | Possible **vesicle** shape deformation | Optional **cross-attention prior** to correct deformation |
| 3) **Two-Level VQ-VAE + Transformer Prior** | **1024×** | High compression **with more detail** restored | **Top-VQ tokens** condition the prediction of **Bottom-VQ tokens**; bottom level fills details | Higher training cost | Requires **Transformer prior**; pairs well with Mode-2 **teacher net** |

---

## 1) Top-Only Training / Reconstruction (1024×)

- **Use when:** You do **not** have a trained prior network and want **extreme compression** to segment entire cells while ignoring most intracellular detail.
- **Behavior:** Cell membrane predictions are generally stable at **1024×**.
- **Limitation:** **Membrane breaks in 2D slices** can occur.
- **Mitigation:** Perform **3D reconstruction / consistency repair** (e.g., connect components across slices) to restore continuity.

---

## 2) Two-Level VQ-VAE (204×)

- **Use when:** You need **overall texture preservation** with substantial compression.
- **Behavior:** Maintains global cellular appearance at ~**204×**.
- **Limitation:** May introduce **vesicle** (not “vehicle”) deformation artifacts.
- **Mitigation:** Train a **targeted prior** with **cross-attention** to correct local shape issues **without changing the compression ratio**.

---

## 3) Two-Level VQ-VAE + Transformer Prior (1024×)

- **Use when:** You want **1024× compression** **and** to **restore more details** than Mode 1.
- **Idea:** Train a **Transformer prior** to **predict Bottom-VQ tokens conditioned on Top-VQ tokens**. Treat **Top-VQ** as the coarse **backbone**, and use **Bottom-VQ** to **fill in details**.
- **Training:** Requires an additional **Transformer**. For best results, **pair with the teacher/prior network** from Mode 2.

---

## Notes & Recommendations

- **3D repair is essential** for Mode 1 to address 2D membrane breaks.
- If **vesicle deformation** is unacceptable at 204× (Mode 2), add a **cross-attention prior** tailored to vesicle morphology.
- For **max compression with detail** (Mode 3), the **Transformer prior** uses **Top-VQ** as context to recover **Bottom-VQ**; combining it with the **teacher net** from Mode 2 typically improves perceptual quality and segmentation consistency.

---

## Evaluation & Repro Tips

- Fix random seeds for comparability.
- Report both **perceptual** metrics (e.g., SSIM/LPIPS) and **task** metrics (e.g., membrane F1, segmentation accuracy).
- Always evaluate **in 3D** when the downstream task is 3D segmentation.

---

## Glossary (brief)

- **Top-VQ / Bottom-VQ:** Codebooks from the two-level VQ-VAE; **Top-VQ** captures coarse structure (context), **Bottom-VQ** refines local detail.
- **Prior (Transformer / Cross-attention):** A learned model used to predict or refine discrete VQ token sequences to improve reconstruction quality at fixed bitrate.


## 1) Top-Only Training / Reconstruction (1024×)

In [None]:
# ==============================================================
# VQ-VAE Top-Only (32x32 latent, EMA codebook) — End-to-End Pipeline
# - Reads dataset zip from Google Drive (compression_em.zip)
# - Unzips locally, sorts by filename, first 400 for training, last 100 for prediction
# - Handles large images (~2k x 4k) by tiling into 1024x1024 patches
# - Trains, exports bundle, reconstructs full-res predictions by stitching tiles
# ==============================================================

import os, time, json, math, zipfile, shutil, glob, re
import numpy as np
import tensorflow as tf
from PIL import Image
from tqdm import tqdm

# -----------------------------
# (0) Global Hyperparams
# -----------------------------
IMAGE_SIZE       = int(globals().get('IMAGE_SIZE', 1024))   # tile size for training/inference
TILE             = IMAGE_SIZE
TOP_GRID         = int(globals().get('TOP_GRID', 32))       # latent grid (32x32)

NUM_HIDDENS_TOP    = int(globals().get('NUM_HIDDENS_TOP', 128))
EMBEDDING_DIM_TOP  = int(globals().get('EMBEDDING_DIM_TOP', 96))
NUM_EMBEDDINGS_TOP = int(globals().get('NUM_EMBEDDINGS_TOP', 256))
COMMITMENT_COST_TOP= float(globals().get('COMMITMENT_COST_TOP', 1.0))

BATCH_SIZE       = int(globals().get('BATCH_SIZE', 2))
EPOCHS_WARMUP    = int(globals().get('EPOCHS_WARMUP', 100))
LR_WARMUP        = float(globals().get('LR_WARMUP', 2e-4))
WEIGHT_DECAY     = float(globals().get('WEIGHT_DECAY', 1e-4))

# Drive dataset zip path (adjust if different)
DRIVE_ZIP = "/content/drive/MyDrive/compression_em.zip"
LOCAL_ZIP = "/content/compression_em.zip"
EXTRACT_DIR = "/content/data/compression_em"

# -----------------------------
# (A) Residual helpers
# -----------------------------

def residual_block(x, filters):
    skip = x
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(filters//2, 3, padding='same')(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(filters, 1, padding='same')(x)
    return tf.keras.layers.Add()([skip, x])


def residual_stack(x, filters, num_blocks=2):
    for _ in range(num_blocks):
        x = residual_block(x, filters)
    return tf.keras.layers.ReLU()(x)

# -----------------------------
# (B) EMA VectorQuantizer
# -----------------------------
class VectorQuantizerEMA(tf.keras.layers.Layer):
    def __init__(self, num_embeddings, embedding_dim,
                 commitment_cost=0.25, decay=0.99, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.num_embeddings = int(num_embeddings)
        self.embedding_dim = int(embedding_dim)
        self.commitment_cost = float(commitment_cost)
        self.decay = float(decay)
        self.epsilon = float(epsilon)
        self._perplexity = tf.keras.metrics.Mean(name=f'{self.name}_perplexity')

    @property
    def metrics(self):
        return [self._perplexity]

    def build(self, input_shape):
        self.embeddings = self.add_weight(
            name='embeddings',
            shape=(self.embedding_dim, self.num_embeddings),
            initializer=tf.keras.initializers.RandomUniform(-1.0/self.num_embeddings, 1.0/self.num_embeddings),
            trainable=False
        )
        self.ema_cluster_size = self.add_weight(
            name='ema_cluster_size', shape=(self.num_embeddings,), initializer='zeros', trainable=False)
        self.ema_dw = self.add_weight(
            name='ema_dw', shape=(self.embedding_dim, self.num_embeddings), initializer='zeros', trainable=False)

    def call(self, inputs, training=None):
        shp  = tf.shape(inputs)                              # [B,H,W,C]
        flat = tf.reshape(inputs, [-1, self.embedding_dim])  # [N,C]

        dists = (tf.reduce_sum(flat**2, axis=1, keepdims=True)
                 - 2.0 * tf.matmul(flat, self.embeddings)
                 + tf.reduce_sum(self.embeddings**2, axis=0, keepdims=True))  # [N,K]
        idx  = tf.argmax(-dists, axis=1)                     # [N]
        one  = tf.one_hot(idx, self.num_embeddings, dtype=flat.dtype)  # [N,K]
        quant = tf.matmul(one, tf.transpose(self.embeddings))          # [N,C]
        quant = tf.reshape(quant, shp)                                   # [B,H,W,C]

        def _ema_update():
            cluster_size = tf.reduce_sum(one, axis=0)
            dw = tf.matmul(tf.transpose(flat), one)
            ema_cs = self.ema_cluster_size * self.decay + cluster_size * (1.0 - self.decay)
            ema_dw = self.ema_dw * self.decay + dw * (1.0 - self.decay)
            n = tf.reduce_sum(ema_cs)
            smoothed_cs = (ema_cs + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
            new_embed = ema_dw / tf.expand_dims(smoothed_cs, 0)
            self.ema_cluster_size.assign(ema_cs)
            self.ema_dw.assign(ema_dw)
            self.embeddings.assign(new_embed)
            return 0.0

        if training is None:
            training = tf.keras.backend.learning_phase()
        tf.cond(tf.cast(training, tf.bool), _ema_update, lambda: 0.0)

        # commitment loss
        e_loss = tf.reduce_mean((tf.stop_gradient(quant) - inputs)**2)
        self.add_loss(self.commitment_cost * e_loss)

        # perplexity
        avg_p = tf.reduce_mean(one, axis=0)
        perp  = tf.exp(-tf.reduce_sum(avg_p * tf.math.log(avg_p + 1e-10)))
        self._perplexity.update_state(perp)

        # straight-through
        quant = inputs + tf.stop_gradient(quant - inputs)
        return quant

# -----------------------------
# (C) Encoder/Decoder builders
# -----------------------------

def build_encoder_to_top(img_size, top_grid, num_hiddens):
    """Downsample 1024->32 via 5x stride-2."""
    assert (img_size % top_grid) == 0
    n_downs = int(math.log2(img_size // top_grid))  # 1024->32 = 5
    inputs = tf.keras.Input(shape=(img_size, img_size, 1))
    x = inputs
    chans = [64, 128, 128, 128, num_hiddens][:n_downs]
    for c in chans:
        x = tf.keras.layers.Conv2D(c, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)
    x = residual_stack(x, num_hiddens, num_blocks=2)
    return tf.keras.Model(inputs, x, name=f'encoder_to_top_{top_grid}')


def build_decoder_top_to_image(img_size, top_grid, in_ch):
    """Upsample 32->1024 via 5x deconv."""
    assert (img_size % top_grid) == 0
    n_ups = int(math.log2(img_size // top_grid))    # 32->1024 = 5
    inputs = tf.keras.Input(shape=(top_grid, top_grid, in_ch))
    x = inputs
    ups_ch = [256, 128, 128, 64, 32][:n_ups]
    for c in ups_ch:
        x = tf.keras.layers.Conv2DTranspose(c, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)
    out = tf.keras.layers.Conv2D(1, 3, padding='same')(x)
    return tf.keras.Model(inputs, out, name=f'decoder_top_{top_grid}_to_{IMAGE_SIZE}')

# -----------------------------
# (D) Assemble Model
# -----------------------------
enc_top  = build_encoder_to_top(IMAGE_SIZE, TOP_GRID, NUM_HIDDENS_TOP)
pre_vq_top   = tf.keras.layers.Conv2D(EMBEDDING_DIM_TOP, 1, name='pre_vq_top')
vq_top       = VectorQuantizerEMA(NUM_EMBEDDINGS_TOP, EMBEDDING_DIM_TOP,
                                  commitment_cost=COMMITMENT_COST_TOP, decay=0.99, epsilon=1e-5, name='vq_top')
post_vq_top  = tf.keras.layers.Conv2D(NUM_HIDDENS_TOP, 1, name='post_vq_top')
dec_top = build_decoder_top_to_image(IMAGE_SIZE, TOP_GRID, in_ch=NUM_HIDDENS_TOP)

class VQTopOnlyEMA(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.enc_top   = enc_top
        self.pre_vq_top= pre_vq_top
        self.vq_top    = vq_top
        self.post_vq_top = post_vq_top
        self.dec_top   = dec_top
        self.norm_t    = tf.keras.layers.LayerNormalization(axis=-1)
        self.drop_t    = tf.keras.layers.SpatialDropout2D(0.1)

    def call(self, x, training=None):
        ht  = self.enc_top(x)                   # [B,32,32,Ht]
        zt  = self.pre_vq_top(ht)               # [B,32,32,Ct]
        ztq = self.vq_top(zt, training=training)
        ztq = self.post_vq_top(ztq)             # [B,32,32,Ht]
        ztq = self.norm_t(ztq)
        ztq = self.drop_t(ztq, training=training)
        y   = self.dec_top(ztq)                 # [B,1024,1024,1]
        return y

model = VQTopOnlyEMA()

# -----------------------------
# (E) Compile
# -----------------------------

def ssim_metric(y_true, y_pred):
    return tf.reduce_mean(tf.image.ssim(y_true + 0.5, y_pred + 0.5, max_val=1.0))

def psnr_metric(y_true, y_pred):
    return tf.reduce_mean(tf.image.psnr(y_true + 0.5, y_pred + 0.5, max_val=1.0))

def compile_warmup(m, lr=LR_WARMUP, wd=WEIGHT_DECAY):
    try:
        opt = tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=wd, beta_1=0.9, beta_2=0.95)
    except Exception:
        opt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.9, beta_2=0.95)
    m.compile(optimizer=opt, loss='mae', metrics=[psnr_metric, ssim_metric])

compile_warmup(model)

# -----------------------------
# (F) Data IO: unzip -> split -> tf.data with tiling
# -----------------------------

def _natural_key(s):
    return [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', os.path.basename(s))]

# Copy zip from Drive -> local to speed up reads
os.makedirs(os.path.dirname(LOCAL_ZIP), exist_ok=True)
shutil.copy(DRIVE_ZIP, LOCAL_ZIP)

# Unzip locally
os.makedirs(EXTRACT_DIR, exist_ok=True)
with zipfile.ZipFile(LOCAL_ZIP, 'r') as zf:
    zf.extractall(EXTRACT_DIR)

# Collect image files
IMG_EXTS = (".png", ".jpg", ".jpeg", ".JPG", ".JPEG", ".PNG", ".tif", ".tiff", ".TIF", ".TIFF")
all_files = [p for p in glob.glob(os.path.join(EXTRACT_DIR, "**", "*"), recursive=True)
             if os.path.splitext(p)[1] in IMG_EXTS]
all_files = sorted(all_files, key=_natural_key)
assert len(all_files) >= 500, f"Found {len(all_files)} files, expected ~500."

train_files = all_files[:400]
pred_files  = all_files[400:500]  # last 100
print(f"Found {len(all_files)} images -> train {len(train_files)}, predict {len(pred_files)}")

AUTO = tf.data.AUTOTUNE

# Image decode (PNG/JPG via TF, TIFF via PIL fallback)

def _decode_any_image(path):
    path = tf.convert_to_tensor(path)
    ext = tf.strings.lower(tf.strings.regex_replace(path, r'^.*\.', '.'))
    img_bin = tf.io.read_file(path)

    def decode_png_jpg():
        img = tf.io.decode_image(img_bin, channels=0, expand_animations=False)
        return img

    def decode_tiff_py(p):
        import numpy as np
        from PIL import Image as _PILImage
        arr = np.array(_PILImage.open(p.decode('utf-8')))
        if arr.ndim == 2:
            arr = arr[..., None]
        return arr

    is_tiff = tf.reduce_any([tf.equal(ext, s) for s in [".tif", ".tiff"]])
    img = tf.cond(
        is_tiff,
        lambda: tf.numpy_function(decode_tiff_py, [path], Tout=tf.uint8),
        lambda: decode_png_jpg()
    )
    img.set_shape([None, None, None])  # H W C
    c = tf.shape(img)[-1]
    img = tf.cond(tf.equal(c, 1), lambda: img, lambda: tf.image.rgb_to_grayscale(img))
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]
    img = img - 0.5                                  # [-0.5,0.5]
    return img  # [H,W,1]


def _pad_to_multiple_tf(img, mult=TILE):
    h = tf.shape(img)[0]; w = tf.shape(img)[1]
    pad_h = (mult - (h % mult)) % mult
    pad_w = (mult - (w % mult)) % mult
    img = tf.pad(img, [[0, pad_h], [0, pad_w], [0, 0]], mode='SYMMETRIC')
    return img


def _tile_1024_tf(img):
    img = _pad_to_multiple_tf(img, TILE)
    ks = [1, TILE, TILE, 1]
    st = [1, TILE, TILE, 1]
    patches = tf.image.extract_patches(
        images=tf.expand_dims(img, 0),
        sizes=ks, strides=st, rates=[1,1,1,1], padding='VALID')
    n_h = tf.shape(patches)[1]
    n_w = tf.shape(patches)[2]
    patches = tf.reshape(patches, [n_h * n_w, TILE, TILE, 1])
    return patches


def files_to_tiles_ds(files, batch, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices(files)
    if shuffle:
        ds = ds.shuffle(buffer_size=len(files), reshuffle_each_iteration=True)

    def map_decode(path):
        img = _decode_any_image(path)
        tiles = _tile_1024_tf(img)
        tile_ds = tf.data.Dataset.from_tensor_slices(tiles)
        return tile_ds

    ds = ds.interleave(map_decode, cycle_length=8, num_parallel_calls=AUTO, deterministic=False)
    if shuffle:
        ds = ds.shuffle(2048)
    ds = ds.map(lambda x: (x, x), num_parallel_calls=AUTO)  # autoencoder target
    ds = ds.batch(batch).prefetch(AUTO)
    return ds

train_ds = files_to_tiles_ds(train_files, batch=BATCH_SIZE, shuffle=True)
val_ds   = files_to_tiles_ds(pred_files,  batch=BATCH_SIZE, shuffle=False)

for xb, yb in val_ds.take(1):
    print("Sample batch:", xb.shape, yb.shape)

# -----------------------------
# (G) Callbacks & Export helpers
# -----------------------------

def make_top_callbacks(out_dir,
                       monitor_primary="val_ssim_metric",
                       monitor_secondary="val_loss"):
    os.makedirs(out_dir, exist_ok=True)
    return [
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(out_dir, "top_best_by_ssim.weights.h5"),
            save_weights_only=True, save_best_only=True,
            monitor=monitor_primary, mode="max", verbose=1),
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(out_dir, "top_best_by_vloss.weights.h5"),
            save_weights_only=True, save_best_only=True,
            monitor=monitor_secondary, mode="min", verbose=1),
        tf.keras.callbacks.CSVLogger(os.path.join(out_dir, "top_log.csv")),
        tf.keras.callbacks.TensorBoard(log_dir=os.path.join(out_dir, "tb_top")),
    ]


def _to_uint8(x):
    x = np.clip(x, 0.0, 1.0)
    return (x * 255.0 + 0.5).astype(np.uint8)


def save_png01(img01, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    Image.fromarray(_to_uint8(img01)).save(path, 'PNG')


def _top_manifest_dict():
    return {
        "IMAGE_SIZE": IMAGE_SIZE,
        "TOP_GRID": TOP_GRID,
        "NUM_EMBEDDINGS_TOP": NUM_EMBEDDINGS_TOP,
        "EMBEDDING_DIM_TOP": EMBEDDING_DIM_TOP,
        "COMMITMENT_COST_TOP": COMMITMENT_COST_TOP,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }


def export_top_bundle(model, out_root="/content/drive/MyDrive/vqvae_toponly_runs",
                      run_dir=None, history=None, take_val=1, also_savedmodel=False):
    os.makedirs(out_root, exist_ok=True)
    stamp = time.strftime("%Y%m%d-%H%M%S") if run_dir is None else run_dir
    bundle_dir = os.path.join(out_root, stamp)
    os.makedirs(bundle_dir, exist_ok=True)

    # build once
    _ = model(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 1]), training=False)

    # 1) weights
    w_path = os.path.join(bundle_dir, "toponly.weights.h5")
    model.save_weights(w_path)

    # 2) codebook + EMA
    np.save(os.path.join(bundle_dir, "vq_top_emb.npy"), model.vq_top.embeddings.numpy())
    np.save(os.path.join(bundle_dir, "vq_top_ema_cluster_size.npy"), model.vq_top.ema_cluster_size.numpy())
    np.save(os.path.join(bundle_dir, "vq_top_ema_dw.npy"), model.vq_top.ema_dw.numpy())

    # 3) manifest
    with open(os.path.join(bundle_dir, "manifest.json"), "w") as f:
        json.dump(_top_manifest_dict(), f, indent=2)

    # 4) history
    if history is not None and hasattr(history, "history"):
        with open(os.path.join(bundle_dir, "history.json"), "w") as f:
            json.dump(history.history, f, indent=2)

    # 5) sample from val_ds
    sample_dir = os.path.join(bundle_dir, "samples"); os.makedirs(sample_dir, exist_ok=True)
    taken = 0
    for xb, _ in val_ds.take(take_val):
        yb = model(xb, training=False)
        save_png01((xb.numpy()[0,...,0] + 0.5), os.path.join(sample_dir, f"ae_input_val0.png"))
        save_png01((yb.numpy()[0,...,0] + 0.5), os.path.join(sample_dir, f"ae_recon_val0.png"))
        taken += 1
        if taken >= take_val: break

    # 6) SavedModel
    if also_savedmodel:
        sm_dir = os.path.join(bundle_dir, "saved_model")
        model.save(sm_dir)

    # 7) latest marker
    latest = os.path.join(out_root, "latest")
    try:
        if os.path.islink(latest) or os.path.exists(latest):
            os.remove(latest)
        os.symlink(bundle_dir, latest)
    except Exception:
        with open(os.path.join(out_root, "LATEST.txt"), "w") as f:
            f.write(bundle_dir)

    print(f"✅ Top-only bundle saved to: {bundle_dir}")
    return bundle_dir

# -----------------------------
# (H) Bit-accurate compression math (Top-only)
# -----------------------------
bits_per_top  = int(math.ceil(math.log2(NUM_EMBEDDINGS_TOP)))   # 256 -> 8
tokens_top    = TOP_GRID * TOP_GRID                             # 1024
latent_bits   = tokens_top * bits_per_top                       # 8192
orig_bits     = IMAGE_SIZE * IMAGE_SIZE * 8                     # 8,388,608 (for 1024x1024, 8-bit)
bpp           = latent_bits / (IMAGE_SIZE * IMAGE_SIZE)         # ≈0.0078125
ratio         = orig_bits / latent_bits                         # ≈1024×

print("="*66)
print(f"Bit-accurate compression (Top-only): {ratio:.1f}×  (~{bpp:.5f} bpp)")
print(f" - tokens: top {tokens_top}")
print(f" - bits/token: top {bits_per_top}")
print(f" - commitment(top) = {COMMITMENT_COST_TOP}")
print("="*66)

# -----------------------------
# (I) Train → load best → export
# -----------------------------
run_stamp = time.strftime("%Y%m%d-%H%M%S")
out_root = "/content/drive/MyDrive/vqvae_toponly_runs"
run_dir   = os.path.join(out_root, run_stamp)

cbs = make_top_callbacks(
    run_dir,
    monitor_primary="val_ssim_metric",
    monitor_secondary="val_loss",
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_WARMUP,
    callbacks=cbs,
    verbose=1
)

best_ckpt = os.path.join(run_dir, "top_best_by_vloss.weights.h5")
model.load_weights(best_ckpt)
print(f"Loaded best weights: {best_ckpt}")

bundle_dir = export_top_bundle(
    model,
    out_root=out_root,
    run_dir=run_stamp,
    history=history,
    also_savedmodel=False
)
print("Exported:", bundle_dir)

# -----------------------------
# (J) Predict last-100 full images -> stitch back to full size
# -----------------------------

def reconstruct_image(model, path, tile=None, batch=4, out_dir="/content/recons_full"):
    """Reconstruct one full image by tiling with given tile size and stitching back.
    Saves PNG as <basename>_recon.png, returns path.
    """
    if tile is None:
        tile = IMAGE_SIZE
    os.makedirs(out_dir, exist_ok=True)

    # Decode to numpy using the same preproc
    img_t = _decode_any_image(tf.constant(path))   # [-0.5,0.5]
    img = img_t.numpy()
    h0, w0 = img.shape[:2]

    # Pad symmetrically to multiple of tile
    pad_h = (tile - (h0 % tile)) % tile
    pad_w = (tile - (w0 % tile)) % tile
    img_pad = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode='symmetric')
    h, w = img_pad.shape[:2]

    n_h = h // tile
    n_w = w // tile

    # Batch through the model
    recons_tiles = []
    batch_buf = []
    for ih in range(n_h):
        for iw in range(n_w):
            y = ih * tile; x = iw * tile
            patch = img_pad[y:y+tile, x:x+tile, :]
            batch_buf.append(patch)
            if len(batch_buf) == batch:
                xb = tf.convert_to_tensor(np.stack(batch_buf, axis=0), dtype=tf.float32)
                yb = model(xb, training=False).numpy()
                recons_tiles.append(yb)
                batch_buf = []
    if len(batch_buf) > 0:
        xb = tf.convert_to_tensor(np.stack(batch_buf, axis=0), dtype=tf.float32)
        yb = model(xb, training=False).numpy()
        recons_tiles.append(yb)
        batch_buf = []
    recons = np.concatenate(recons_tiles, axis=0) if len(recons_tiles) else np.zeros((0, tile, tile, 1), np.float32)

    # Stitch back
    canvas = np.zeros((h, w, 1), dtype=np.float32)
    k = 0
    for ih in range(n_h):
        for iw in range(n_w):
            y = ih * tile; x = iw * tile
            canvas[y:y+tile, x:x+tile, :] = recons[k]
            k += 1

    # Crop to original size and save
    canvas = canvas[:h0, :w0, :]
    out = (np.clip(canvas + 0.5, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
    base = os.path.splitext(os.path.basename(path))[0]
    save_path = os.path.join(out_dir, f"{base}_recon.png")
    Image.fromarray(out[...,0]).save(save_path, "PNG")
    return save_path

if RUN_TRAIN:
    # keep the original behavior: only last-100
    files_to_recon = pred_files
    full_out_dir = os.path.join(run_dir, "recons_full")
    recon_desc = "Predict & stitch last-100"
else:
    # reconstruct ALL images using the provided checkpoint
    files_to_recon = all_files
    full_out_dir = os.path.join(run_dir, "recons_full_all")
    recon_desc = f"Reconstruct ALL ({len(all_files)}) from CKPT"

print("Reconstruction output dir:", full_out_dir)
os.makedirs(full_out_dir, exist_ok=True)
saved = []
for p in tqdm(files_to_recon, desc=recon_desc):
    saved.append(reconstruct_image(model, p, tile=IMAGE_SIZE, batch=max(1, BATCH_SIZE), out_dir=full_out_dir))
print(f"Saved {len(saved)} reconstructions to {full_out_dir}")


In [None]:
# =========================== Eval: Top-only 重组第一张测试图 ===========================
import os, glob, json, time, math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image


# model.load_weights("/content/drive/MyDrive/vqvae_toponly_runs/你的时间戳/top_best_by_ssim.weights.h5")

def _to_uint8(x):
    x = np.clip(x, 0.0, 1.0)
    return (x * 255.0 + 0.5).astype(np.uint8)

def save_png01(img01, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    Image.fromarray(_to_uint8(img01)).save(path, 'PNG')

def find_first_test_image(test_dir):
    paths = sorted(glob.glob(os.path.join(test_dir, '*.png')))
    if len(paths) == 0:
        paths = sorted(glob.glob(os.path.join(test_dir, '**/*.png'), recursive=True))
    if len(paths) == 0:
        raise FileNotFoundError(f"No PNG images found under {test_dir}")
    return paths[0]


def load_img_for_model(path):
    return load_image(path)

def eval_recon_toponly(model, test_path,
                       out_root="/content/drive/MyDrive/vqvae_toponly_eval"):
    os.makedirs(out_root, exist_ok=True)
    stamp = time.strftime("%Y%m%d-%H%M%S")
    out_dir = os.path.join(out_root, stamp)
    os.makedirs(out_dir, exist_ok=True)


    x  = load_img_for_model(test_path)          # [H,W,1], [-0.5,0.5]
    x1 = tf.expand_dims(x, 0)                   # [1,H,W,1]
    y1 = model(x1, training=False)              # [1,H,W,1]


    x01 = (x1.numpy()[0, ..., 0] + 0.5)
    y01 = (y1.numpy()[0, ..., 0] + 0.5)


    psnr = float(tf.image.psnr(x1 + 0.5, y1 + 0.5, max_val=1.0).numpy().mean())
    ssim = float(tf.image.ssim(x1 + 0.5, y1 + 0.5, max_val=1.0).numpy().mean())


    in_path = os.path.join(out_dir, "input.png")
    re_path = os.path.join(out_dir, "recon.png")
    save_png01(x01, in_path)
    save_png01(y01, re_path)


    fig = plt.figure(figsize=(8, 4), dpi=150)
    ax = plt.subplot(1, 2, 1); ax.imshow(x01, cmap='gray'); ax.set_title("Input"); ax.axis('off')
    ax = plt.subplot(1, 2, 2); ax.imshow(y01, cmap='gray'); ax.set_title("Recon"); ax.axis('off')
    panel_path = os.path.join(out_dir, "panel.png")
    plt.tight_layout(); plt.savefig(panel_path, dpi=150); plt.close(fig)


    bits_per_top = int(math.ceil(math.log2(int(NUM_EMBEDDINGS_TOP))))
    tokens_top   = int(TOP_GRID) * int(TOP_GRID)
    latent_bits  = tokens_top * bits_per_top
    orig_bits    = int(IMAGE_SIZE) * int(IMAGE_SIZE) * 8
    ratio        = orig_bits / latent_bits
    bpp          = latent_bits / (IMAGE_SIZE * IMAGE_SIZE)

    info = {
        "test_path": test_path,
        "psnr": round(psnr, 4),
        "ssim": round(ssim, 6),
        "tokens_top": tokens_top,
        "bits_per_token": bits_per_top,
        "latent_bits": latent_bits,
        "orig_bits": orig_bits,
        "compression_ratio": round(ratio, 2),
        "bpp": bpp
    }
    with open(os.path.join(out_dir, "results.json"), "w") as f:
        json.dump(info, f, indent=2)

    print(f"[Test] {test_path}")
    print(f"[Metrics] PSNR={psnr:.2f}, SSIM={ssim:.4f}")
    print(f"[Bits] tokens={tokens_top}, bits/token={bits_per_top}, "
          f"latent={latent_bits}, orig={orig_bits}, ratio≈{ratio:.1f}× (~{bpp:.5f} bpp)")
    print("✅ Saved:", in_path)
    print("✅ Saved:", re_path)
    print("✅ Saved:", panel_path)
    print("✅ Saved:", os.path.join(out_dir, "results.json"))
    return out_dir


test_img_path = find_first_test_image(TEST_DIR)
_ = eval_recon_toponly(model, test_img_path)


In [None]:
import csv

def get_test_image_paths(test_dir, n=30):
    paths = sorted(glob.glob(os.path.join(test_dir, '*.png')))
    if len(paths) == 0:
        paths = sorted(glob.glob(os.path.join(test_dir, '**/*.png'), recursive=True))
    if len(paths) == 0:
        raise FileNotFoundError(f"No PNG images found under {test_dir}")
    return paths[:n]

def eval_recon_toponly_many(model, test_paths, out_root="/content/drive/MyDrive/vqvae_toponly_eval", batch_size=2):
    os.makedirs(out_root, exist_ok=True)
    stamp = time.strftime("%Y%m%d-%H%M%S")
    out_dir = os.path.join(out_root, f"{stamp}_batch")
    os.makedirs(out_dir, exist_ok=True)


    bits_per_top = int(math.ceil(math.log2(int(NUM_EMBEDDINGS_TOP))))
    tokens_top   = int(TOP_GRID) * int(TOP_GRID)
    latent_bits  = tokens_top * bits_per_top
    orig_bits    = int(IMAGE_SIZE) * int(IMAGE_SIZE) * 8
    ratio        = orig_bits / latent_bits
    bpp          = latent_bits / (IMAGE_SIZE * IMAGE_SIZE)

    results = []
    for i in range(0, len(test_paths), batch_size):
        chunk = test_paths[i:i+batch_size]
        xs = [load_img_for_model(p) for p in chunk]
        x1 = tf.stack(xs, axis=0)                     # [B,H,W,1]
        y1 = model(x1, training=False)


        psnrs = tf.image.psnr(x1 + 0.5, y1 + 0.5, max_val=1.0).numpy()
        ssims = tf.image.ssim(x1 + 0.5, y1 + 0.5, max_val=1.0).numpy()

        x01 = (x1.numpy() + 0.5)[..., 0]  # [B,H,W] in [0,1]
        y01 = (y1.numpy() + 0.5)[..., 0]

        for j, p in enumerate(chunk):
            idx = i + j
            stem = f"{idx:04d}"
            save_png01(x01[j], os.path.join(out_dir, f"{stem}_input.png"))
            save_png01(y01[j], os.path.join(out_dir, f"{stem}_recon.png"))
            results.append({
                "index": idx,
                "test_path": p,
                "psnr": float(psnrs[j]),
                "ssim": float(ssims[j])
            })


    avg_psnr = float(np.mean([r["psnr"] for r in results]))
    avg_ssim = float(np.mean([r["ssim"] for r in results]))

    summary = {
        "count": len(test_paths),
        "avg_psnr": round(avg_psnr, 4),
        "avg_ssim": round(avg_ssim, 6),
        "tokens_top": tokens_top,
        "bits_per_token": bits_per_top,
        "latent_bits": latent_bits,
        "orig_bits": orig_bits,
        "compression_ratio": round(ratio, 2),
        "bpp": bpp
    }

    with open(os.path.join(out_dir, "results.json"), "w") as f:
        json.dump({"summary": summary, "results": results}, f, indent=2)

    with open(os.path.join(out_dir, "results.csv"), "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["index", "test_path", "psnr", "ssim"])
        writer.writeheader()
        writer.writerows(results)

    print(f"[BATCH DONE] N={len(test_paths)} | Avg PSNR={avg_psnr:.2f}, Avg SSIM={avg_ssim:.4f}")
    print(f"Saved to: {out_dir}")
    return out_dir

# TEST_DIR = "/content/drive/MyDrive/your_test_dir"
paths_30 = get_test_image_paths(TEST_DIR, n=30)
_ = eval_recon_toponly_many(model, paths_30, batch_size=2)


## 2) Two-Level VQ-VAE (204×)

In [None]:
# ==============================================================
# Stage-1 (Two-level VQ-VAE-2, EMA codebook + LN/Dropout + FiLM)
# - Top = 32x32, Bottom = 64x64
# - Bottom commitment = 0.15
# ==============================================================

import os, time, json, math
import numpy as np
import tensorflow as tf
from PIL import Image

# -----------------------------
# (A) Residual helpers
# -----------------------------
if 'residual_stack' not in globals():
    def residual_block(x, filters):
        skip = x
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Conv2D(filters//2, 3, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Conv2D(filters, 1, padding='same')(x)
        return tf.keras.layers.Add()([skip, x])

    def residual_stack(x, filters, num_blocks=2):
        for _ in range(num_blocks):
            x = residual_block(x, filters)
        return tf.keras.layers.ReLU()(x)

# -----------------------------
# (B) EMA VectorQuantizer (VQ-VAE-2)
# -----------------------------
class VectorQuantizerEMA(tf.keras.layers.Layer):
    def __init__(self, num_embeddings, embedding_dim,
                 commitment_cost=0.25, decay=0.99, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.num_embeddings = int(num_embeddings)
        self.embedding_dim = int(embedding_dim)
        self.commitment_cost = float(commitment_cost)
        self.decay = float(decay)
        self.epsilon = float(epsilon)
        self._perplexity = tf.keras.metrics.Mean(name=f'{self.name}_perplexity')

    @property
    def metrics(self):
        return [self._perplexity]

    def build(self, input_shape):
        self.embeddings = self.add_weight(
            name='embeddings',
            shape=(self.embedding_dim, self.num_embeddings),
            initializer=tf.keras.initializers.RandomUniform(-1.0/self.num_embeddings, 1.0/self.num_embeddings),
            trainable=False
        )
        self.ema_cluster_size = self.add_weight(
            name='ema_cluster_size',
            shape=(self.num_embeddings,),
            initializer='zeros',
            trainable=False
        )
        self.ema_dw = self.add_weight(
            name='ema_dw',
            shape=(self.embedding_dim, self.num_embeddings),
            initializer='zeros',
            trainable=False
        )

    def call(self, inputs, training=None):
        shp  = tf.shape(inputs)                              # [B,H,W,C]
        flat = tf.reshape(inputs, [-1, self.embedding_dim])  # [N,C]

        # NN lookup via expanded L2
        dists = (tf.reduce_sum(flat**2, axis=1, keepdims=True)
                 - 2.0 * tf.matmul(flat, self.embeddings)
                 + tf.reduce_sum(self.embeddings**2, axis=0, keepdims=True))  # [N,K]
        idx  = tf.argmax(-dists, axis=1)                     # [N]
        one  = tf.one_hot(idx, self.num_embeddings, dtype=flat.dtype)  # [N,K]
        quant = tf.matmul(one, tf.transpose(self.embeddings))          # [N,C]
        quant = tf.reshape(quant, shp)                                   # [B,H,W,C]

        # EMA updates (training only)
        def _ema_update():
            cluster_size = tf.reduce_sum(one, axis=0)                  # [K]
            dw = tf.matmul(tf.transpose(flat), one)                    # [C,K]
            ema_cs = self.ema_cluster_size * self.decay + cluster_size * (1.0 - self.decay)
            ema_dw = self.ema_dw * self.decay + dw * (1.0 - self.decay)
            n = tf.reduce_sum(ema_cs)
            smoothed_cs = (ema_cs + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
            new_embed = ema_dw / tf.expand_dims(smoothed_cs, 0)        # [C,K]
            self.ema_cluster_size.assign(ema_cs)
            self.ema_dw.assign(ema_dw)
            self.embeddings.assign(new_embed)
            return 0.0

        if training is None:
            training = tf.keras.backend.learning_phase()
        tf.cond(tf.cast(training, tf.bool), _ema_update, lambda: 0.0)

        # only commitment loss
        e_loss = tf.reduce_mean((tf.stop_gradient(quant) - inputs)**2)
        self.add_loss(self.commitment_cost * e_loss)

        # perplexity
        avg_p = tf.reduce_mean(one, axis=0)
        perp  = tf.exp(-tf.reduce_sum(avg_p * tf.math.log(avg_p + 1e-10)))
        self._perplexity.update_state(perp)

        # straight-through
        quant = inputs + tf.stop_gradient(quant - inputs)
        return quant

# -----------------------------
# (C) Hyperparams (read/override from globals when present)
# -----------------------------
IMAGE_SIZE         = int(globals().get('IMAGE_SIZE', 1024))
TOP_GRID           = int(globals().get('TOP_GRID', 32))
BOTTOM_GRID        = int(globals().get('BOTTOM_GRID', 64))
NUM_HIDDENS_TOP    = int(globals().get('NUM_HIDDENS_TOP', 128))
NUM_HIDDENS_BOTTOM = int(globals().get('NUM_HIDDENS_BOTTOM', 128))

EMBEDDING_DIM_TOP     = int(globals().get('EMBEDDING_DIM_TOP', 96))
EMBEDDING_DIM_BOTTOM  = int(globals().get('EMBEDDING_DIM_BOTTOM', 96))
NUM_EMBEDDINGS_TOP    = int(globals().get('NUM_EMBEDDINGS_TOP', 256))
NUM_EMBEDDINGS_BOTTOM = int(globals().get('NUM_EMBEDDINGS_BOTTOM', 256))

COMMITMENT_COST_TOP    = float(globals().get('COMMITMENT_COST_TOP', 1.0))
COMMITMENT_COST_BOTTOM = float(globals().get('COMMITMENT_COST_BOTTOM', 0.15))
# --- Hyperparams ---
COMMITMENT_COST_TOP    = float(globals().get('COMMITMENT_COST_TOP', 1.0))
COMMITMENT_COST_BOTTOM = 0.15


# -----------------------------
# (D) Building blocks
# -----------------------------
def build_encoder_bottom(img_size, bottom_grid, num_hiddens):
    assert (img_size % bottom_grid) == 0
    n_downs = int(math.log2(img_size // bottom_grid))  # 1024->64 = 4
    inputs = tf.keras.Input(shape=(img_size, img_size, 1))
    x = inputs
    chans = [64, 128, 128, num_hiddens][:n_downs]
    for c in chans:
        x = tf.keras.layers.Conv2D(c, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)
    x = residual_stack(x, num_hiddens, num_blocks=2)
    return tf.keras.Model(inputs, x, name=f'encoder_bottom_{bottom_grid}')

def build_encoder_top_from_bottom(bottom_grid, top_grid, in_ch, out_ch):
    assert (bottom_grid % top_grid) == 0
    n_downs = int(math.log2(bottom_grid // top_grid))  # 64->32 = 1
    inputs = tf.keras.Input(shape=(bottom_grid, bottom_grid, in_ch))
    x = inputs
    for _ in range(n_downs):
        x = tf.keras.layers.Conv2D(out_ch, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)
    x = residual_stack(x, out_ch, num_blocks=2)
    return tf.keras.Model(inputs, x, name=f'encoder_top_{top_grid}')

def build_top_to_bottom_upsampler(top_grid, bottom_grid, in_ch, out_ch):
    assert (bottom_grid % top_grid) == 0
    n_ups = int(math.log2(bottom_grid // top_grid))     # 32->64 = 1
    inputs = tf.keras.Input(shape=(top_grid, top_grid, in_ch))
    x = inputs
    for _ in range(n_ups):
        x = tf.keras.layers.Conv2DTranspose(out_ch, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)
    return tf.keras.Model(inputs, x, name=f'top_upsampler_{top_grid}_to_{bottom_grid}')

def build_decoder_bottom_to_image(img_size, bottom_grid, in_ch):
    assert (img_size % bottom_grid) == 0
    n_ups = int(math.log2(img_size // bottom_grid))    # 64->1024 = 4
    inputs = tf.keras.Input(shape=(bottom_grid, bottom_grid, in_ch))
    x = inputs
    ups_ch = [256, 128, 128, 64][:n_ups]
    for c in ups_ch:
        x = tf.keras.layers.Conv2DTranspose(c, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)
    out = tf.keras.layers.Conv2D(1, 3, padding='same')(x)
    return tf.keras.Model(inputs, out, name=f'decoder_bottom_{bottom_grid}_to_{IMAGE_SIZE}')

# -----------------------------
# (E) Assemble Stage-1 model
# -----------------------------
enc_bottom = build_encoder_bottom(IMAGE_SIZE, BOTTOM_GRID, NUM_HIDDENS_BOTTOM)
enc_top    = build_encoder_top_from_bottom(BOTTOM_GRID, TOP_GRID, NUM_HIDDENS_BOTTOM, NUM_HIDDENS_TOP)

pre_vq_top     = tf.keras.layers.Conv2D(EMBEDDING_DIM_TOP, 1, name='pre_vq_top')
vq_top         = VectorQuantizerEMA(NUM_EMBEDDINGS_TOP, EMBEDDING_DIM_TOP,
                                    commitment_cost=COMMITMENT_COST_TOP, decay=0.99, epsilon=1e-5, name='vq_top')
post_vq_top    = tf.keras.layers.Conv2D(NUM_HIDDENS_TOP, 1, name='post_vq_top')

pre_vq_bottom  = tf.keras.layers.Conv2D(EMBEDDING_DIM_BOTTOM, 1, name='pre_vq_bottom')
vq_bottom      = VectorQuantizerEMA(NUM_EMBEDDINGS_BOTTOM, EMBEDDING_DIM_BOTTOM,
                                    commitment_cost=COMMITMENT_COST_BOTTOM, decay=0.99, epsilon=1e-5, name='vq_bottom')
post_vq_bottom = tf.keras.layers.Conv2D(NUM_HIDDENS_BOTTOM, 1, name='post_vq_bottom')

top_up   = build_top_to_bottom_upsampler(TOP_GRID, BOTTOM_GRID, NUM_HIDDENS_TOP, NUM_HIDDENS_BOTTOM)
dec_bot  = build_decoder_bottom_to_image(IMAGE_SIZE, BOTTOM_GRID, in_ch=NUM_HIDDENS_BOTTOM * 2)

class VQTwoLevelEMA(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.enc_bottom = enc_bottom
        self.enc_top    = enc_top
        self.pre_vq_top = pre_vq_top
        self.vq_top     = vq_top
        self.post_vq_top= post_vq_top
        self.pre_vq_bottom  = pre_vq_bottom
        self.vq_bottom      = vq_bottom
        self.post_vq_bottom = post_vq_bottom
        self.top_up   = top_up
        self.dec_bottom = dec_bot
        # info balancing
        self.norm_b = tf.keras.layers.LayerNormalization(axis=-1)
        self.norm_t = tf.keras.layers.LayerNormalization(axis=-1)
        self.drop_b = tf.keras.layers.SpatialDropout2D(0.3)  # training only
        # FiLM
        self.film_gamma = tf.keras.layers.Conv2D(NUM_HIDDENS_BOTTOM, 1)
        self.film_beta  = tf.keras.layers.Conv2D(NUM_HIDDENS_BOTTOM, 1)

    def call(self, x, training=None):
        hb  = self.enc_bottom(x)                 # [B,64,64,Hb]
        zb  = self.pre_vq_bottom(hb)             # [B,64,64,Cb]
        zbq = self.vq_bottom(zb, training=training)
        zbq = self.post_vq_bottom(zbq)           # [B,64,64,Hb]

        ht  = self.enc_top(hb)                   # [B,32,32,Ht]
        zt  = self.pre_vq_top(ht)                # [B,32,32,Ct]
        ztq = self.vq_top(zt, training=training)
        ztq = self.post_vq_top(ztq)              # [B,32,32,Ht]
        t_up= self.top_up(ztq)                   # [B,64,64,Hb]

        zbq = self.norm_b(zbq)
        t_up= self.norm_t(t_up)
        zbq = self.drop_b(zbq, training=training)

        gamma = self.film_gamma(t_up)
        beta  = self.film_beta(t_up)
        zbq_mod = zbq * (1.0 + gamma) + beta

        y = self.dec_bottom(tf.concat([zbq_mod, t_up], axis=-1))  # [B,1024,1024,1]
        return y

model = VQTwoLevelEMA()

# -----------------------------
# (F) Compile (define if missing)
# -----------------------------
def ssim_metric(y_true, y_pred):
    # data is in [-0.5, 0.5]; shift to [0,1] for SSIM
    return tf.reduce_mean(tf.image.ssim(y_true + 0.5, y_pred + 0.5, max_val=1.0))

def psnr_metric(y_true, y_pred):
    return tf.reduce_mean(tf.image.psnr(y_true + 0.5, y_pred + 0.5, max_val=1.0))

def compile_warmup(m, lr=3e-4, wd=1e-4):
    try:
        opt = tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=wd, beta_1=0.9, beta_2=0.95)
    except Exception:
        opt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.9, beta_2=0.95)
    # L1 更稳，EMA VQ 的 loss 已由 layer.add_loss 注入
    m.compile(optimizer=opt, loss='mae', metrics=[ssim_metric, psnr_metric])

# -----------------------------
# (G) Callbacks: auto-save best
# -----------------------------
def make_stage1_callbacks(out_dir,
                          monitor_primary="val_ssim_metric",
                          monitor_secondary="val_loss",
                          patience=8):
    os.makedirs(out_dir, exist_ok=True)
    cbs = [
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(out_dir, "stage1_best_by_ssim.weights.h5"),
            save_weights_only=True, save_best_only=True,
            monitor=monitor_primary, mode="max", verbose=1),
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(out_dir, "stage1_best_by_vloss.weights.h5"),
            save_weights_only=True, save_best_only=True,
            monitor=monitor_secondary, mode="min", verbose=1),
        tf.keras.callbacks.CSVLogger(os.path.join(out_dir, "stage1_log.csv")),
        tf.keras.callbacks.TensorBoard(log_dir=os.path.join(out_dir, "tb_stage1")),
        tf.keras.callbacks.EarlyStopping(monitor=monitor_secondary, mode="min",
                                         patience=patience, restore_best_weights=False),
    ]
    return cbs

# -----------------------------
# (H) Export bundle (weights + codebooks + manifest + samples)
# -----------------------------
def _to_uint8(x):
    x = np.clip(x, 0.0, 1.0)
    return (x * 255.0 + 0.5).astype(np.uint8)

def save_png01(img01, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    Image.fromarray(_to_uint8(img01)).save(path, 'PNG')

def _stage1_manifest_dict():
    return {
        "IMAGE_SIZE": IMAGE_SIZE,
        "TOP_GRID": TOP_GRID,
        "BOTTOM_GRID": BOTTOM_GRID,
        "NUM_EMBEDDINGS_TOP": NUM_EMBEDDINGS_TOP,
        "NUM_EMBEDDINGS_BOTTOM": NUM_EMBEDDINGS_BOTTOM,
        "EMBEDDING_DIM_TOP": EMBEDDING_DIM_TOP,
        "EMBEDDING_DIM_BOTTOM": EMBEDDING_DIM_BOTTOM,
        "COMMITMENT_COST_TOP": COMMITMENT_COST_TOP,
        "COMMITMENT_COST_BOTTOM": COMMITMENT_COST_BOTTOM,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }

def export_stage1_bundle(model, out_root="/content/drive/MyDrive/vqvae_stage1_runs",
                         run_dir=None, history=None, take_val=1, also_savedmodel=False):
    os.makedirs(out_root, exist_ok=True)
    stamp = time.strftime("%Y%m%d-%H%M%S") if run_dir is None else run_dir
    bundle_dir = os.path.join(out_root, stamp)
    os.makedirs(bundle_dir, exist_ok=True)

    # ensure built
    try:
        _xb = next(iter(train_ds))[0]
        _ = model(_xb[:1], training=False)
    except Exception:
        pass

    # 1) weights
    w_path = os.path.join(bundle_dir, "stage1.weights.h5")
    model.save_weights(w_path)

    # 2) codebooks + EMA
    np.save(os.path.join(bundle_dir, "vq_top_emb.npy"),     model.vq_top.embeddings.numpy())
    np.save(os.path.join(bundle_dir, "vq_bottom_emb.npy"),  model.vq_bottom.embeddings.numpy())
    np.save(os.path.join(bundle_dir, "vq_top_ema_cluster_size.npy"),    model.vq_top.ema_cluster_size.numpy())
    np.save(os.path.join(bundle_dir, "vq_top_ema_dw.npy"),             model.vq_top.ema_dw.numpy())
    np.save(os.path.join(bundle_dir, "vq_bottom_ema_cluster_size.npy"), model.vq_bottom.ema_cluster_size.numpy())
    np.save(os.path.join(bundle_dir, "vq_bottom_ema_dw.npy"),          model.vq_bottom.ema_dw.numpy())

    # 3) manifest
    with open(os.path.join(bundle_dir, "manifest.json"), "w") as f:
        json.dump(_stage1_manifest_dict(), f, indent=2)

    # 4) history
    if history is not None and hasattr(history, "history"):
        with open(os.path.join(bundle_dir, "history.json"), "w") as f:
            json.dump(history.history, f, indent=2)

    # 5) samples from val
    sample_dir = os.path.join(bundle_dir, "samples"); os.makedirs(sample_dir, exist_ok=True)
    taken = 0
    for xb, _ in val_ds.take(take_val):
        yb = model(xb, training=False)
        save_png01((xb.numpy()[0,...,0] + 0.5), os.path.join(sample_dir, f"ae_input_val0.png"))
        save_png01((yb.numpy()[0,...,0] + 0.5), os.path.join(sample_dir, f"ae_recon_val0.png"))
        taken += 1
        if taken >= take_val: break

    # 6) (optional) SavedModel
    if also_savedmodel:
        sm_dir = os.path.join(bundle_dir, "saved_model")
        model.save(sm_dir)

    # 7) latest symlink (best-effort)
    latest = os.path.join(out_root, "latest")
    try:
        if os.path.islink(latest) or os.path.exists(latest):
            os.remove(latest)
        os.symlink(bundle_dir, latest)
    except Exception:
        with open(os.path.join(out_root, "LATEST.txt"), "w") as f:
            f.write(bundle_dir)

    print(f"✅ Stage-1 bundle saved to: {bundle_dir}")
    return bundle_dir

# -----------------------------
# (I) Bit-accurate compression (print)
# -----------------------------
bits_per_top    = int(math.ceil(math.log2(NUM_EMBEDDINGS_TOP)))       # 256 -> 8
bits_per_bottom = int(math.ceil(math.log2(NUM_EMBEDDINGS_BOTTOM)))    # 256 -> 8
tokens_top      = TOP_GRID * TOP_GRID          # 1024
tokens_bottom   = BOTTOM_GRID * BOTTOM_GRID    # 4096
latent_bits     = tokens_top * bits_per_top + tokens_bottom * bits_per_bottom
orig_bits       = IMAGE_SIZE * IMAGE_SIZE * 8
bpp             = latent_bits / (IMAGE_SIZE * IMAGE_SIZE)
ratio           = orig_bits / latent_bits

print("="*66)
print(f"Bit-accurate compression (Top+Bottom): {ratio:.1f}×  (~{bpp:.5f} bpp)")
print(f" - tokens: top {tokens_top} + bottom {tokens_bottom} = {tokens_top + tokens_bottom}")
print(f" - bits/token: top {bits_per_top}, bottom {bits_per_bottom}")
print(f" - commitment(bottom) = {COMMITMENT_COST_BOTTOM}")
print("="*66)

# -----------------------------
# (J) Train with auto-save best → load best → export bundle
# -----------------------------
EPOCHS_WARMUP = int(globals().get('EPOCHS_WARMUP', 50))

compile_warmup(model)

out_root = "/content/drive/MyDrive/vqvae_stage1_runs"
run_stamp = time.strftime("%Y%m%d-%H%M%S")
run_dir   = os.path.join(out_root, run_stamp)
cbs = make_stage1_callbacks(
    run_dir,
    monitor_primary="val_ssim_metric",
    monitor_secondary="val_loss",
    patience=8
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS_WARMUP,
    callbacks=cbs,
    verbose=1
)


best_ckpt = os.path.join(run_dir, "stage1_best_by_vloss.weights.h5")
model.load_weights(best_ckpt)
print(f"Loaded best weights: {best_ckpt}")


bundle_dir = export_stage1_bundle(
    model,
    out_root=out_root,
    run_dir=run_stamp,
    history=history,
    also_savedmodel=False
)
print("Exported:", bundle_dir)


## 3) Two-Level VQ-VAE + Transformer Prior (1024×)

You can also use PixelCNN to train a prior or bellow - but may need to have some fine-tune

In [None]:
# ============================================
# Prior Transformer — 省显存可跑版（混合精度 / 256×6 / logits回float32 / batch=1）
# ============================================
import os, json, math, numpy as np, tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# -------- 0) 显存 & 精度设置（必须在建模前）---------
# 可选：让 TensorFlow 对 GPU 显存按需增长，避免一次性占满
try:
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for g in gpus:
            tf.config.experimental.set_memory_growth(g, True)
        print("✅ Enabled GPU memory growth")
except Exception as e:
    print("⚠️ memory growth set failed:", e)

# 开启混合精度：注意必须在任何模型/层创建之前调用
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
print("✅ Mixed precision policy:", mixed_precision.global_policy())

# -------------------------
# 
# -------------------------
EXT_DIR         = "/content/drive/MyDrive/vqvae_stage2_tokens"
INDEX_EXT       = os.path.join(EXT_DIR, "index_ext.json")
ADAPTER_WEIGHTS = "/content/drive/MyDrive/vqvae_stage2_adapter/adapter_best.weights.h5"
PRIOR_OUT       = "/content/drive/MyDrive/vqvae_stage2_prior_full_mem"
os.makedirs(PRIOR_OUT, exist_ok=True)

with open(INDEX_EXT, "r") as f:
    NPZ_LIST = json.load(f)
print("总样本:", len(NPZ_LIST))

# 尺寸/词表（保持与你前面一致）
TOP_H, TOP_W  = 32, 32
BOT_H, BOT_W  = 64, 64
TGT_LEN       = BOT_H * BOT_W
VBOT          = 256            # bottom 词表
VBOT_PLUS     = VBOT + 1       # 0:BOS, 1..256 对应真实 0..255
VTOP          = 256
Hb            = 128            # bottom 隐特征通道数

# 批量设 1（配合混合精度，一般就能稳）
BATCH_SIZE    = 1
VAL_RATIO     = 0.10

# -------------------------
# 数据集：((bot_in, top_in, cond_mask), target_bot)
# -------------------------
def _load_npz(path_b):
    p = path_b.numpy().decode("utf-8")
    d = np.load(p, allow_pickle=False)
    top = d["top32_only"].astype(np.int32)        # [32,32]
    bot = d["bot64"].astype(np.int32)             # [64,64]
    return top, bot

def parse_for_prior(path):
    top, bot = tf.py_function(_load_npz, [path], [tf.int32, tf.int32])
    top.set_shape((TOP_H, TOP_W))
    bot.set_shape((BOT_H, BOT_W))
    cond_mask = tf.ones((), dtype=tf.float32)     # 训练期默认全条件=1.0；采样时可切换CFG
    return ((bot, top, cond_mask), bot)

AUTO = tf.data.AUTOTUNE
def make_ds(files, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices(files)
    if shuffle:
        ds = ds.shuffle(min(20000, len(files)), reshuffle_each_iteration=True)
    ds = ds.map(parse_for_prior, num_parallel_calls=AUTO)
    ds = ds.batch(BATCH_SIZE, drop_remainder=False).prefetch(AUTO)
    return ds

n_total   = len(NPZ_LIST)
n_val     = max(1, int(n_total * VAL_RATIO))
VAL_FILES = NPZ_LIST[:n_val]
TRN_FILES = NPZ_LIST[n_val:]

train_ds = make_ds(TRN_FILES, shuffle=True)
val_ds   = make_ds(VAL_FILES, shuffle=False)
print(f"Train: {len(TRN_FILES)} | Val: {len(VAL_FILES)}")

# -------------------------
# Strong Adapter（复用你已训练的；冻结）
# -------------------------
D_EMB  = 256   # 与 prior 的 d_model 对齐即可
HEADS  = 8

def GN(groups=16):
    # 有 GroupNorm 用 GN；没有就退化到 LayerNorm
    try:
        return layers.GroupNormalization(groups=groups)
    except Exception:
        return layers.LayerNormalization(axis=-1)

def LambdaReshape_HW_to_seq(H, W):
    return layers.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], H*W, tf.shape(t)[-1]]))

def LambdaReshape_seq_to_HW(H, W):
    return layers.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], H, W, tf.shape(t)[-1]]))

def PixelShuffleLayer(r=2):
    return layers.Lambda(lambda t: tf.nn.depth_to_space(t, r))

def ResConv(x, ch, gn_groups=16):
    s = x
    x = layers.Conv2D(ch, 3, padding='same')(x); x = GN(gn_groups)(x); x = layers.Activation('gelu')(x)
    x = layers.Conv2D(ch, 3, padding='same')(x); x = GN(gn_groups)(x)
    return layers.Activation('gelu')(layers.Add()([x, s]))

def MHSA_2d_32(x, heads=HEADS, D=D_EMB):
    to_seq = LambdaReshape_HW_to_seq(32, 32); to_hw = LambdaReshape_seq_to_HW(32, 32)
    seq = to_seq(x)
    att = layers.MultiHeadAttention(num_heads=heads, key_dim=D//heads, output_shape=D)
    y = att(seq, seq, use_causal_mask=False)
    y = layers.LayerNormalization()(layers.Add()([y, seq]))
    y2= layers.Dense(4*D, activation='gelu')(y); y2 = layers.Dense(D)(y2)
    y = layers.LayerNormalization()(layers.Add()([y2, y]))
    return to_hw(y)

def CrossAttn_2d_64x32(q_64, kv_32, heads=HEADS, D_kv=D_EMB, C_q=None):
    if C_q is None: C_q = q_64.shape[-1]
    to_seq64 = LambdaReshape_HW_to_seq(64, 64); to_hw64 = LambdaReshape_seq_to_HW(64, 64)
    q = to_seq64(q_64)
    att = layers.MultiHeadAttention(num_heads=heads, key_dim=D_kv//heads, output_shape=int(C_q))
    y = att(q, kv_32)
    y = layers.LayerNormalization()(layers.Add()([y, q]))
    y2= layers.Dense(int(4*C_q), activation='gelu')(y); y2 = layers.Dense(int(C_q))(y2)
    y = layers.LayerNormalization()(layers.Add()([y2, y]))
    return to_hw64(y)

def build_strong_adapter(vocab_top=VTOP, D=D_EMB, Hb=Hb, heads=HEADS):
    t_idx = keras.Input((32,32), dtype='int32', name='toponly_idx_32x32')
    F32 = layers.Embedding(vocab_top, D, name='emb_toponly')(t_idx)
    x32 = ResConv(F32, D); x32 = ResConv(x32, D); x32 = MHSA_2d_32(x32, heads=heads, D=D)
    kv_seq = LambdaReshape_HW_to_seq(32, 32)(x32)
    pos_ids = tf.constant([[i for i in range(32*32)]], dtype=tf.int32)
    pos_emb = layers.Embedding(32*32, D, name='pos_kv_emb')(pos_ids)
    kv_seq  = layers.Add()([kv_seq, pos_emb])
    up64 = layers.Conv2D(D//2*4, 3, padding='same')(x32); up64 = PixelShuffleLayer(r=2)(up64)
    up64 = ResConv(up64, D//2); up64 = ResConv(up64, D//2)
    up64 = CrossAttn_2d_64x32(up64, kv_seq, heads=heads, D_kv=D, C_q=D//2)
    y = layers.Conv2D(D//2, 3, padding='same', dilation_rate=2, activation='gelu')(up64)
    y = layers.Conv2D(D//2, 3, padding='same', dilation_rate=3, activation='gelu')(y)
    y = layers.LayerNormalization(axis=-1)(y)
    out = layers.Conv2D(Hb, 1, padding='same', name='t_up_hat')(y)  # [B,64,64,Hb]
    return keras.Model(t_idx, out, name='AdapterStrong')

adapter = build_strong_adapter()
_ = adapter(tf.zeros([1,32,32], dtype=tf.int32))
adapter.load_weights(ADAPTER_WEIGHTS)
adapter.trainable = False
print("✅ Adapter loaded & frozen:", ADAPTER_WEIGHTS)

# -------------------------
# Prior：二维感知 + 条件注入 + CFG 支持（省显存版）
# 输出：[B,64,64,256]（logits），且 logits 为 float32（便于稳定计算 CE）
# -------------------------
def build_prior_transformer(
    d_model=256,            # ★ 512->256
    n_heads=8,
    n_layers=6,             # ★ 10->6
    ff_mult=4, dropout=0.1, cond_dropout=0.1
):
    # ---- Inputs ----
    bot_gt = keras.Input(shape=(BOT_H, BOT_W), dtype='int32', name='bot_gt_64x64')
    top_in = keras.Input(shape=(TOP_H, TOP_W), dtype='int32', name='top_only_32x32')
    cond_m = keras.Input(shape=(), dtype='float32', name='cond_mask')  # 训练=1.0；CFG 时切 0/1

    # ---- teacher-forcing: 右移 + BOS，再 +1 偏移到 1..256 ----
    tgt = layers.Reshape((TGT_LEN,), name="flat_bot")(bot_gt)  # [B,T]
    def _shift_right_plus1(t):
        z = tf.zeros_like(t[:, :1])
        t_in = tf.concat([z, t[:, :-1]], axis=1)
        return t_in + 1
    tgt_in = layers.Lambda(_shift_right_plus1, name="shift_right_plus1")(tgt)  # [B,T] ∈ [1..256]

    # ---- bottom token embedding ----
    x = layers.Embedding(VBOT_PLUS, d_model, name='emb_bot')(tgt_in)  # [B,T,D]

    # ---- 2D + 绝对位置（常量索引，避免图内 tf.* 触发符号问题）----
    row_idx_np = np.repeat(np.arange(BOT_H)[:, None], BOT_W, axis=1).reshape(1, TGT_LEN).astype('int32')
    col_idx_np = np.repeat(np.arange(BOT_W)[None, :], BOT_H, axis=0).reshape(1, TGT_LEN).astype('int32')
    abs_idx_np = np.arange(TGT_LEN)[None, :].astype('int32')

    row_pe = layers.Embedding(BOT_H, d_model, name='row_pe')(tf.constant(row_idx_np))
    col_pe = layers.Embedding(BOT_W, d_model, name='col_pe')(tf.constant(col_idx_np))
    abs_pe = layers.Embedding(TGT_LEN, d_model, name='abs_pe')(tf.constant(abs_idx_np))
    x = layers.Add(name="add_pos_pe")([x, row_pe, col_pe, abs_pe])  # [B,T,D]

    # ---- Adapter 条件 ----
    t_up = adapter(top_in)                         # [B,64,64,Hb]
    t_up = layers.Dropout(cond_dropout)(t_up)      # train only
    tproj = layers.Conv2D(d_model, 1, padding='same', name='tup_proj')(t_up)  # [B,64,64,D]
    tproj = layers.Reshape((TGT_LEN, d_model))(tproj)                          # [B,T,D]
    cond_scale = layers.Lambda(lambda c: tf.reshape(c, (-1, 1, 1)), name="cond_scale")(cond_m)
    tproj = layers.Multiply(name="tproj_mask")([tproj, cond_scale])
    x = layers.Add(name="add_tproj")([x, tproj])   # [B,T,D]

    # ---- Top-only KV 记忆 ----
    top_seq = layers.Reshape((TOP_H*TOP_W,), name="flat_top")(top_in)       # [B,1024]
    mem = layers.Embedding(VTOP, d_model, name='emb_top')(top_seq)          # [B,1024,D]
    top_pos_idx_np = np.arange(TOP_H*TOP_W)[None, :].astype('int32')
    mem_pos = layers.Embedding(TOP_H*TOP_W, d_model, name='pos_top')(tf.constant(top_pos_idx_np))
    mem = layers.Add(name="add_top_pos")([mem, mem_pos])                    # [B,1024,D]
    mem = layers.Dropout(cond_dropout)(mem)                                 # train only
    cond_scale_mem = layers.Lambda(lambda c: tf.reshape(c, (-1, 1, 1)), name="cond_scale_mem")(cond_m)
    mem = layers.Multiply(name="mem_mask")([mem, cond_scale_mem])

    # ---- Transformer Decoder blocks（Pre-LN）----
    for i in range(n_layers):
        ln1 = layers.LayerNormalization(name=f'ln1_{i}')(x)
        sa  = layers.MultiHeadAttention(num_heads=n_heads, key_dim=d_model//n_heads,
                                        dropout=dropout, name=f'self_attn_{i}')
        x_sa = sa(ln1, ln1, use_causal_mask=True)
        x = layers.Add(name=f'resid_sa_{i}')([x, x_sa])

        ln2 = layers.LayerNormalization(name=f'ln2_{i}')(x)
        ca  = layers.MultiHeadAttention(num_heads=n_heads, key_dim=d_model//n_heads,
                                        dropout=dropout, name=f'cross_attn_{i}')
        x_ca = ca(ln2, mem)
        x = layers.Add(name=f'resid_ca_{i}')([x, x_ca])

        ln3 = layers.LayerNormalization(name=f'ln3_{i}')(x)
        y = layers.Dense(ff_mult*d_model, activation='gelu', name=f'ffn_{i}_fc1')(ln3)
        y = layers.Dropout(dropout, name=f'ffn_{i}_drop')(y)
        y = layers.Dense(d_model, name=f'ffn_{i}_fc2')(y)
        x = layers.Add(name=f'resid_ffn_{i}')([x, y])

    x = layers.LayerNormalization(name='ln_out')(x)
    # ★ logits 显式设为 float32，避免半精度下 CE 数值不稳
    logits_seq = layers.Dense(VBOT, name='logits', dtype='float32')(x)                 # [B,4096,256] fp32
    logits_4d  = layers.Reshape((BOT_H, BOT_W, VBOT), name='logits_4d')(logits_seq)    # [B,64,64,256]
    return keras.Model(inputs=[bot_gt, top_in, cond_m], outputs=logits_4d, name='PriorTransformerFull')

prior = build_prior_transformer(
    d_model=256, n_heads=8, n_layers=6, ff_mult=4, dropout=0.1, cond_dropout=0.1
)
prior.summary(line_length=140)

# -------------------------
# 训练设置：AdamW + warmup+cosine / 梯度裁剪 / 指标
# -------------------------
class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, warmup_steps, total_steps, name=None):
        super().__init__()
        self.base_lr = float(base_lr)
        self.warmup_steps = int(warmup_steps)
        self.total_steps  = int(total_steps)
        self.name = name or "WarmupCosine"
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        wu = tf.cast(self.warmup_steps, tf.float32)
        ts = tf.cast(tf.maximum(1, self.total_steps), tf.float32)
        base = tf.cast(self.base_lr, tf.float32)
        lr_wu = base * tf.minimum(1.0, step / tf.maximum(1.0, wu))
        progress = tf.clip_by_value((step - wu) / tf.maximum(1.0, ts - wu), 0.0, 1.0)
        lr_cos = 0.5 * base * (1.0 + tf.cos(tf.constant(math.pi, tf.float32) * progress))
        return tf.where(step < wu, lr_wu, lr_cos)
    def get_config(self):
        return {"base_lr": self.base_lr, "warmup_steps": self.warmup_steps,
                "total_steps": self.total_steps, "name": self.name}

steps_per_epoch = max(1, math.ceil(len(TRN_FILES) / BATCH_SIZE))
total_steps     = steps_per_epoch * 30
sched = WarmupCosine(base_lr=2e-4, warmup_steps=min(2000, steps_per_epoch*3), total_steps=total_steps)

try:
    opt = keras.optimizers.AdamW(learning_rate=sched, weight_decay=1e-4, clipnorm=1.0)
except Exception:
    opt = keras.optimizers.Adam(learning_rate=sched, clipnorm=1.0)

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
acc_fn  = keras.metrics.SparseCategoricalAccuracy(name="acc")
def ppl_metric(y_true, y_pred):
    ce = keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
    return tf.exp(tf.reduce_mean(ce))

prior.compile(optimizer=opt, loss=loss_fn, metrics=[acc_fn, ppl_metric])

cbs = [
    keras.callbacks.ModelCheckpoint(
        os.path.join(PRIOR_OUT, "prior_best.weights.h5"),
        save_weights_only=True, save_best_only=True,
        monitor="val_loss", mode="min", verbose=1
    ),
    keras.callbacks.CSVLogger(os.path.join(PRIOR_OUT, "prior_log.csv")),
    keras.callbacks.TensorBoard(log_dir=os.path.join(PRIOR_OUT, "tb")),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=6, restore_best_weights=True),
]

# -------------------------
# 开训
# -------------------------
hist = prior.fit(train_ds, validation_data=val_ds, epochs=30, callbacks=cbs, verbose=1)
prior.save_weights(os.path.join(PRIOR_OUT, "prior_final.weights.h5"))
print("✅ Prior saved to:", PRIOR_OUT)

# =========================
# （可选）梯度累积版本：需要更大模型或更大等效batch可启用
# 使用方法：
#    1) 取消以下类与两行替换的注释
#    2) 用 AccumulateModel 包装 prior 再 compile & fit
# =========================
# class AccumulateModel(keras.Model):
#     def __init__(self, acc_steps=4, **kwargs):
#         super().__init__(**kwargs)
#         self.acc_steps = acc_steps
#         self._grad_accum = None
#         self._step = 0
#     def train_step(self, data):
#         (bot_in, top_in, cond_m), y = data
#         with tf.GradientTape() as tape:
#             y_pred = self([bot_in, top_in, cond_m], training=True)
#             loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
#         grads = tape.gradient(loss, self.trainable_variables)
#         if self._grad_accum is None:
#             self._grad_accum = [tf.zeros_like(g) for g in grads]
#         self._grad_accum = [ga + g for ga, g in zip(self._grad_accum, grads)]
#         self._step += 1
#         if self._step % self.acc_steps == 0:
#             self.optimizer.apply_gradients(zip(self._grad_accum, self.trainable_variables))
#             self._grad_accum = None
#         self.compiled_metrics.update_state(y, y_pred)
#         logs = {m.name: m.result() for m in self.metrics}
#         logs.update({"loss": loss})
#         return logs

# # 使用梯度累积：
# prior = AccumulateModel(acc_steps=4, inputs=prior.inputs, outputs=prior.outputs)
# prior.compile(optimizer=opt, loss=loss_fn, metrics=[acc_fn, ppl_metric])
# hist = prior.fit(train_ds, validation_data=val_ds, epochs=30, callbacks=cbs, verbose=1)


## Bonus - Train a Adapter

For image generation you can also train a Adapter to speed-up the decoding section

In [None]:
# =============================
# Strong Adapter (32x32 indices -> 64x64xHb) — Keras-only
# =============================
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Hb   = 128   # two-level 的 bottom hidden 通道
VTOP = 256   # top-only 词表大小
D    = 256   # token embedding 维度
HEADS = 8

# ---- small helpers (Keras-only) ----
def GN(groups=16):
    # 兼容：若没有 GroupNormalization 就退化成 LayerNormalization
    try:
        return layers.GroupNormalization(groups=groups)
    except Exception:
        return layers.LayerNormalization(axis=-1)

def LambdaReshape_HW_to_seq(H, W):
    return layers.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], H*W, tf.shape(t)[-1]]))

def LambdaReshape_seq_to_HW(H, W):
    return layers.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], H, W, tf.shape(t)[-1]]))

def PixelShuffleLayer(r=2):
    return layers.Lambda(lambda t: tf.nn.depth_to_space(t, r))

def ResConv(x, ch, gn_groups=16):
    skip = x
    x = layers.Conv2D(ch, 3, padding='same')(x)
    x = GN(gn_groups)(x)
    x = layers.Activation('gelu')(x)
    x = layers.Conv2D(ch, 3, padding='same')(x)
    x = GN(gn_groups)(x)
    x = layers.Activation('gelu')(layers.Add()([x, skip]))
    return x

def MHSA_2d_32(x, heads=HEADS, D=D):
    # x: [B,32,32,D]  -> attn -> [B,32,32,D]
    to_seq = LambdaReshape_HW_to_seq(32, 32)
    to_hw  = LambdaReshape_seq_to_HW(32, 32)
    seq = to_seq(x)
    att = layers.MultiHeadAttention(num_heads=heads, key_dim=D//heads, output_shape=D)
    y = att(seq, seq, use_causal_mask=False)
    y = layers.LayerNormalization()(layers.Add()([y, seq]))
    y2 = layers.Dense(4*D, activation='gelu')(y)
    y2 = layers.Dense(D)(y2)
    y  = layers.LayerNormalization()(layers.Add()([y2, y]))
    return to_hw(y)

def CrossAttn_2d_64x32(q_64, kv_32, heads=HEADS, D_kv=D, C_q=None):
    # q_64: [B,64,64,Cq]   kv_32: [B,32*32,D]
    # 输出形状与 q_64 一致
    if C_q is None:
        C_q = q_64.shape[-1]  # 这里是 D//2
    to_seq64 = LambdaReshape_HW_to_seq(64, 64)
    to_hw64  = LambdaReshape_seq_to_HW(64, 64)
    q = to_seq64(q_64)                              # [B,4096,Cq]
    att = layers.MultiHeadAttention(num_heads=heads, key_dim=D_kv//heads, output_shape=int(C_q))
    y = att(q, kv_32)                               # [B,4096,Cq]
    y = layers.LayerNormalization()(layers.Add()([y, q]))
    y2= layers.Dense(int(4*C_q), activation='gelu')(y)
    y2= layers.Dense(int(C_q))(y2)
    y = layers.LayerNormalization()(layers.Add()([y2, y]))
    return to_hw64(y)

def build_strong_adapter(vocab_top=VTOP, D=D, Hb=Hb, heads=HEADS):
    # 输入：32x32 的 top-only 索引
    t_idx = keras.Input((32,32), dtype='int32', name='toponly_idx_32x32')

    # 32×32 token embedding + 两个残差 + 自注意力
    F32 = layers.Embedding(vocab_top, D, name='emb_toponly')(t_idx)  # [B,32,32,D]
    x32 = ResConv(F32, D); x32 = ResConv(x32, D)
    x32 = MHSA_2d_32(x32, heads=heads, D=D)

    # 作为 KV 的序列（加绝对位置）
    kv_seq = LambdaReshape_HW_to_seq(32, 32)(x32)                    # [B,1024,D]
    pos_ids = tf.constant([[i for i in range(32*32)]], dtype=tf.int32)  # 常量OK
    pos_emb = layers.Embedding(32*32, D, name='pos_kv_emb')(pos_ids)    # [1,1024,D]
    kv_seq  = layers.Add()([kv_seq, pos_emb])                          # broadcast 到 B

    # 上采样到 64×64（像素重排），通道减半 -> Cq = D//2
    up64 = layers.Conv2D(D//2*4, 3, padding='same')(x32)  # 预卷积后 PixelShuffle
    up64 = PixelShuffleLayer(r=2)(up64)                   # [B,64,64,D//2]
    up64 = ResConv(up64, D//2); up64 = ResConv(up64, D//2)

    # 交叉注意力：Q=64×64×(D//2)，K/V=kv_seq(32×32×D)
    up64 = CrossAttn_2d_64x32(up64, kv_seq, heads=heads, D_kv=D, C_q=D//2)

    # 空洞卷积细化 + LN + 1×1 到 Hb
    y = layers.Conv2D(D//2, 3, padding='same', dilation_rate=2, activation='gelu')(up64)
    y = layers.Conv2D(D//2, 3, padding='same', dilation_rate=3, activation='gelu')(y)
    y = layers.LayerNormalization(axis=-1)(y)
    out = layers.Conv2D(Hb, 1, padding='same', name='t_up_hat')(y)     # [B,64,64,Hb]

    return keras.Model(t_idx, out, name='AdapterStrong')

# --- build & check ---
adapter = build_strong_adapter()
_ = adapter(tf.zeros([1,32,32], dtype=tf.int32))  # build
adapter.summary()


In [None]:
# ============================================
# Adapter 训练（端到端可运行版）
# - 读取 tok_ext_*.npz  -> 训练 Strong Adapter
# - 修复回调文件名后缀  -> *.weights.h5
# ============================================
import os, json, numpy as np, tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# ---- 路径 & 常量（按需改）----
EXT_DIR   = "/content/drive/MyDrive/vqvae_stage2_tokens"   # 存 tok_ext_*.npz
INDEX_EXT = os.path.join(EXT_DIR, "index_ext.json")
assert os.path.exists(INDEX_EXT), "缺少 index_ext.json（先跑 Stage-2 Prep++ 的导出脚本）"

with open(INDEX_EXT, "r") as f:
    npz_list = json.load(f)
print("总样本:", len(npz_list))

Hb    = 128   # two-level 的 bottom hidden 通道数
VTOP  = 256   # top-only 词表
D     = 256   # token embedding 维度
HEADS = 8

# -------------------------------
# Dataset: 读 tok_ext_*.npz
#   每个文件包含:
#     - top32_only : [32,32] int
#     - t_up_gt    : [64,64,Hb] float
# -------------------------------
def _load_npz(path_bytes):
    path = path_bytes.decode("utf-8")
    d = np.load(path, allow_pickle=False)
    t_idx = d["top32_only"].astype(np.int32)   # [32,32]
    t_up  = d["t_up_gt"].astype(np.float32)    # [64,64,Hb]
    return t_idx, t_up

def parse_fn(path):
    t_idx, t_up = tf.numpy_function(_load_npz, [path], [tf.int32, tf.float32])
    t_idx.set_shape((32, 32))
    t_up.set_shape((64, 64, Hb))
    # 产出 dict，便于后续 map 到 (x, y)
    return {"top_only": t_idx, "t_up": t_up}

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 8

def make_ds(file_list, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices(file_list)
    if shuffle:
        ds = ds.shuffle(buffer_size=min(20000, len(file_list)), reshuffle_each_iteration=True)
    ds = ds.map(parse_fn, num_parallel_calls=AUTO)
    ds = ds.batch(BATCH_SIZE, drop_remainder=False).prefetch(AUTO)
    return ds

# 划分 train/val
VAL_RATIO = 0.10
n_total   = len(npz_list)
n_val     = max(1, int(n_total * VAL_RATIO))
val_files   = npz_list[:n_val]
train_files = npz_list[n_val:]
print(f"Train: {len(train_files)} | Val: {len(val_files)}")

train_ds = make_ds(train_files, shuffle=True)
val_ds   = make_ds(val_files, shuffle=False)

# -------------------------------
# Strong Adapter (与你上条消息一致的 Keras-only 版本)
# -------------------------------
def GN(groups=16):
    try:
        return layers.GroupNormalization(groups=groups)
    except Exception:
        return layers.LayerNormalization(axis=-1)

def LambdaReshape_HW_to_seq(H, W):
    return layers.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], H*W, tf.shape(t)[-1]]))

def LambdaReshape_seq_to_HW(H, W):
    return layers.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], H, W, tf.shape(t)[-1]]))

def PixelShuffleLayer(r=2):
    return layers.Lambda(lambda t: tf.nn.depth_to_space(t, r))

def ResConv(x, ch, gn_groups=16):
    skip = x
    x = layers.Conv2D(ch, 3, padding='same')(x)
    x = GN(gn_groups)(x)
    x = layers.Activation('gelu')(x)
    x = layers.Conv2D(ch, 3, padding='same')(x)
    x = GN(gn_groups)(x)
    x = layers.Activation('gelu')(layers.Add()([x, skip]))
    return x

def MHSA_2d_32(x, heads=HEADS, D=D):
    to_seq = LambdaReshape_HW_to_seq(32, 32)
    to_hw  = LambdaReshape_seq_to_HW(32, 32)
    seq = to_seq(x)
    att = layers.MultiHeadAttention(num_heads=heads, key_dim=D//heads, output_shape=D)
    y = att(seq, seq, use_causal_mask=False)
    y = layers.LayerNormalization()(layers.Add()([y, seq]))
    y2 = layers.Dense(4*D, activation='gelu')(y)
    y2 = layers.Dense(D)(y2)
    y  = layers.LayerNormalization()(layers.Add()([y2, y]))
    return to_hw(y)

def CrossAttn_2d_64x32(q_64, kv_32, heads=HEADS, D_kv=D, C_q=None):
    if C_q is None:
        C_q = q_64.shape[-1]
    to_seq64 = LambdaReshape_HW_to_seq(64, 64)
    to_hw64  = LambdaReshape_seq_to_HW(64, 64)
    q = to_seq64(q_64)                              # [B,4096,Cq]
    att = layers.MultiHeadAttention(num_heads=heads, key_dim=D_kv//heads, output_shape=int(C_q))
    y = att(q, kv_32)                               # [B,4096,Cq]
    y = layers.LayerNormalization()(layers.Add()([y, q]))
    y2= layers.Dense(int(4*C_q), activation='gelu')(y)
    y2= layers.Dense(int(C_q))(y2)
    y = layers.LayerNormalization()(layers.Add()([y2, y]))
    return to_hw64(y)

def build_strong_adapter(vocab_top=VTOP, D=D, Hb=Hb, heads=HEADS):
    t_idx = keras.Input((32,32), dtype='int32', name='toponly_idx_32x32')

    F32 = layers.Embedding(vocab_top, D, name='emb_toponly')(t_idx)  # [B,32,32,D]
    x32 = ResConv(F32, D); x32 = ResConv(x32, D)
    x32 = MHSA_2d_32(x32, heads=heads, D=D)

    # KV 序列 + 绝对位置
    kv_seq = LambdaReshape_HW_to_seq(32, 32)(x32)                     # [B,1024,D]
    pos_ids = tf.constant([[i for i in range(32*32)]], dtype=tf.int32)
    pos_emb = layers.Embedding(32*32, D, name='pos_kv_emb')(pos_ids)  # [1,1024,D]
    kv_seq  = layers.Add()([kv_seq, pos_emb])                         # broadcast 到 B

    # 上采样到 64×64
    up64 = layers.Conv2D(D//2*4, 3, padding='same')(x32)
    up64 = PixelShuffleLayer(r=2)(up64)                               # [B,64,64,D//2]
    up64 = ResConv(up64, D//2); up64 = ResConv(up64, D//2)
    up64 = CrossAttn_2d_64x32(up64, kv_seq, heads=heads, D_kv=D, C_q=D//2)

    y = layers.Conv2D(D//2, 3, padding='same', dilation_rate=2, activation='gelu')(up64)
    y = layers.Conv2D(D//2, 3, padding='same', dilation_rate=3, activation='gelu')(y)
    y = layers.LayerNormalization(axis=-1)(y)
    out = layers.Conv2D(Hb, 1, padding='same', name='t_up_hat')(y)    # [B,64,64,Hb]

    return keras.Model(t_idx, out, name='AdapterStrong')

adapter = build_strong_adapter()
_ = adapter(tf.zeros([1,32,32], dtype=tf.int32))  # build

# -------------------------------
# 损失 & 训练
# -------------------------------
def cosine_loss(a, b, eps=1e-6):
    a = tf.nn.l2_normalize(a, axis=-1)
    b = tf.nn.l2_normalize(b, axis=-1)
    return tf.reduce_mean(1.0 - tf.reduce_mean(a*b, axis=-1))

def adapter_loss(y_true, y_pred):
    mse = tf.reduce_mean(tf.square(y_true - y_pred))
    cos = cosine_loss(y_true, y_pred)
    return mse + 0.1 * cos

try:
    opt = tf.keras.optimizers.AdamW(3e-4, weight_decay=1e-4)
except Exception:
    opt = tf.keras.optimizers.Adam(3e-4)

adapter.compile(optimizer=opt, loss=adapter_loss,
                metrics=[keras.metrics.MeanSquaredError(name="mse")])

ADAPTER_OUT = "/content/drive/MyDrive/vqvae_stage2_adapter"
os.makedirs(ADAPTER_OUT, exist_ok=True)
cbs = [
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(ADAPTER_OUT, "adapter_best.weights.h5"),  # ← 修正为 .weights.h5
        save_weights_only=True, save_best_only=True,
        monitor="val_loss", mode="min", verbose=1),
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=6, restore_best_weights=True),
    tf.keras.callbacks.CSVLogger(os.path.join(ADAPTER_OUT, "adapter_log.csv")),
    tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, verbose=1),
]

# (x, y) 二元组
def map_xy(batch):
    return batch["top_only"], batch["t_up"]

EPOCHS = 30
hist = adapter.fit(
    train_ds.map(map_xy),
    validation_data=val_ds.map(map_xy),
    epochs=EPOCHS,
    callbacks=cbs,
    verbose=1
)

print("✅ 训练完成；最佳权重：", os.path.join(ADAPTER_OUT, "adapter_best.weights.h5"))
