# Inspect MST true sources (why NaNs appear)

This notebook inspects `static_datavectors_seed6.json` and the processed NPZ to explain
where NaNs in `mst_true` come from.

Key idea:
- If a block lacks `sigma_v_measured` or `kin_pred_samples`, MST is missing for that block.
- For blocks with MST, NaNs can come from non-finite values in sigma_v or kin_pred.


In [None]:
import json
import numpy as np
from pathlib import Path

json_path = Path('../Temp_data/static_datavectors_seed6.json')
npz_path = Path('/users/tianli/Temp_data/quasar_datavectors_seed6_processed.npz')

with json_path.open('r') as f:
    data = json.load(f)

print('Blocks:', len(data))


In [None]:
# ---------------------------
# Which blocks have MST inputs?
# ---------------------------
has_mst = np.array([('sigma_v_measured' in d) & ('kin_pred_samples' in d) for d in data])
missing_blocks = np.where(~has_mst)[0]
print('Blocks without MST inputs:', missing_blocks.tolist())


In [None]:
# ---------------------------
# Per-block sizes and expected NaN counts in MST
# ---------------------------
rows = []
for b, d in enumerate(data):
    z_lens = np.asarray(d['z_lens'])
    td = np.asarray(d['td_measured'])
    n_lens = z_lens.shape[0]
    n_td = td.shape[2]
    n_obs = n_lens * n_td
    rows.append((b, n_lens, n_td, n_obs, bool(has_mst[b])))

print('block  n_lens  n_td  n_obs  has_mst')
for r in rows:
    print(f'{r[0]:5d} {r[1]:7d} {r[2]:4d} {r[3]:6d} {str(r[4]):>7}')


In [None]:
# ---------------------------
# For MST blocks, inspect finite fractions
# ---------------------------
for b, d in enumerate(data):
    if not has_mst[b]:
        continue
    sig = np.asarray(d['sigma_v_measured'], dtype=float)
    kin = np.asarray(d['kin_pred_samples'], dtype=float)
    mst = sig / kin
    if mst.ndim == 3:
        mst = mst.mean(axis=2)
    sig_f = np.isfinite(sig).mean()
    kin_f = np.isfinite(kin).mean()
    mst_f = np.isfinite(mst).mean()
    print(f'block {b:02d}: sigma_v finite={sig_f:.4f}, kin_pred finite={kin_f:.4f}, mst finite={mst_f:.4f}')


In [None]:
# ---------------------------
# Compare with processed NPZ (NaN counts)
# ---------------------------
dnpz = np.load(npz_path)
mst_true = dnpz['mst_true']
mst_err = dnpz['mst_err']
block_id = dnpz['block_id']

nan_mask = ~np.isfinite(mst_true) | ~np.isfinite(mst_err)
print('Total MST NaN entries:', nan_mask.sum())

for b in np.unique(block_id):
    m = block_id == b
    n_nan = np.sum(nan_mask[m])
    n_tot = np.sum(m)
    if n_nan > 0:
        print(f'block {int(b):02d}: NaN {n_nan}/{n_tot}')
