In [1]:
import onnx
from onnx import helper, numpy_helper, shape_inference
import onnxruntime as ort
from pathlib import Path


In [2]:
ONNX_IN  = Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_tl_20250829.onnx")
ONNX_OUT = Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_tl_20250829_with_feats.onnx")


model = onnx.load(str(ONNX_IN))

# logits tensor is the current graph output name (your model: "output")
assert len(model.graph.output) == 1, "Expected single output for logits"
logits_src = model.graph.output[0].name
print("Logits source tensor:", logits_src)

# penultimate features tensor: last Flatten output (right before Gemm)
flatten_nodes = [n for n in model.graph.node if n.op_type == "Flatten"]
assert flatten_nodes, "No Flatten node found"
feats_src = flatten_nodes[-1].output[0]
print("Feats source tensor:", feats_src)

Logits source tensor: output
Feats source tensor: /Flatten_output_0


In [3]:
# Cell 2. Add Identity heads for both outputs (stable naming) + set graph outputs (feats + logits)
import onnx
from onnx import helper, TensorProto, shape_inference

# Infer shapes if possible (so outputs have known shapes)
model_inf = shape_inference.infer_shapes(model)

def find_vi(name: str):
    # Search existing value_info / inputs / outputs
    for vi in list(model_inf.graph.value_info) + list(model_inf.graph.input) + list(model_inf.graph.output):
        if vi.name == name:
            return vi
    return None

def make_output_vi(src_name: str, out_name: str):
    vi = find_vi(src_name)
    if vi is None:
        # Fallback: unknown shape
        return helper.make_tensor_value_info(out_name, TensorProto.FLOAT, None)

    # Copy type/shape from src, but rename to out_name
    t = vi.type.tensor_type
    elem_type = t.elem_type
    shape = [d.dim_value if d.dim_value > 0 else (d.dim_param if d.dim_param else None) for d in t.shape.dim]
    # ONNX wants dim_value or dim_param, cannot mix None directly. If unknown, omit the whole shape.
    if any(s is None for s in shape):
        return helper.make_tensor_value_info(out_name, elem_type, None)
    return helper.make_tensor_value_info(out_name, elem_type, shape)

# Create explicit output names
FEATS_OUT  = "features"
LOGITS_OUT = "logits"

# Add Identity nodes to expose them as named outputs
id_feats  = helper.make_node("Identity", inputs=[feats_src],  outputs=[FEATS_OUT],  name="ExposeFeats")
id_logits = helper.make_node("Identity", inputs=[logits_src], outputs=[LOGITS_OUT], name="ExposeLogits")

model.graph.node.extend([id_feats, id_logits])

# Replace graph outputs with both
model.graph.ClearField("output")
model.graph.output.extend([
    make_output_vi(feats_src, FEATS_OUT),
    make_output_vi(logits_src, LOGITS_OUT),
])

onnx.save(model, str(ONNX_OUT))
print("Saved:", ONNX_OUT)


Saved: /home/hschatzle/monte-carlo-selection/data/models/resnet50_tl_20250829_with_feats.onnx


In [4]:
# Cell 3. Verify with ONNX Runtime
import onnxruntime as ort

sess = ort.InferenceSession(str(ONNX_OUT), providers=["CPUExecutionProvider"])

print("Outputs:")
for o in sess.get_outputs():
    print(" ", o.name, o.shape, o.type)

print("\nInputs:")
for i in sess.get_inputs():
    print(" ", i.name, i.shape, i.type)


Outputs:
  features ['batch_size', 2048] tensor(float)
  logits ['batch_size', 54] tensor(float)

Inputs:
  input ['batch_size', 3, 224, 224] tensor(float)


In [1]:
from pathlib import Path

M1 = Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_geirhos_tl_with_feats.onnx")  # e.g. original
M2 = Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_jangtong_tl_with_feats.onnx")  # e.g. your modified one
M3 = Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_tl_20250829_with_feats.onnx")  # e.g. another baseline

for p in [M1, M2, M3]:
    assert p.exists(), f"Missing: {p}"


In [2]:
import onnx
import onnxruntime as ort

def dump_signature(onnx_path: Path):
    print("\n" + "="*90)
    print("MODEL:", onnx_path)

    m = onnx.load(str(onnx_path))

    print("IR version:", m.ir_version)
    print("Opset imports:", [(op.domain, op.version) for op in m.opset_import])
    print("Producer:", m.producer_name, m.producer_version)

    def fmt_vi(vi):
        tt = vi.type.tensor_type
        dtype = tt.elem_type
        shape = []
        for d in tt.shape.dim:
            if d.dim_param:
                shape.append(d.dim_param)
            else:
                shape.append(d.dim_value if d.dim_value != 0 else "?")
        return vi.name, dtype, shape

    print("\nGraph inputs:")
    for vi in m.graph.input:
        print(" ", fmt_vi(vi))

    print("\nGraph outputs:")
    for vi in m.graph.output:
        print(" ", fmt_vi(vi))

    # ORT view (sometimes differs if shape info is missing)
    sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"])
    print("\nORT inputs:")
    for i in sess.get_inputs():
        print(" ", i.name, i.shape, i.type)
    print("\nORT outputs:")
    for o in sess.get_outputs():
        print(" ", o.name, o.shape, o.type)

    # Quick tail nodes (helps identify where outputs come from)
    print("\nLast 12 nodes:")
    for n in m.graph.node[-12:]:
        print(f" {n.op_type} -> {list(n.output)}")

dump_signature(M1)
dump_signature(M2)
dump_signature(M3)



MODEL: /home/hschatzle/monte-carlo-selection/data/models/resnet50_geirhos_tl_with_feats.onnx
IR version: 6
Opset imports: [('', 11)]
Producer: pytorch 2.8.0

Graph inputs:
  ('input', 1, ['batch_size', 3, 224, 224])

Graph outputs:
  ('output', 1, ['batch_size', 54])
  ('features', 1, ['batch_size', 2048])

ORT inputs:
  input ['batch_size', 3, 224, 224] tensor(float)

ORT outputs:
  output ['batch_size', 54] tensor(float)
  features ['batch_size', 2048] tensor(float)

Last 12 nodes:
 Relu -> ['/layer4/layer4/1/relu_2/Relu/Relu_output_0']
 Conv -> ['/layer4/layer4/2/conv1/Conv/Conv_output_0']
 Relu -> ['/layer4/layer4/2/relu/Relu/Relu_output_0']
 Conv -> ['/layer4/layer4/2/conv2/Conv/Conv_output_0']
 Relu -> ['/layer4/layer4/2/relu_1/Relu/Relu_output_0']
 Conv -> ['/layer4/layer4/2/conv3/Conv/Conv_output_0']
 Add -> ['/layer4/layer4/2/Add/Add_output_0']
 Relu -> ['/layer4/layer4/2/relu_2/Relu/Relu_output_0']
 GlobalAveragePool -> ['/avgpool/GlobalAveragePool/GlobalAveragePool_output_0

In [3]:
from pathlib import Path
import onnx
from onnx import helper, TensorProto, shape_inference

def canonicalize_outputs(
    onnx_in: Path,
    onnx_out: Path,
    logits_src_tensor: str = "output",      # in your graphs, Gemm produces 'output'
    features_tensor_name: str = "features", # your identity already uses this in all 3
):
    m = onnx.load(str(onnx_in))

    # Infer shapes (nice-to-have for proper output metadata)
    m_inf = shape_inference.infer_shapes(m)

    def find_vi(name: str):
        for vi in list(m_inf.graph.value_info) + list(m_inf.graph.input) + list(m_inf.graph.output):
            if vi.name == name:
                return vi
        return None

    def make_vi(src_name: str, out_name: str):
        vi = find_vi(src_name)
        if vi is None:
            return helper.make_tensor_value_info(out_name, TensorProto.FLOAT, None)

        tt = vi.type.tensor_type
        elem = tt.elem_type
        # If any dim is unknown, omit shape (keeps it permissive)
        shape = []
        for d in tt.shape.dim:
            if d.dim_param or d.dim_value == 0:
                return helper.make_tensor_value_info(out_name, elem, None)
            shape.append(d.dim_value)
        return helper.make_tensor_value_info(out_name, elem, shape)

    # 1) Ensure we have a 'features' tensor somewhere.
    # In your models, there is already: Identity -> ['features']
    features_exists = any(
        (features_tensor_name in node.output) for node in m.graph.node
    ) or any(o.name == features_tensor_name for o in m.graph.output)

    if not features_exists:
        # Fallback: expose last Flatten as 'features'
        flatten_nodes = [n for n in m.graph.node if n.op_type == "Flatten"]
        if not flatten_nodes:
            raise RuntimeError("No Flatten node found to derive features from.")
        feats_src = flatten_nodes[-1].output[0]
        m.graph.node.append(
            helper.make_node("Identity", inputs=[feats_src], outputs=[features_tensor_name], name="ExposeFeatures")
        )

    # 2) Make sure logits are exposed as 'output' (not 'logits').
    # Your Gemm already outputs 'output' in all three models.
    # But models 2/3 also add an Identity -> ['logits'] and declare graph output 'logits'.
    # We'll simply set the graph outputs to the canonical pair.

    # 3) Replace graph outputs with canonical order and names
    m.graph.ClearField("output")
    m.graph.output.extend([
        make_vi(logits_src_tensor, "output"),
        make_vi(features_tensor_name, "features"),
    ])

    onnx.save(m, str(onnx_out))
    return onnx_out

# ---- run on your three models ----
paths = [
    Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_geirhos_tl_with_feats.onnx"),
    Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_jangtong_tl_with_feats.onnx"),
    Path("/home/hschatzle/monte-carlo-selection/data/models/resnet50_tl_20250829_with_feats.onnx"),
]

for p in paths:
    outp = p.with_name(p.stem.replace("_with_feats", "") + "_CANON.onnx")
    canonicalize_outputs(p, outp)
    print("Wrote:", outp)


Wrote: /home/hschatzle/monte-carlo-selection/data/models/resnet50_geirhos_tl_CANON.onnx
Wrote: /home/hschatzle/monte-carlo-selection/data/models/resnet50_jangtong_tl_CANON.onnx
Wrote: /home/hschatzle/monte-carlo-selection/data/models/resnet50_tl_20250829_CANON.onnx
