# ONNX Submodule Dumper

This notebook provides a complete script to run ONNX submodules and dump the input data, output data, weights, biases, and Leaky ReLU activation outputs to text files for a single input.

**Please ensure the following directory structure and files exist before running this notebook:**

```
.
├── checkpoints/
│   └── config.yaml
├── data/
│   └── rom_det-3_part-200_cont-and-rounded_excerpt/
│       └── ...
├── submodule_onnx/
│   ├── submodule_embed.onnx
│   ├── submodule_solvers-0.onnx
│   ├── submodule_solvers-1.onnx
│   ├── submodule_solvers-2.onnx
│   └── submodule_output.onnx
├── mlp.py
└── dump_all_submodules.ipynb
```

In [1]:
# ======================== IMPORTS ========================
import os
import sys
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from rich import print as rprint

import torch
from torch import nn
import onnx
import onnxruntime as ort

from rtal.datasets.dataset import ROMDataset
from torch.utils.data import DataLoader

from mlp import MLP
from onnx import numpy_helper  # add this import


In [2]:
# ---- config ----
SUBMODULE_DIR = "submodule_onnx"
EMBED_ONNX    = f"{SUBMODULE_DIR}/submodule_embed.onnx"
OUTPUT_ONNX   = f"{SUBMODULE_DIR}/submodule_output.onnx"

NUM_SOLVERS   = 3     # <-- set to your count
SUBSET_SIZE   = 6     # <-- set to your assemble_np subset size
BATCH_SIZE    = 1
NUM_PARTICLES = 50    # <-- match your training/config
DATA_ROOT     = "data/rom_det-3_part-200_cont-and-rounded_excerpt/"
SPLIT         = "train"

OUT_DIR       = "onnx_txt"   # where txt files go
SAVE_DTYPE = "float32"

In [3]:
# ---- dataset → readout ----
dataset    = ROMDataset(DATA_ROOT, split=SPLIT, num_particles=NUM_PARTICLES)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
event      = next(iter(dataloader))

# readout: (B, num_detectors, num_particles, 2)
readout = event['readout_curr_cont']  # torch tensor (CPU)

# -> (B, num_particles, num_detectors, 2)
readout = torch.transpose(readout, 1, 2)
# -> (B, num_particles, num_detectors*2) == (B, num_particles, in_features)
readout = readout.flatten(-2, -1)

print("readout shape (torch):", tuple(readout.shape))

readout shape (torch): (1, 50, 6)


In [4]:
# ---- infer dtype from embed ONNX input ----
embed_sess = ort.InferenceSession(EMBED_ONNX)
embed_in_type = embed_sess.get_inputs()[0].type  # e.g. 'tensor(float16)' or 'tensor(float)'
np_dtype = np.float16 if 'float16' in embed_in_type else np.float32

onnx_inputs = readout.numpy().astype(np_dtype)
print("onnx_inputs shape:", onnx_inputs.shape, "dtype:", onnx_inputs.dtype)

onnx_inputs shape: (1, 50, 6) dtype: float32


In [5]:
from pathlib import Path
import numpy as np

def assemble_np(array: np.ndarray, subset_size: int) -> np.ndarray:
    """Roll along axis=1 and concat along the last dim."""
    return np.concatenate([np.roll(array, shift=i, axis=1) for i in range(subset_size)], axis=-1)

def _cast_for_save(arr: np.ndarray) -> np.ndarray:
    if SAVE_DTYPE == "float32":
        return arr.astype(np.float32, copy=False)
    elif SAVE_DTYPE == "float16":
        return arr.astype(np.float16, copy=False)
    else:
        raise ValueError(f"SAVE_DTYPE must be 'float16' or 'float32', got {SAVE_DTYPE!r}")

def save_txt(path: Path, arr: np.ndarray):
    path.parent.mkdir(parents=True, exist_ok=True)
    arr_to_save = _cast_for_save(arr)     # <-- cast only for saving
    np.savetxt(path, np.asarray(arr_to_save).reshape(-1), fmt="%.6f")

def run_and_dump(session: ort.InferenceSession, x: np.ndarray, tag: str):
    subdir = Path(OUT_DIR) / tag
    # save the input (casted per SAVE_DTYPE), but feed the original x to the model
    save_txt(subdir / "input.txt", x)

    in_name  = session.get_inputs()[0].name
    out_name = session.get_outputs()[0].name
    y = session.run([out_name], {in_name: x})[0]

    # save the output (casted per SAVE_DTYPE)
    save_txt(subdir / "output.txt", y)
    return y  # return original dtype for downstream solvers

def dump_dense_weights(model_path: str, tag: str):
    """
    Save weights/biases for Linear/Dense layers in an ONNX:
      - Gemm:     W = input[1], B = input[2] (if initializer)
      - MatMul:   W = right input (if initializer); try to find Add bias after it
    Files go under OUT_DIR/<tag>/ as NN_linear_W.txt / NN_linear_B.txt.
    NOTE: Saved exactly as stored in ONNX (no transpose handling).
    """
    m = onnx.load(model_path)
    inits = {i.name: numpy_helper.to_array(i) for i in m.graph.initializer}

    # Build simple consumer map to detect MatMul -> Add bias
    consumers = {}
    for n in m.graph.node:
        for i in n.input:
            consumers.setdefault(i, []).append(n)

    subdir = Path(OUT_DIR) / tag
    subdir.mkdir(parents=True, exist_ok=True)

    idx = 0
    for n in m.graph.node:
        if n.op_type == "Gemm":
            W = inits.get(n.input[1]) if len(n.input) > 1 else None
            B = inits.get(n.input[2]) if len(n.input) > 2 else None
            if W is not None: save_txt(subdir / f"{idx:02d}_linear_W.txt", W)
            if B is not None: save_txt(subdir / f"{idx:02d}_linear_B.txt", B)
            idx += 1

        elif n.op_type == "MatMul" and len(n.input) > 1 and n.input[1] in inits:
            # Right input is constant weights
            W = inits[n.input[1]]
            save_txt(subdir / f"{idx:02d}_linear_W.txt", W)

            # Look for immediate Add with a constant bias
            B = None
            for c in consumers.get(n.output[0], []):
                if c.op_type == "Add":
                    other = [t for t in c.input if t != n.output[0]]
                    if other and other[0] in inits:
                        B = inits[other[0]]
                        break
            if B is not None:
                save_txt(subdir / f"{idx:02d}_linear_B.txt", B)
            idx += 1

In [6]:
# ---- load sessions ----
solver_paths = [f"{SUBMODULE_DIR}/submodule_solvers-{i}.onnx" for i in range(NUM_SOLVERS)]
solver_sess  = [ort.InferenceSession(p) for p in solver_paths]
output_sess  = ort.InferenceSession(OUTPUT_ONNX)

# ---- run ----
# 1) embed
embed_out = run_and_dump(embed_sess, onnx_inputs, tag="embed")

# 2) solvers
arr = embed_out
for i, sess in enumerate(solver_sess):
    arr = assemble_np(arr, SUBSET_SIZE)
    arr = run_and_dump(sess, arr, tag=f"solvers-{i}")

# 3) output
out = run_and_dump(output_sess, arr, tag="output")

print("Done. Wrote TXT dumps under:", OUT_DIR)

Done. Wrote TXT dumps under: onnx_txt


In [7]:
# Weights for embed / solvers / output
dump_dense_weights(EMBED_ONNX, "embed")
for i, p in enumerate(solver_paths):
    dump_dense_weights(p, f"solvers-{i}")
dump_dense_weights(OUTPUT_ONNX, "output")

In [None]:

# # Inputs
# x      = np.loadtxt("input_x.txt", dtype=np.float32)           # 128
# W      = np.loadtxt("weights.txt", dtype=np.float32)           # 16384 (=128*128), row-major
# bias   = np.loadtxt("bias.txt",    dtype=np.float32)           # 128

# # Reshape weights row-major (rows=outputs, cols=K)
# W = W.reshape(128, 128)

# # Split along K=128 → two 64-wide blocks
# W0 = W[:, :64].reshape(-1)
# W1 = W[:, 64:].reshape(-1)

# # Split x consistently
# x0 = x[:64]
# x1 = x[64:]

# np.savetxt("weights_part0.txt", W0, fmt="%.6f")
# np.savetxt("weights_part1.txt", W1, fmt="%.6f")
# np.savetxt("x_part0.txt", x0, fmt="%.6f")
# np.savetxt("x_part1.txt", x1, fmt="%.6f")
