Skip to content

Commit 97eb776

Browse files
committed
Shape inference doc small update: monomorphism, new type defs
1 parent 7674214 commit 97eb776

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

lib/shape_inference.md

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ Shape inference broadly speaking consists in OCANNL of inferring the `Shape.t` r
66

77
The bulk of the projections inference happens alongside shape inference, with the projections-relevant information stored in auxiliary fields -- this prevents subtle bugs where projection semantics deviates from shape semantics, and will simplify adding new shape/projection inference features. Shape inference happens during `propagate_shapes` calls, and then again in a `finish_inference` call, which is triggered whenever the dimensions or projections are required (i.e. typically by jitting). Finally, the projections are reconstructed in `derive_projections`. It would seem `derive_projections` could reuse the already-computed solutions constraints. But we face a problem: we must prevent contaminating projections across different operations. To illustrate: we conclude the dimensions of two axes are the same because they are reduced together in another operation -- this should not force the axes to share a projection in the processed operation. To prevent the contamination, in each `derive_projections` call, we freshen the projection ids in the (inferred) shapes, and regenerate and re-solve the constraints with the fresh projection ids.
88

9+
The shape system in OCANNL is currently monomorphic: both row and dimension variables are interpreted existentially. It can in principle be made polymorphic: by abstracting over the remaining fresh variables when forming a tensor-producing function, and by replacing universally bound variables by fresh variables when applying such functions. However, this is non-trivial and would depend on introducing namespaces for tensor nodes. Then, we could perform "abstract interpretation" (aka. tracing like e.g. in JAX) by computing an OCaml function under an abstract tensor node namespace. Applying the function would not execute the OCaml code again, but instead would copy the tensors generated by "abstract interpretation"-stage execution with appropriately freshened shape variables into a concrete tensor node namespace. There is a natural context for introducing such abstraction: the special `~config` labeled functions as processed by the `%op` syntax extension -- see [syntax extensions](./syntax_extensions.md). Exploring this is left as potential future work (no earlier than OCANNL v2).
10+
911
## Representing shapes and constraints
1012

11-
A tensor shape in OCANNL is composed of three rows of axes: batch, input and output. These are ordered input-last (`batch @ output @ input`) in the underlying n-dimensional array implementation of tensors. A (fully inferred) tensor shape must have non-empty output axes; we do not use the convention where empty axes mean the tensor is a scalar -- scalars = 1-D output-only tensors. For printing and einsum-notation-like specifications, we use the syntax: `batch|input->output` (or `input->output`, `batch|output`, `output`), where `batch`, `input`, `output` are whitespace or comma or parenthesis separated axis entries; or the axis entries are the individual characters, if no separators are used (except if it's digits only).
13+
A tensor shape in OCANNL is composed of three rows of axes: batch, input and output. These are ordered input-last (`batch @ output @ input`) in the underlying n-dimensional array implementation of tensors (at least when hosted, as backends can reorder axes via a stride mechanism NOTE: NOT IMPLEMENTED YET). A (fully inferred) tensor shape must have non-empty output axes; we do not use the convention where empty axes mean the tensor is a scalar -- scalars = 1-D output-only tensors. For printing and einsum-notation-like specifications, we use the syntax: `batch|input->output` (or `input->output`, `batch|output`, `output`), where `batch`, `input`, `output` are whitespace or comma or parenthesis separated axis entries; or the axis entries are the individual characters, if no separators are used (except if it's digits only).
1214

13-
A row is a sequence of axes of a single kind: batch, input, or output. The shape type incorporates information relevant to inference, in particular shape variables: both for individual axes (`dim` variables), and for extending a row with more axes (`row` variables). Currently, all rows are (independently) broadcastable: can be broadcasted to a larger number of axes. However, in OCANNL the broadcasting can happen "in the middle", with not only the given trailing axes fixed, but also with the given leading axes fixed.
15+
A row is a sequence of axes of a single kind: batch, input, or output. The shape type incorporates information relevant to inference, in particular shape variables: both for individual axes (`dim` variables), and for extending a row with more axes (`row` variables). Currently, all rows are (independently) broadcastable: can be broadcasted to a larger number of axes. However, in OCANNL the broadcasting can happen "in the middle", with not only the given trailing axes fixed, but also with the given leading axes fixed. (TODO: clarify here the precise logic as it is implemented, I'm not sure this description is correct.)
1416

1517
```ocaml
1618
type solved_dim = { d : int; label : string option; proj_id : proj_id option }
@@ -93,40 +95,40 @@ The entry point to shape inference is the shape logic specification, that each o
9395
type deduce_within_shape = Not_constrained | Input_equals_output
9496
9597
type compose_type =
96-
| Pointwise_bin (** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *)
98+
| Pointwise_bin
99+
(** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *)
97100
| Compose
98-
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape of
99-
[fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix multiply). *)
101+
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape
102+
of [fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix
103+
multiply). *)
100104
| Einsum of string
101-
(** The [einsum] syntax: LABELS1;LABELS2=>LABELS3, where LABELSi are labels specifications.
102-
Since OCANNL's extended einsum notation supports both axis variables and row variables, it makes
103-
other compose types redundant.
104-
The [axis_labels] use pseudo-labels local to the notation, to line up the axes.
105-
For [Einsum (ls1^";"^ls2^"=>"^ls3)], the symmetric difference / disjunctive union of [ls1] and [ls2]'s
106-
pseudo-labels should be equal to [ls3] pseudo-labels.
105+
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
106+
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
107+
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
108+
notation, to line up the axes and row variables. The symmetric difference / disjunctive
109+
union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
107110
108-
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs". *)
111+
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs",
112+
"rhs1;rhs2=>lhs". *)
109113
110114
type transpose_type =
111115
| Transpose (** Swaps inputs and outputs of a shape, preserves batch axes. *)
112116
| Pointwise_un (** Preserves the shape. *)
113-
| Permute of string
114-
(** [Permute (ls1^"=>"^ls2)] is a variant of the [einsum] syntax [Einsum (ls1^";"^ls1^"=>"^ls2)].
115-
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs". *)
117+
| Permute of string (** The unary "einsum" syntax: RHS1=>LHS. *)
116118
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)
117-
118-
type logic =
119-
| Broadcast of compose_type * shape * shape
120-
(** Matches the shapes for a binary operation.
121-
122-
For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match according
123-
to the [ls1], [ls2] lineup, and the resulting shape inherits the labels according to the [ls3] lineup.
124-
*)
125-
| Transpose of transpose_type * shape
126-
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of [s1],
127-
hence the name. *)
128-
| Terminal of [ `Data of Ir.Assignments.init_data | `Fetch of Ir.Assignments.fetch_op ]
129-
(** Extracts any available shape information from the initialization, e.g. the number of elements. *)
119+
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
120+
(** Converts precision in a bit-effient way, with a corresponding conversion in total number
121+
of elements. Currently, assumes the incoming tensor (RHS) has just a single axis to not
122+
force unnecessary minimum sizes on output axes. *)
123+
124+
(** If you miss expressivity here, leave a note on
125+
{{:https://github.com/ahrefs/ocannl/issues/305}issue 305}. *)
126+
type ternary_type =
127+
| Pointwise_tern (** As in the operation [Where]. *)
128+
| Compose_accumulate (** As in the operation [FMA]. *)
129+
130+
(** Extracts any available shape information from the initialization or fetch. *)
131+
type terminal_type = Data of Ir.Assignments.init_data | Fetch of Ir.Assignments.fetch_op
130132
```
131133

132134
### Non-tensor-like constraints
@@ -136,10 +138,13 @@ The above mechanisms (excluding `dim_constraint` and `row_constraint`) are suffi
136138
```ocaml
137139
type dim_constraint = Unconstrained_dim | At_least_dim of int
138140
141+
type total_elems = Num_elems of int | Delayed of { coeff : int Lazy.t; var : dim_var }
142+
139143
type row_constraint =
140144
| Unconstrained
141-
| Total_elems of { nominator : int; divided_by : dim_var list }
142-
(** The row or remainder of a row, inclusive of the further row spec, has this many elements. *)
145+
| Total_elems of { nominator : total_elems; divided_by : dim_var_set }
146+
(** The rows, inclusive of the further row spec, have this many elements. *)
147+
| Exact of dim list (** The concatenated rows have these axes. *)
143148
```
144149

145150
During the solution process, the constraints are incorporated, or propagated, into the environment `constr` entry fields, and into further `constraint_` constraints, as needed. This provides sufficient scaffolding to implement the other complex constraints as the need arises.

0 commit comments

Comments
 (0)