# THINGS × VDVAE – Quick Data Checks

Use this notebook to sanity-check that your **predicted VDVAE latents** and **ref_latents** exist, load correctly, and match expected shapes.

## What this checks
- Confirms existence of files (predicted latents `.npy`, `ref_latents.npz`).
- Prints shapes for both.
- Verifies that the **flattened latent dimension** equals the sum of expected per-level sizes (`layer_dims`).
- Optionally lists the first/last few entries of your `test_image_paths.txt` to confirm index↔image ordering.

💡 **Instructions:**
1. Edit the parameters in the first code cell (subject, alpha, base directories) to match your setup.
2. Run cells top-to-bottom.


In [None]:
# --- Parameters (EDIT ME) ---
SUBJ = 1            # subject number (unpadded)
ALPHA = 50000       # ridge alpha used for predictions (e.g., 50000)

# Base directories (EDIT if your layout differs)
BASE = "/home/rothermm/THINGS"
PREPROC_DIR = f"{BASE}/02_data/preprocessed_data/subj{SUBJ:02d}"
PRED_DIR    = f"{BASE}/02_data/predicted_features/subj{SUBJ:02d}"
FEAT_DIR    = f"{BASE}/02_data/extracted_features/subj{SUBJ:02d}"
RESULTS_DIR = f"{BASE}/03_results/vdvae/subj{SUBJ:02d}"

# File names
def alpha_tag(a):
    return (str(int(a)) if float(a).is_integer() else str(a).replace('.', 'p'))

PRED_NPY = f"{PRED_DIR}/things_vdvae_pred_sub{SUBJ:02d}_31l_alpha{alpha_tag(ALPHA)}.npy"
REF_NPZ  = f"{FEAT_DIR}/ref_latents.npz"

# Optional: for inspecting image ordering
TEST_PATHS_TXT = f"{PREPROC_DIR}/test_image_paths.txt"

print("Using paths:")
print("  PRED_NPY:", PRED_NPY)
print("  REF_NPZ :", REF_NPZ)
print("  TEST_PATHS_TXT:", TEST_PATHS_TXT)


In [None]:
# --- Existence checks & basic loading ---
import os, numpy as np

pred_exists = os.path.exists(PRED_NPY)
ref_exists  = os.path.exists(REF_NPZ)
print("pred exists:", pred_exists)
print("ref  exists:", ref_exists)

if pred_exists:
    pred = np.load(PRED_NPY)
    print("pred shape:", pred.shape, "dtype:", pred.dtype)
else:
    pred = None

if ref_exists:
    ref = np.load(REF_NPZ, allow_pickle=True)["ref_latent"]
    print("ref levels:", len(ref))
    if len(ref) > 0:
        print("ref[0]['z'] shape:", ref[0]['z'].shape)
else:
    ref = None


In [None]:
# --- Verify flattened latent dimension matches per-level sizes ---
import numpy as np

# Must match your extraction stacking order (31 levels)
layer_dims = np.array([
    2**4,2**4,2**8,2**8,2**8,
    2**8,2**10,2**10,2**10,2**10,2**10,2**10,2**10,
    2**10,2**12,2**12,2**12,2**12,2**12,2**12,2**12,
    2**12,2**12,2**12,2**12,2**12,2**12,2**12,2**12,
    2**12,2**14
], dtype=np.int64)

if pred is not None:
    total = int(layer_dims.sum())
    print("Sum(layer_dims)   =", total)
    print("pred latent width =", pred.shape[1])
    if pred.shape[1] != total:
        print("WARNING: width mismatch -> check your layer_dims vs extraction")
    else:
        print("OK: pred width matches sum(layer_dims)")

    # Show first few split offsets
    offs = np.cumsum(np.r_[0, layer_dims])
    print("First 10 offsets:", offs[:10])
else:
    print("Skip: pred not loaded")


In [None]:
# --- Optional: inspect test_image_paths to confirm ordering (head/tail) ---
def head_tail(path, n=5):
    try:
        with open(path, 'r') as f:
            lines = [ln.strip() for ln in f if ln.strip()]
        print(f"Total lines: {len(lines)}")
        print("\nHEAD:")
        for ln in lines[:n]:
            print("  ", ln)
        print("\nTAIL:")
        for ln in lines[-n:]:
            print("  ", ln)
    except FileNotFoundError:
        print("File not found:", path)

head_tail(TEST_PATHS_TXT, n=5)


In [None]:
# --- Helper: reshape one sample's flat vector to per-level maps using ref shapes ---
def flat_to_pyramid_one(vec, ref_levels, layer_dims):
    offs = np.cumsum(np.r_[0, layer_dims])
    out = []
    for i in range(len(layer_dims)):
        sl = vec[offs[i]:offs[i+1]]
        c, h, w = ref_levels[i]['z'].shape[1:]
        out.append(sl.reshape(c, h, w))
    return out

if (pred is not None) and (ref is not None):
    maps0 = flat_to_pyramid_one(pred[0], ref, layer_dims)
    print("Levels:", len(maps0))
    print("Level 0 shape (C,H,W):", maps0[0].shape)
else:
    print("Skip: need both pred and ref loaded")
