# Inference

## Introduction

Spox attempts to perform inference on operators immediately as they are constructed in Python.
This includes two main mechanisms: type (and shape) inference, and value propagation.

Both are done on a best-effort basis and primarily based on ONNX implementations.
Some type inference and value propagation routines may be _patched_ in the generated opset. This is a Python implementation within Spox. This attempts to follow the standard, but may also be imperfect and have bugs (as can be the standard ONNX implementations).

Inference mechanisms work effectively in various contexts. To make this work, Spox expects that type information will be carried in `Var`s through the entire graph, as it is constructed. This enables raising Python exceptions as early as possible when type inference fails, as well as improving debug and reducing the requirement of specifying redundant type information in Python.

The general mechanism of this is the following: the single standard node is built into a _singleton_ (single-node) model as `onnx.ModelProto`. This is then passed into `onnx` routines. Afterwards, the information is extracted and converted back into Spox.

In [1]:
import numpy
import spox
import spox.opset.ai.onnx.v17 as op

## Type inference

Type and shape inference is run via `onnx.shape_inference.infer_shapes` on the singleton model. Types are converted to and from the ONNX representation internally. Some operators may have missing or incomplete type inference implementations (especially in ML operators of `ai.onnx.ml`), and may have a patch implemented in Spox.

Patches can be currently found as an `infer_output_types` implementation in the respective Node class.

In [2]:
x = spox.argument(spox.Tensor(float, ('N',)))
y = spox.argument(spox.Tensor(float, ()))
z = spox.argument(spox.Tensor(int, ('N', 'M')))

In [3]:
# Broadcasting of (N) and () into (N)
op.add(x, y)

<Var from ai.onnx@14::Add->C of float64[N]>

In [4]:
# Casting element type with a Cast
op.cast(z, to=str)

<Var from ai.onnx@13::Cast->output of str[N][M]>

In [5]:
# Reshape of a matrix into a vector
op.reshape(z, op.constant(value_ints=[-1]))

<Var from ai.onnx@14::Reshape->reshaped of int64[?]>

In [6]:
# Using a broadcast of (1, N) and (N, 1) into (N, N)
op.add(
    op.unsqueeze(x, op.constant(value_ints=[0])),
    op.unsqueeze(x, op.constant(value_ints=[1]))
)

<Var from ai.onnx@14::Add->C of float64[N][N]>

> Spox does not have a facility for type hinting ``Var`` objects to perform type inference on the level of Python annotations. This is because the ONNX type system, and in particular typing tensors, is not expressible in type hints beyond possibly tensor element types.
> This may be reconsidered in the future if libraries like ``numpy`` start supporting similar functions.

## Value propagation

Value propagation in Spox is run via the `onnx.reference` module (added in 1.13) - the reference runtime implementation in Python. It replicates the _partial data propagation_ mechanism of type inference in ONNX, which is essentially constant folding.

In Spox, a ``Var`` may have a constant value associated with it. If all input variables of a standard operator have a value, propagation will be attempted by running the singleton model through the reference implementation.

The most common instance of value propagation is in the ``Reshape`` operator, where a constant target shape allows determining the resulting shape. If the target shape were not known, even the rank of the output shape could not be determined.

Value propagation can also be thought of as **eager execution** mode within Spox, and is well-suited for experimenting with (standard) operators.

Currently, there isn't a standard way of accessing the propagated value. It can be viewed when printed.
Value propagation isn't usually patched as in most cases it is not critical to type inference. It is implemented by overriding the `propagate_values` method of Node classes.

In [7]:
# Can't perform value propagation - type inference fails & warns
t = spox.argument(spox.Tensor(int, (None,)))
op.reshape(x, t)

  op.reshape(x, t)


<Var from ai.onnx@14::Reshape->reshaped of float64[...]>

In [8]:
# Trivial reshape example
op.reshape(x, op.constant(value_ints=[1, 2, 3]))

<Var from ai.onnx@14::Reshape->reshaped of float64[1][2][3]>

In [9]:
s = op.add(
    op.mul(op.constant(value_ints=[1, 2]), op.constant(value_int=2)),
    op.constant(value_int=1)
)  # [1, 2] * 2 + 1 = [3, 5]
# Reshape with a basic constant fold
op.reshape(x, s)

<Var from ai.onnx@14::Reshape->reshaped of float64[3][5]>

Constant variable values can also be seen in the string representation.
Currently, there isn't a stable way of accessing them programmatically (the internal field is `_value`, but the representation isn't publicly specified).

In [10]:
def const(xs):
    return op.constant(value=numpy.array(xs))

In [11]:
# Trivial add
op.add(
    const(1),
    const([1, 2, 3])
)

<Var from ai.onnx@14::Add->C of int64[3] = [2 3 4]>

In [12]:
# Reshape
mat = op.reshape(
    const([1., 2., 3., 4.]),
    const([2, 2])
)
mat

<Var from ai.onnx@14::Reshape->reshaped of float64[2][2] = [[1. 2.]
 [3. 4.]]>

In [13]:
# Composing value propagation
op.matmul(mat, mat)

<Var from ai.onnx@13::MatMul->Y of float64[2][2] = [[ 7. 10.]
 [15. 22.]]>

In [14]:
# Unstable! Programmatic access
_value = op.add(
    const(1),
    const([1, 2, 3])
)._value
assert (_value.value == numpy.array([2, 3, 4])).all()
_value

PropValue(type=Tensor(dtype=int64, shape=(3,)), value=array([2, 3, 4]))