In [None]:

from pathlib import Path
import re

# === CONFIG ===
ONNX_DIR      = Path("submodule_onnx")  # base directory for ONNX files (edit if different)
EMBED_ONNX    = ONNX_DIR / "submodule_embed.onnx"
SOLVER_GLOB   = "submodule_solvers-*.onnx"  # will be discovered & sorted
OUTPUT_ONNX   = ONNX_DIR / "submodule_output.onnx"

OUTPUT_ROOT   = Path("onnx_txt")
TMP_MODEL_DIR = Path("tmp_instrumented_models")
TXT_FLOAT_FMT = "%.8g"

# ROMDataset config
DATA_ROOT   = "data/rom_det-3_part-200_cont-and-rounded_excerpt/"
SPLIT       = "train"
BATCH_SIZE  = 1
FEATURE_KEY = "readout_curr_cont"
DEVICE = "cpu"


In [None]:

import numpy as np
import onnx, onnxruntime as ort
from onnx import numpy_helper, helper, shape_inference
from typing import List, Dict, Optional, Tuple
import torch
from torch.utils.data import DataLoader
from rtal.datasets.dataset import ROMDataset


In [None]:

def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)

def sanitize(name: str) -> str:
    for b in ['/', '\\', ':', '*', '?', '"', '<', '>', '|', ' ']:
        name = name.replace(b, '_')
    return name

def dump_txt(path: Path, arr: np.ndarray, fmt: str):
    path.parent.mkdir(parents=True, exist_ok=True)
    flat = np.asarray(arr).ravel()
    with open(path, "w") as f:
        for v in flat:
            f.write((fmt % float(v)) + "\n")

def ort_session(path: Path):
    so = ort.SessionOptions()
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
    return ort.InferenceSession(str(path), so, providers=["CPUExecutionProvider"])

def concrete_dims(shape: Tuple) -> Tuple:
    out = []
    for s in shape:
        out.append(int(s) if isinstance(s, (int, np.integer)) else None)
    return tuple(out)

def list_dense(model) -> List[Dict]:
    dense = []
    g = model.graph
    init_names = {init.name for init in g.initializer}
    consumers = {}
    for node in g.node:
        for i in node.input: consumers.setdefault(i, []).append(node)
    idx = 0
    for n in g.node:
        if n.op_type == "Gemm":
            dense.append({"kind":"Gemm","index":idx,"out":n.output[0],
                          "W": n.input[1] if len(n.input)>1 else None,
                          "B": n.input[2] if len(n.input)>2 and n.input[2] in init_names else None})
            idx += 1
        elif n.op_type == "MatMul":
            mm_out = n.output[0]; add=None
            for c in consumers.get(mm_out, []):
                if c.op_type=="Add": add=c; break
            W = n.input[1] if len(n.input)>1 else None
            B = None
            out = add.output[0] if add is not None else mm_out
            if add is not None and len(add.input)>1 and add.input[1] in init_names:
                B = add.input[1]
            dense.append({"kind":"MatMulAdd" if add else "MatMul","index":idx,"out":out,"W":W,"B":B})
            idx += 1
    return dense

def list_lrelu(model)->List[Dict]:
    outs=[]; idx=0
    for n in model.graph.node:
        if n.op_type=="LeakyRelu":
            alpha=0.01
            for a in n.attribute:
                if a.name=="alpha": alpha=a.f; break
            outs.append({"index":idx,"out":n.output[0],"alpha":float(alpha)}); idx+=1
    return outs

def instrument(model, names):
    try:
        inferred = shape_inference.infer_shapes(model, strict_mode=True)
    except Exception:
        inferred = None
    g = model.graph
    have = {o.name for o in g.output}
    add = []
    for name in names:
        if name in have: continue
        vi = None
        if inferred is not None:
            for vi2 in list(inferred.graph.input)+list(inferred.graph.output)+list(inferred.graph.value_info):
                if vi2.name == name: vi = vi2; break
        if vi is None:
            vi = helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, None)
        add.append(vi)
    g.output.extend(add)
    return model

def safe_get_init_array(model, name: Optional[str]) -> Optional[np.ndarray]:
    if not name: return None
    for init in model.graph.initializer:
        if init.name == name:
            return numpy_helper.to_array(init)
    return None


In [None]:

def adapt_for_model(x: np.ndarray, model_in_shape: Tuple, prefer_3d_last: Optional[int]=None) -> np.ndarray:
    """Adapt x to model_in_shape. Supports:
       - (1,50,D) <-> (1,50*D)
       - If prefer_3d_last is set (e.g., 6/128/768), try to form (1,50,D) with that last dim.
    """
    xin = np.asarray(x, dtype=np.float32)
    dims = concrete_dims(model_in_shape)
    want_3d = (len(dims)==3 and (dims[0] in (1, None)) and (dims[1] in (50, None)))
    want_2d = (len(dims)==2 and (dims[0] in (1, None)))
    print(f"    [adapter] expects {model_in_shape} -> {dims}, prefer_3d_last={prefer_3d_last}")
    # Try 3D
    if want_3d and prefer_3d_last is not None:
        if xin.shape == (1,50,prefer_3d_last):
            return xin
        if xin.ndim==2 and xin.shape[1]==50*prefer_3d_last:
            return xin.reshape(1,50,prefer_3d_last)
        if xin.ndim==3 and xin.shape[1]==50 and xin.shape[2]*50 == 50*prefer_3d_last:
            return xin[:, :, :prefer_3d_last]
    # Try to fit exactly declared dims
    if want_3d and dims[2] is not None:
        D = dims[2]
        if xin.shape == (1,50,D): return xin
        if xin.ndim==2 and xin.shape[1]==50*D: return xin.reshape(1,50,D)
    if want_2d and dims[1] is not None:
        D = dims[1]
        if xin.ndim==2 and xin.shape[1]==D: return xin
        if xin.ndim==3 and xin.shape[1]*xin.shape[2]==D: return xin.reshape(1,D)
    # Last resort: flatten and pad/trim
    flat = xin.reshape(1,-1)
    if want_2d and dims[1] is not None:
        D = dims[1]
        if flat.shape[1] >= D: return flat[:, :D]
        pad = np.zeros((1,D), dtype=np.float32); pad[:,:flat.shape[1]] = flat; return pad
    if want_3d and dims[2] is not None:
        D = 50*dims[2]
        if flat.shape[1] >= D: return flat[:, :D].reshape(1,50,dims[2])
        pad = np.zeros((1,D), dtype=np.float32); pad[:,:flat.shape[1]] = flat; return pad.reshape(1,50,dims[2])
    return xin


In [None]:

def run_stage(model_path: Path, x_in: np.ndarray, tag: str, prefer_3d_last: Optional[int]=None) -> np.ndarray:
    if not model_path.exists():
        raise FileNotFoundError(f"{tag}: ONNX not found: {model_path}")
    model = onnx.load(str(model_path), load_external_data=True)
    dense = list_dense(model); lrs = list_lrelu(model)
    inst = onnx.ModelProto(); inst.CopyFrom(model)
    names = [d["out"] for d in dense] + [r["out"] for r in lrs]
    inst = instrument(inst, names)
    TMP_MODEL_DIR.mkdir(parents=True, exist_ok=True)
    tmp = TMP_MODEL_DIR / f"{model_path.stem}__inst.onnx"
    onnx.save(inst, str(tmp))

    sess = ort_session(tmp)
    in_vi = sess.get_inputs()[0]
    print(f"[{tag}] provided {x_in.shape} | expects {in_vi.shape}")
    xin = adapt_for_model(x_in, tuple(in_vi.shape), prefer_3d_last=prefer_3d_last)
    print(f"[{tag}] feeding {xin.shape}")

    fetches = names + [sess.get_outputs()[-1].name]
    outs = sess.run(fetches, {in_vi.name: xin})
    name_to_arr = dict(zip(fetches, outs))

    out_dir = OUTPUT_ROOT / sanitize(model_path.stem); ensure_dir(out_dir)
    dump_txt(out_dir / "input.txt", xin, TXT_FLOAT_FMT)
    for d in dense:
        W = safe_get_init_array(model, d["W"]); B = safe_get_init_array(model, d["B"])
        if W is not None: dump_txt(out_dir / f"dense_{d['index']}_weights.txt", W, TXT_FLOAT_FMT)
        if B is not None: dump_txt(out_dir / f"dense_{d['index']}_bias.txt", B, TXT_FLOAT_FMT)
    for d in dense:
        arr = name_to_arr.get(d["out"])
        if arr is not None: dump_txt(out_dir / f"dense_{d['index']}_output.txt", arr, TXT_FLOAT_FMT)
    for r in lrs:
        arr = name_to_arr.get(r["out"])
        if arr is not None:
            dump_txt(out_dir / f"leakyrelu_{r['index']}_output.txt", arr, TXT_FLOAT_FMT)
            with open(out_dir / f"leakyrelu_{r['index']}_alpha.txt", "w") as f:
                f.write(f"{r['alpha']}\n")
    dump_txt(out_dir / "model_output.txt", name_to_arr[fetches[-1]], TXT_FLOAT_FMT)

    return name_to_arr[fetches[-1]]


In [None]:

def rom_input_b506() -> np.ndarray:
    ds = ROMDataset(DATA_ROOT, split=SPLIT, num_particles=50)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)
    event = next(iter(dl))
    x = event[FEATURE_KEY].to(DEVICE)
    # Coerce to (1,50,6), typical source is (1,3,50,2)
    if x.ndim == 4 and x.shape[0]==1 and x.shape[1:]==(3,50,2):
        x = torch.transpose(x,1,2).reshape(1,50,6)
    elif x.shape == (1,50,6):
        pass
    elif x.ndim >= 3 and (x.shape[-1]*x.shape[-2] == 6):
        x = x.reshape(1,50,6)
    else:
        raise ValueError(f"Unexpected ROM shape {tuple(x.shape)}, cannot get (1,50,6)")
    return x.cpu().numpy().astype(np.float32)


In [None]:

def discover_solvers() -> List[Path]:
    paths = sorted(ONNX_DIR.glob(SOLVER_GLOB), key=lambda p: int(re.search(r"(\d+)", p.stem).group(1)) if re.search(r"(\d+)", p.stem) else 1e9)
    if not paths:
        raise FileNotFoundError(f"No solver ONNX files found under {ONNX_DIR} with pattern {SOLVER_GLOB}")
    print("[DISCOVER] solvers:", [p.name for p in paths])
    return paths


In [None]:

def run_chain():
    ensure_dir(OUTPUT_ROOT); ensure_dir(TMP_MODEL_DIR)
    # Check file existence up front
    if not EMBED_ONNX.exists():
        raise FileNotFoundError(f"embed ONNX not found: {EMBED_ONNX}")
    if not OUTPUT_ONNX.exists():
        raise FileNotFoundError(f"output ONNX not found: {OUTPUT_ONNX}")
    solvers = discover_solvers()

    # 1) ROM -> (1,50,6)
    x = rom_input_b506()

    # 2) embed stage (prefer forming 3D with last=6)
    x = run_stage(EMBED_ONNX, x, tag="EMBED", prefer_3d_last=6)

    # 3) solver stages in order — infer preferred last dim from model input (if it's concrete)
    for i, spath in enumerate(solvers):
        sess = ort_session(spath)
        in_shape = sess.get_inputs()[0].shape
        dims = concrete_dims(in_shape)
        pref = dims[2] if (len(dims)==3 and isinstance(dims[2], int)) else None
        if pref is None and len(dims)==2 and isinstance(dims[1], int) and dims[1] % 50 == 0:
            pref = dims[1] // 50
        print(f"[SOLVER-{i}] prefers last={pref} from {in_shape}")
        x = run_stage(spath, x, tag=f"SOLVER-{i}", prefer_3d_last=pref)

    # 4) output stage — prefer last=128 per your contract
    x = run_stage(OUTPUT_ONNX, x, tag="OUTPUT", prefer_3d_last=128)
    print("Chain complete. Dumps are under:", OUTPUT_ROOT)

# Execute
run_chain()
