In [1]:
import onnx

# Path to your ONNX model
model_path = 'simple_mnist.onnx'

# Load the ONNX model
model = onnx.load(model_path)

# Check the model
print("Model is loaded")
print("Model type:", type(model))

# Inspect the model (optional)
# Print the model's graph
print(onnx.helper.printable_graph(model.graph))

# You can also access specific details
# For example, inputs, outputs, nodes, etc.
print("Inputs:", [input.name for input in model.graph.input])
print("Outputs:", [output.name for output in model.graph.output])


Model is loaded
Model type: <class 'onnx.onnx_ml_pb2.ModelProto'>
graph main_graph (
  %onnx::Reshape_0[FLOAT, 1x1x28x28]
) initializers (
  %fc.weight[FLOAT, 10x784]
  %fc.bias[FLOAT, 10]
) {
  %/Constant_output_0 = Constant[value = <Tensor>]()
  %/Reshape_output_0 = Reshape[allowzero = 0](%onnx::Reshape_0, %/Constant_output_0)
  %5 = Gemm[alpha = 1, beta = 1, transB = 1](%/Reshape_output_0, %fc.weight, %fc.bias)
  return %5
}
Inputs: ['onnx::Reshape_0']
Outputs: ['5']


In [2]:
# Inspect the graph
graph = model.graph

# Nodes (Operations)
for i, node in enumerate(graph.node):
    print(f"Node {i}: {node.op_type}")
    print("  Inputs:", node.input)
    print("  Outputs:", node.output)

# Initializers (Model Parameters)
for init in graph.initializer:
    print("Initializer:", init.name)

# Model Inputs and Outputs
print("Model Inputs:", [input.name for input in graph.input])
print("Model Outputs:", [output.name for output in graph.output])


Node 0: Constant
  Inputs: []
  Outputs: ['/Constant_output_0']
Node 1: Reshape
  Inputs: ['onnx::Reshape_0', '/Constant_output_0']
  Outputs: ['/Reshape_output_0']
Node 2: Gemm
  Inputs: ['/Reshape_output_0', 'fc.weight', 'fc.bias']
  Outputs: ['5']
Initializer: fc.weight
Initializer: fc.bias
Model Inputs: ['onnx::Reshape_0']
Model Outputs: ['5']


In [3]:
for i, node in enumerate(graph.node[:3]):  # Adjust range as needed
    print(f"Node {i}: {node.op_type}")
    print("  Inputs:", node.input)
    print("  Outputs:", node.output)

    # Print attributes of the node
    for attr in node.attribute:
        print(f"  Attribute: {attr.name}")
        if attr.type == onnx.AttributeProto.TENSOR:
            # For 'Constant' node, print the tensor values
            tensor = onnx.numpy_helper.to_array(attr.t)
            print("  Tensor Value:", tensor)
        else:
            # Handle other attribute types as needed
            print("  Attribute Value:", attr)

Node 0: Constant
  Inputs: []
  Outputs: ['/Constant_output_0']
  Attribute: value
  Tensor Value: [ -1 784]
Node 1: Reshape
  Inputs: ['onnx::Reshape_0', '/Constant_output_0']
  Outputs: ['/Reshape_output_0']
  Attribute: allowzero
  Attribute Value: name: "allowzero"
type: INT
i: 0

Node 2: Gemm
  Inputs: ['/Reshape_output_0', 'fc.weight', 'fc.bias']
  Outputs: ['5']
  Attribute: alpha
  Attribute Value: name: "alpha"
type: FLOAT
f: 1

  Attribute: beta
  Attribute Value: name: "beta"
type: FLOAT
f: 1

  Attribute: transB
  Attribute Value: name: "transB"
type: INT
i: 1

