- Tracing problem overview
- Definitions
- Generalisation of the tracing problem
- Catalyst implementation details
- Caveats
Consider the following Python program:
@qjit
def circuit(sz:int):
a0 = jnp.ones([sz], dtype=float)
@for_loop(0, 10, 1)
def loop(i, a):
return a*sz
a2 = loop(a0)
return a0 + a2
qjit
decorator at the top means that we are going to compile this program rather than interpret
it directly. In order to do it, Catalyst performs a series of transformations:
- Trace Python program in order to obtain Jaxpr program
- Lower Jaxpr program fither into the StableHLO MLIR dialect
- Apply a series of MLIR passes in order to lower the StableHLO into the LLVM MLIR dialect
- Emit the LLVM code and compile it into the machine's native binary, rendered as a shared library
This document explains the Tracing step of the workflow, with emphasis on the Jax dynamic API support.
Since we specified argument type in our program, namely sz:int
, Catalyst has compiled the program
already. To revisit tracing results, lets print an equivalent in the Jaxpr language, the main IR
language of the Jax library:
print(circuit.jaxpr)
{ lambda ; a:i64[]. let
b:f64[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 a
_:i64[] c:f64[a] = for_loop[
apply_reverse_transform=False
body_jaxpr={ lambda ; d:i64[] e:i64[] f:i64[] g:f64[e]. let
h:f64[] = convert_element_type[new_dtype=float64 weak_type=False] d
i:f64[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] h e
j:f64[e] = mul g i
in (e, j) }
body_nconsts=1
nimplicit=1
preserve_dimensions=True
] a a 0 10 1 0 b
k:f64[a] = add b c
in (k,) }
The following points are important to note:
- Jaxpr program has the following structure:
where operation attributes might contain nested Jaxpr programs.
{ lambda CONSTS ; ARGUMENTS . let A0 A1 ... An = opA[ ATTRIBUTES ] OPERANDS ... Z0 Z1 ... Zn = opZ[ ATTRIBUTES ] OPERANDS in Z }
- All Jaxpr variables have types.
x:f64[3,2]
means thatx
is a 2D tensor with dimensions 3 and 2 consisting of 64-bit floating point elements.a:i64[]
means thata
is a scalar integer tensor. - With dynamic API flag set to True (Catalyst sets it by default) Jaxpr types are allowed to
contain variables.
f:64[d]
means that the single dimension off
is not known at compile time. What is known is that at runtime the actual dimension will be available as variabled
. - Loop body of the source Python program has two arguments, while the
body_jaxpr
of the resulting Jaxpr program has four arguments. The additional arguments appeared due to different reasons:d:i64[]
: Usage of the outer-scope variable in the body loop in the source Python program. Jaxpr program does not allow capturing, so we have to pass captured variables as additional arguments.e:i64[]
: Requirement saying that Jaxpr variable must be declared before use. Since we usee
in the type ofg:f64[e]
, we must pass it as an additional argument before this use.
- Loop argument
b:f64[a]
and the loop resultc:f64[a]
have the same type. Jax takes special care of propagating types across operations. Jax binary operators likeadd
,mull
require operand types to be the same. Since the types ofb
andc
are indeed the same, the very last addition operation is possible. - In contrast to the regular Python evaluation, loop body is evaluated only once during the tracing. This is because we only want to record the execution path rather then perform the real computation.
In this section we define concepts required to describe the tracing algorithm.
Jax is a library which turns interpreters into compilers by means of tracing. Jax supports tracing of programs in two “source” languages:
- Python: regular Python interpreter with a custom Numpy API implementation are used for tracing.
- Jaxpr: the IR language with the interpreter implemented in Python.
Tracers are objects which track arguments of the traced program. By means of tracers, Jax transforms interpretation of a program into compilation.
Jax tracers (source) are Python classes having the following properties:
- Tracers are in 1-to-1 correspondence with Python and Jaxpr variables. If the dynamic api is enabled, tracers might contain other tracers, so referring a variable by name would mean referring to more than one tracer.
- Contain
AbstractValue
objects (typically, theDShapedArrays
) in theiraval
field. - Jax tracks unique identifiers of tracers in order to distinguish them from each other.
- Tracers typically belong to a
Trace
object (source) representing variable scope. Creating nested scopes requires users to also create tracers.
Examples:
Tracer(id=1111, aval=ShapedArray(shape=[3,4], dtype=int))
Tracer(id=2222, aval=DShapedArray(shape=[ Tracer(id=... aval=...) ], dtype=float))
AbstractValue is a Python class describing shape
(a list of dimensions) and dtype
(other
features omitted). Jax comes with two notable implementations:
- ShapedArray
(source)
is the implementation supporting constant dimensions.
- Example:
ShapedArray(shape=[3,4],dtype=int)
- Example:
- DShapedArray (source) is an implementation where, in addition to constants, variable dimensions are allowed in shapes. Dynamic dimensions are represented by other scalar tracers which become valid shape values. This implementation becomes available if the Dynamic API flag of JAX is enabled.
DShapedArray
(source)
are abstract values describing tracers whose dimensions might be unknown at the compile time. Like
for ShapedArray
, the main fields are shape
and dtype
. Unlike ShapedArrays, this class allows
more freedom in the shape
contents. Possible elements are:
- Numbers representing static dimensions.
- If used in the
aval
field of a tracer, other tracers are allowed in shapes. Nested tracers must be scalar (of unit shape) andint
dtype. - If used in a type signature (see below), de Brujin indices InDBIdx(val), OutDBIdx(val) and DBIdx(val) are allowed as values and have special meaning described below.
Examples:
DShapedArray(shape=[3,4],dtype=int)
- DShapedArrays are mostly backward compatible with ShapedArraysDShapedArray(shape=[Tracer(id=23232, aval=ShapedArray(shape=(),dtype=int)),4,1],dtype=float)
- shape might contain scalar integer tracers if this object is an abstract value of a tracer.DShapedArray(shape=[InDBIdx(val=0),InDBIdx(val=1)],dtype=float)
- shape might contain de Bruijn indices if this object is contained in a type signature.
In Jax, in_type/out_type
objects are list-like tuples of abstract values paired with Booleans.
Types are mainly used to transfer the information about tracers between scopes. Types are typically
deduced from the in the source scope and then interpreted in a target scope. The results of this
interpretation are new tracers living in the target scope.
- The MyPy type of types is
Tuple[Tuple[AbstractValue,bool],...]
link - The
bool
tuple elements represent “explicitness” of argument in a source Python program. Explicit arguments explicitly appear in user-defined Python functions. Implicit arguments, in contrast, are added by Jax to match Jaxpr typing requirements.
The important properties of types are the following one. Abstract values found in types are allowed
to include indices (called in Jax "de Brjuin indices" for some reason). The indices are described by
objects of DBIdx
, InDBIdx
and OutDBIdx
classes, all having their purpose.
Example:
((DShapedArray(shape=(), dtype=int), False),
(DShapedArray(shape=(), dtype=int), False),
(DShapedArray(shape=[OutDBIdx(0),OutDBIdx(1),InDBIdx(0)], dtype=float), True)
)
The above tuple may represent an output type of a Jaxpr program returning a 3D tensor along with the two of its three dynamic dimensions. The last dimension is assumed to be taken from the first input argument of the current program.
DBIdx(val)
are input type references. They are allowed in the shape values of DShapedArray
objects, found in the in_type signature of Python/Jaxpr programs. The integer value of a
reference are interpreted as an index in the current in_type
signature tuple.
Input type indices are:
- Produced when calculating the input type of a nested program source
- Consumed when creating argument tracers of a nested program source
InDBIdx(val)/OutDBIdx(val)
are output type references. references allowed in the shape values of DShapedArray
objects found in the out_type signatures of Python/Jaxpr programs.
- InDBIdx(val) refers to the position in the
in_type
signature tuple of the Jaxpr/Python program. - OutDBIdx(val) refers to the position in the
out_type
signature tuple of the Jaxpr/Python program.
Output type indices are:
- Produced when calculating the output type of a nested program source
- Consumed when creating output tracers in the scope of an enclosing program source
Binding, in Jax terms, is the process where a primitive applies itself to the interpreter stack. The binding includes:
- Calculation of the input type from the primitive's input arguments.
- Creation of the inner tracing scope and the corresponding tracers.
- Calculation of the output type based on the results of the nested program.
- Creation of the resulting tracers in the caller's tracing context.
Separating explicit and implicit arguments makes sense when we trace Python program. Explicit arguments/results are those which were explicitly mentioned in the source Python program. Implicit arguments are those that have to be added in order to fit into Jaxpr typing requirements.
For example, in the following Python program:
def f(sz):
o = jnp.ones((sz+1,), dtype=float)
return o
We map the Python tracer o
to the two Jaxpr variables b:i64[], c:f[b]
to get the following
equivalent Jaxpr program.
{ lambda ; a:i64[]. let
b:i64[] = add a 1
c:f64[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b
in (b, c) }
The first variable b:i64[]
is the implicit result, while the second one is the explicit one.
Python and Jaxpr programs represent function arguments and results in a different level of details. In Python tensors are objects holding all the required information about there shapes. In Jaxpr, in contrast, shape dimension variables must be passes explicitly as additional arguments. In order to encode the program transformation, we attribute argument lists as collapsed or expanded, depending on whether implicit arguments were added or not.
In Jax the expansion is performed simply by adding together the tuple of deduced implicit values with another tuple of known explicit values.
We name lists or tuples as expanded_
if the implicit arguments are known to be already prepended
to them.
In this section we attempt to generalise the tracing problem. Consider the following schematic Python program:
def nested_function(*ARGS):
... # Calculate RESULTS from ARGS
return RESULTS
INPUTS = ... # Obtain INPUTS somehow
OUTPUTS = bind(nested_function, *INPUTS)
Notes:
- In this example,
nested_function
plays the role offor_loop
body,cond
branch oradjoint
region. bind
represents the binding API function, e.g.for_loop
.
We want to transform this program into the following Jaxpr program (also schematic):
{ lambda ; INPUTS . let
OUTPUTS = bind[
nested_function = { lambda ; ARGS . let
... // calculate RESULTS
in RESULTS };
] INPUTS;
in OUTPUTS }
In order to do so, Jax evaluates the source Python program by passing tracers objects as INPUTS. All operations applied to the tracers, including the nested function call, are recorded into the internal Jax equation list, which is then used to print the final Jaxpr program.
In the above example, bind is the most important operation, joining the outer and inner scope tracing into a single recursive tracing algorithm.
Below we describe one step of this algorithm:
-
$bind(Function, Inputs, S) -> Outputs_s$ , where:$(ExpandedInputs_s, InputType_s) \gets expandArgs(Inputs, strategy = S)$ -
$OutputType_s \gets AbstractEvaluation(InputType_s)$ where$AbstractEvaluation$ is defined as follows:$ExpandedArguments_s \gets initialize(InputType_s)$ $Arguments \gets collapse(ExpandedArguments_s)$ $Results \gets traceNested(Function, Arguments)$ $OutputType_s \gets expandResults(ExpandedArguments_s, Results)$ $return(OutputType_s)$
$ExpandedOutputs_s \gets initialize(OutputType_s, ExpandedInputs_s)$ $Outputs_s \gets collapse(ExpandedOutputs_s)$ $return(Outputs_s)$
The variations of this algorithm are implemented in every Catalyst binding function such as
for_loop
, while_loop
, cond
, etc. The common properties, however, are the same.
-
$Inputs$ are input tracers obtained from the user. -
$Outputs_s$ represents output tracers of a Python program obtained using the expansion strategyS
. Any set of tracers might be converted to a Jaxpr program at any time using the core Jax IR printer functionto_jaxpr
. Thus, having output tracers is equivalent to having the Jaxpr program. The expansion strategy captures specifics of the particular binding function. -
$expandArgs()$ determines the implicit parameters using the specified expansion strategy$S$ and calculates the input type signature. Source -
$expandResults()$ calculates implicit output variables and obtains the final output type signature. Source -
$initialize()$ reads the input type information and creates the required tracers in the inner tracing context. Note that the function interprets de Brjuin indices which might exist in inputs. Source (inputs) Source (outputs) -
$traceNested()$ runs the next recursion step of the tracing. It takes collapsed (not-expanded) list of input tracers and calculates the list of output tracers.
As illustrated in the overview, in order to transform Python program into a Jaxpr program, arguments and results of functions needs to be adjusted in the following ways:
- Dynamic dimensions must be added as implicit arguments to ensure Jaxpr program types are correct.
- Additional "constant" parameters must be added to hoist numeric constants and capture outer-scope variables.
- We start from the state where the explicit arguments/results mentioned in the source Python program are known.
- The expansion algorithm scans the explicit dimensions and prepends variables to the list of
explicit arguments.
- In the basic case, the variables found in the dimensions are added as-is.
- In case when several tensors use same dimension variables, different decisions are possible. In
Catalyst, we support the following two cases:
- (Default) Add only one implicit argument for shared dimension variable. For example:
a:f64[d], b:f64[d]
becomesd:i64[], a:f64[d], b:f64[d]
- "Forget" about the sharing and add new variable for every dimension variable separately.
For example:
a:f64[d], b:f64[d]
becomesd1:i64[], d2:i64[], a:f64[d1], b:f64[d2]
This mode is enabled ifexperimental_preserve_dimensions
parameter is set toFalse
.
- (Default) Add only one implicit argument for shared dimension variable. For example:
- Produce types (
in_type
for arguments,out_type
for results), describing the expansion results.- For arguments, we use
DBIdx
in types to refer to position in the same list. For example:- Arguments:
d:i64[], a:f64[d], b:f64[d]
- Input type:
[(i64, True), (f64[DBIdx(0)], False), (f64[DBIdx(0)], False)]
- Arguments:
- For results, we use
InDBIdx
andOutDBIdx
in type to refer to positions in the argument list and the result list (the current one) correspondingly. For example:- Arguments:
v:i64[], d:i64[], a:f64[d], b:f64[d]
- Results:
e:i64[], a:f64[e], b:f64[d]
- Output type:
[(i64, True), (f64[OutDBIdx(0)], False), (f64[InDBIdx(1)], False)]
- Arguments:
- For arguments, we use
In Catalyst, we usually record the number of implicit variables added using
num_implicit_inputs
/num_implicit_outputs
attributes.
Loop primitives have notable additional requirements. In order to lower loop bodies, types and numbers of loop body arguments must match types and numbers of the loop body results.
In the presence of dynamic dimensions, Jax needs determine which dimensions are going to change during the loop iterations and which one remain the same. Unfortunately, in a single-pass tracer it is hard to communicate this information to the compiler. We only see iteration-0 arguments and results and in general we can not extrapolate this information to later iterations.
We developed the following compromise in order to handle this situation:
-
By default, we assume that loop results will keep the same dimension sharing pattern as loop arguments. For example:
@for_loop(0, 10, 1) def loop(i, a, a_): return a, a_ loop(a0,a0) # CORRECT: one shared dimension in both inputs and outputs
@for_loop(0, 10, 1) def loop(i, a, a_): b = jnp.ones([sz+1], dtype=float) return b, b loop(a0,a0) # CORRECT: still one shared dimension
@for_loop(0, 10, 1) def loop(i, a, a_): b = jnp.ones([sz+1], dtype=float) return a, b # ERROR: dimensions are not the same any more loop(a0,a0)
-
With
experimental_preserve_dimensions=False
flag, we treat every same dimension as a 0-iteration conincidense. We create separate dimension during the argument/result expansion.@for_loop(0, 10, 1, experimental_preserve_dimensions=False) def loop(i, a, b, b_): return a, b, b_ # CORRECT # BUT `b + b_` is not possible, because `b` and `b_` now has different dimensions b0 = jnp.ones([sz+1], dtype=float) a2, b2, b2_ = loop(a0, b0, b0)
A special for for-loops: loop index variable could not be referred using InDBIdx
index in output
types. For example, in the following program
@for_loop(0, 10, 1)
def loop(i, a):
b = jnp.ones([i], dtype=float)
return b
Output type will contain OuDBIdx
in the dimension of b
.
Deduction of Jax constants happens during the final step of the tracing - at the same time with the Jaxpr program generation. Constants are prepended to argument lists. Thus, the program with arguments
d:i64[], a:f64[d], b:f64[d]
that captures a dinamically-dimensioned tensor o:f64[od]
from the outside scope might get the
following final list of arguments:
od:i64, o:f64[od], d:i64[], a:f64[d], b:f64[d]
In Catalyst, we usually record the number of constants added using body_nconsts
attribute. This
information is used during the StableHLO lowering.
-
Dimension variables obtained from constants never matches dimension variables from regular parameters. Thus, the following program will raise an error:
@qjit def circuit(sz:int): a0 = jnp.ones([sz], dtype=float) @for_loop(0, 10, 1) def loop(i, a): return a + a0 # a0 is a constant, dimension variable is different return loop(a0)