In [None]:

from pathlib import Path
import re

# === CONFIG ===
ONNX_DIR      = Path("submodule_onnx")
EMBED_ONNX    = ONNX_DIR / "submodule_embed.onnx"
SOLVER_GLOB   = "submodule_solvers-*.onnx"
OUTPUT_ONNX   = ONNX_DIR / "submodule_output.onnx"

OUTPUT_ROOT   = Path("onnx_txt")
TMP_MODEL_DIR = Path("tmp_instrumented_models")
TXT_FLOAT_FMT = "%.8g"

# ROMDataset
DATA_ROOT   = "data/rom_det-3_part-200_cont-and-rounded_excerpt/"
SPLIT       = "train"
BATCH_SIZE  = 1
FEATURE_KEY = "readout_curr_cont"
DEVICE = "cpu"


In [None]:

import numpy as np
import onnx, onnxruntime as ort
from onnx import numpy_helper, helper, shape_inference
from typing import List, Dict, Optional, Tuple
import torch
from torch.utils.data import DataLoader
from rtal.datasets.dataset import ROMDataset


In [None]:

def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)

def sanitize(name: str) -> str:
    for b in ['/', '\\', ':', '*', '?', '"', '<', '>', '|', ' ']:
        name = name.replace(b, '_')
    return name

def dump_txt(path: Path, arr: np.ndarray, fmt: str = "%.8g"):
    path.parent.mkdir(parents=True, exist_ok=True)
    flat = np.asarray(arr).ravel()
    with open(path, "w") as f:
        for v in flat:
            f.write((fmt % float(v)) + "\n")

def ort_session(path: Path):
    so = ort.SessionOptions()
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
    return ort.InferenceSession(str(path), so, providers=["CPUExecutionProvider"])

def concrete_dims(shape: Tuple) -> Tuple:
    return tuple(int(s) if isinstance(s,(int,np.integer)) else None for s in shape)


In [None]:

# ---- Bias-aware graph introspection ----
def build_maps(model):
    g = model.graph
    init_map = {init.name: numpy_helper.to_array(init) for init in g.initializer}
    const_map = {}
    producers = {}
    for n in g.node:
        for o in n.output:
            producers[o] = n
        if n.op_type == "Constant":
            for a in n.attribute:
                if a.name == "value":
                    const_map[n.output[0]] = numpy_helper.to_array(a.t)
    consumers = {}
    for n in g.node:
        for i in n.input:
            consumers.setdefault(i, []).append(n)
    return init_map, const_map, consumers, producers

def resolve_bias_tensor(name: Optional[str], init_map, const_map, producers):
    if not name:
        return None
    if name in init_map:
        return init_map[name]
    if name in const_map:
        return const_map[name]
    prod = producers.get(name)
    if prod is not None and prod.op_type == "Constant":
        for a in prod.attribute:
            if a.name == "value":
                return numpy_helper.to_array(a.t)
    return None

def list_dense_with_bias_and_outputs(model) -> List[Dict]:
    """Identify dense-like operations and their output tensors, weights, and bias.
       - Gemm: W=input[1], B=input[2] (resolve Constant as well), out = node.output[0]
       - MatMul(+Add): W=MatMul input[1], B=Add's non-MM input (resolve Constant), out = Add output (or MatMul output if no Add)
    """
    dense = []
    g = model.graph
    init_map, const_map, consumers, producers = build_maps(model)
    idx = 0
    for n in g.node:
        if n.op_type == "Gemm":
            Wname = n.input[1] if len(n.input)>1 else None
            Bname = n.input[2] if len(n.input)>2 else None
            W = init_map.get(Wname)
            B = resolve_bias_tensor(Bname, init_map, const_map, producers)
            dense.append({"kind":"Gemm","index":idx,"out":n.output[0],"W":W,"B":B})
            idx += 1
        elif n.op_type == "MatMul":
            mm_out = n.output[0]
            add_node = None
            for c in consumers.get(mm_out, []):
                if c.op_type == "Add":
                    add_node = c; break
            Wname = n.input[1] if len(n.input)>1 else None
            W = init_map.get(Wname)
            out = add_node.output[0] if add_node is not None else mm_out
            B = None
            if add_node is not None:
                a0, a1 = add_node.input[0], add_node.input[1]
                cand = a1 if a0 == mm_out else a0
                B = resolve_bias_tensor(cand, init_map, const_map, producers)
            dense.append({"kind":"MatMulAdd" if add_node else "MatMul","index":idx,"out":out,"W":W,"B":B})
            idx += 1
    return dense

def list_lrelu_nodes(model)->List[Dict]:
    outs=[]; idx=0
    for n in model.graph.node:
        if n.op_type=="LeakyRelu":
            alpha=0.01
            for a in n.attribute:
                if a.name=="alpha": alpha=a.f; break
            outs.append({"index":idx,"out":n.output[0],"alpha":float(alpha)}); idx+=1
    return outs

def instrument_model_for_outputs(model, names: List[str]):
    # Add requested value infos as graph outputs, using shape inference where possible
    try:
        inferred = shape_inference.infer_shapes(model, strict_mode=True)
        inferred_vis = list(inferred.graph.input)+list(inferred.graph.output)+list(inferred.graph.value_info)
        vi_map = {vi.name: vi for vi in inferred_vis}
    except Exception:
        vi_map = {}
    g = model.graph
    existing_outs = {o.name for o in g.output}
    to_add = []
    for name in names:
        if name in existing_outs: 
            continue
        vi = vi_map.get(name, helper.make_tensor_value_info(name, onnx.TensorProto.FLOAT, None))
        to_add.append(vi)
    g.output.extend(to_add)
    return model


In [None]:

# ---- Shape adapter & stage runner ----
def adapt_for_model(x: np.ndarray, model_in_shape: Tuple, prefer_3d_last: Optional[int]=None) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)
    dims = concrete_dims(model_in_shape)
    want_3d = len(dims)==3 and (dims[0] in (1,None)) and (dims[1] in (50,None))
    want_2d = len(dims)==2 and (dims[0] in (1,None))
    # Prefer exact (1,50,D)
    if want_3d and prefer_3d_last is not None:
        D = prefer_3d_last
        if x.shape == (1,50,D): return x
        if x.ndim==2 and x.shape[1]==50*D: return x.reshape(1,50,D)
        if x.ndim==3 and x.shape[1]==50 and x.shape[2]*50 == 50*D: return x[:, :, :D]
    # Fit declared dims if concrete
    if want_3d and dims[2] is not None:
        D = dims[2]
        if x.shape == (1,50,D): return x
        if x.ndim==2 and x.shape[1]==50*D: return x.reshape(1,50,D)
    if want_2d and dims[1] is not None:
        D = dims[1]
        if x.ndim==2 and x.shape[1]==D: return x
        if x.ndim==3 and x.shape[1]*x.shape[2]==D: return x.reshape(1,D)
    # Fallback: flatten then trim/pad
    flat = x.reshape(1,-1)
    if want_2d and dims[1] is not None:
        D = dims[1]
        if flat.shape[1] >= D: return flat[:, :D]
        pad = np.zeros((1,D), dtype=np.float32); pad[:, :flat.shape[1]] = flat; return pad
    if want_3d and dims[2] is not None:
        D = 50*dims[2]
        if flat.shape[1] >= D: return flat[:, :D].reshape(1,50,dims[2])
        pad = np.zeros((1,D), dtype=np.float32); pad[:, :flat.shape[1]] = flat; return pad.reshape(1,50,dims[2])
    return x

def run_stage(model_path: Path, x_in: np.ndarray, tag: str, prefer_3d_last: Optional[int]=None) -> np.ndarray:
    if not model_path.exists():
        raise FileNotFoundError(f"{tag}: ONNX not found: {model_path}")
    model = onnx.load(str(model_path), load_external_data=True)
    dense = list_dense_with_bias_and_outputs(model)
    lrs   = list_lrelu_nodes(model)
    # Instrument
    names = [d["out"] for d in dense] + [r["out"] for r in lrs]
    inst = onnx.ModelProto(); inst.CopyFrom(model)
    inst = instrument_model_for_outputs(inst, names)
    TMP_MODEL_DIR.mkdir(parents=True, exist_ok=True)
    tmp = TMP_MODEL_DIR / f"{model_path.stem}__inst.onnx"
    onnx.save(inst, str(tmp))
    # Run
    sess = ort_session(tmp)
    in_vi = sess.get_inputs()[0]
    print(f"[{tag}] provided {x_in.shape} | expects {in_vi.shape}")
    xin = adapt_for_model(x_in, tuple(in_vi.shape), prefer_3d_last=prefer_3d_last)
    print(f"[{tag}] feeding {xin.shape}")
    fetches = names + [sess.get_outputs()[-1].name]
    outs = sess.run(fetches, {in_vi.name: xin})
    name_to_arr = dict(zip(names, outs[:-1]))
    name_to_arr["model_output"] = outs[-1]
    # Dumps
    out_dir = OUTPUT_ROOT / sanitize(model_path.stem); ensure_dir(out_dir)
    dump_txt(out_dir / "input.txt", xin, TXT_FLOAT_FMT)
    for d in dense:
        if d["W"] is not None:
            dump_txt(out_dir / f"dense_{d['index']}_weights.txt", d["W"], TXT_FLOAT_FMT)
        if d["B"] is not None:
            dump_txt(out_dir / f"dense_{d['index']}_bias.txt", d["B"], TXT_FLOAT_FMT)
    for d in dense:
        arr = name_to_arr.get(d["out"])
        if arr is not None:
            dump_txt(out_dir / f"dense_{d['index']}_output.txt", arr, TXT_FLOAT_FMT)
    for r in lrs:
        arr = name_to_arr.get(r["out"])
        if arr is not None:
            dump_txt(out_dir / f"leakyrelu_{r['index']}_output.txt", arr, TXT_FLOAT_FMT)
            with open(out_dir / f"leakyrelu_{r['index']}_alpha.txt", "w") as f:
                f.write(f"{r['alpha']}\n")
    dump_txt(out_dir / "model_output.txt", name_to_arr["model_output"], TXT_FLOAT_FMT)
    return name_to_arr["model_output"]


In [None]:

# ---- ROMDataset and discovery ----
def rom_input_b506() -> np.ndarray:
    ds = ROMDataset(DATA_ROOT, split=SPLIT, num_particles=50)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)
    event = next(iter(dl))
    x = event[FEATURE_KEY].to(DEVICE)
    # Normalize to (1,50,6); common arrival: (1,3,50,2)
    if x.ndim == 4 and x.shape[0]==1 and x.shape[1:]==(3,50,2):
        x = torch.transpose(x,1,2).reshape(1,50,6)
    elif x.shape == (1,50,6):
        pass
    elif x.ndim >= 3 and (x.shape[-1]*x.shape[-2] == 6):
        x = x.reshape(1,50,6)
    else:
        raise ValueError(f"Unexpected ROM shape {tuple(x.shape)}, cannot get (1,50,6)")
    return x.cpu().numpy().astype(np.float32)

def discover_solvers() -> List[Path]:
    paths = sorted(ONNX_DIR.glob(SOLVER_GLOB), key=lambda p: int(re.search(r"(\d+)", p.stem).group(1)) if re.search(r"(\d+)", p.stem) else 1e9)
    if not paths:
        raise FileNotFoundError(f"No solver ONNX files found under {ONNX_DIR} with pattern {SOLVER_GLOB}")
    print("[DISCOVER] solvers:", [p.name for p in paths])
    return paths


In [None]:

# ---- Chain runner ----
def run_chain():
    ensure_dir(OUTPUT_ROOT); ensure_dir(TMP_MODEL_DIR)
    if not EMBED_ONNX.exists(): raise FileNotFoundError(f"embed ONNX not found: {EMBED_ONNX}")
    if not OUTPUT_ONNX.exists(): raise FileNotFoundError(f"output ONNX not found: {OUTPUT_ONNX}")
    solvers = discover_solvers()
    # 1) ROM → (1,50,6)
    x = rom_input_b506()
    # 2) embed (prefer last=6) → (1,50,768)
    x = run_stage(EMBED_ONNX, x, tag="EMBED", prefer_3d_last=6)
    # 3) each solver in order — derive preferred last dim from input shape when possible
    for i, spath in enumerate(solvers):
        sess = ort_session(spath); in_shape = sess.get_inputs()[0].shape
        dims = concrete_dims(in_shape)
        pref = dims[2] if (len(dims)==3 and isinstance(dims[2], int)) else (dims[1]//50 if len(dims)==2 and isinstance(dims[1], int) and dims[1]%50==0 else None)
        print(f"[SOLVER-{i}] prefers last={pref} from {in_shape}")
        x = run_stage(spath, x, tag=f"SOLVER-{i}", prefer_3d_last=pref)
    # 4) output (prefer last=128)
    x = run_stage(OUTPUT_ONNX, x, tag="OUTPUT", prefer_3d_last=128)
    print("Chain complete. Dumps are under:", OUTPUT_ROOT)

# Execute
run_chain()
