You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/shape_inference.md
+35-30Lines changed: 35 additions & 30 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -6,11 +6,13 @@ Shape inference broadly speaking consists in OCANNL of inferring the `Shape.t` r
6
6
7
7
The bulk of the projections inference happens alongside shape inference, with the projections-relevant information stored in auxiliary fields -- this prevents subtle bugs where projection semantics deviates from shape semantics, and will simplify adding new shape/projection inference features. Shape inference happens during `propagate_shapes` calls, and then again in a `finish_inference` call, which is triggered whenever the dimensions or projections are required (i.e. typically by jitting). Finally, the projections are reconstructed in `derive_projections`. It would seem `derive_projections` could reuse the already-computed solutions constraints. But we face a problem: we must prevent contaminating projections across different operations. To illustrate: we conclude the dimensions of two axes are the same because they are reduced together in another operation -- this should not force the axes to share a projection in the processed operation. To prevent the contamination, in each `derive_projections` call, we freshen the projection ids in the (inferred) shapes, and regenerate and re-solve the constraints with the fresh projection ids.
8
8
9
+
The shape system in OCANNL is currently monomorphic: both row and dimension variables are interpreted existentially. It can in principle be made polymorphic: by abstracting over the remaining fresh variables when forming a tensor-producing function, and by replacing universally bound variables by fresh variables when applying such functions. However, this is non-trivial and would depend on introducing namespaces for tensor nodes. Then, we could perform "abstract interpretation" (aka. tracing like e.g. in JAX) by computing an OCaml function under an abstract tensor node namespace. Applying the function would not execute the OCaml code again, but instead would copy the tensors generated by "abstract interpretation"-stage execution with appropriately freshened shape variables into a concrete tensor node namespace. There is a natural context for introducing such abstraction: the special `~config` labeled functions as processed by the `%op` syntax extension -- see [syntax extensions](./syntax_extensions.md). Exploring this is left as potential future work (no earlier than OCANNL v2).
10
+
9
11
## Representing shapes and constraints
10
12
11
-
A tensor shape in OCANNL is composed of three rows of axes: batch, input and output. These are ordered input-last (`batch @ output @ input`) in the underlying n-dimensional array implementation of tensors. A (fully inferred) tensor shape must have non-empty output axes; we do not use the convention where empty axes mean the tensor is a scalar -- scalars = 1-D output-only tensors. For printing and einsum-notation-like specifications, we use the syntax: `batch|input->output` (or `input->output`, `batch|output`, `output`), where `batch`, `input`, `output` are whitespace or comma or parenthesis separated axis entries; or the axis entries are the individual characters, if no separators are used (except if it's digits only).
13
+
A tensor shape in OCANNL is composed of three rows of axes: batch, input and output. These are ordered input-last (`batch @ output @ input`) in the underlying n-dimensional array implementation of tensors (at least when hosted, as backends can reorder axes via a stride mechanism NOTE: NOT IMPLEMENTED YET). A (fully inferred) tensor shape must have non-empty output axes; we do not use the convention where empty axes mean the tensor is a scalar -- scalars = 1-D output-only tensors. For printing and einsum-notation-like specifications, we use the syntax: `batch|input->output` (or `input->output`, `batch|output`, `output`), where `batch`, `input`, `output` are whitespace or comma or parenthesis separated axis entries; or the axis entries are the individual characters, if no separators are used (except if it's digits only).
12
14
13
-
A row is a sequence of axes of a single kind: batch, input, or output. The shape type incorporates information relevant to inference, in particular shape variables: both for individual axes (`dim` variables), and for extending a row with more axes (`row` variables). Currently, all rows are (independently) broadcastable: can be broadcasted to a larger number of axes. However, in OCANNL the broadcasting can happen "in the middle", with not only the given trailing axes fixed, but also with the given leading axes fixed.
15
+
A row is a sequence of axes of a single kind: batch, input, or output. The shape type incorporates information relevant to inference, in particular shape variables: both for individual axes (`dim` variables), and for extending a row with more axes (`row` variables). Currently, all rows are (independently) broadcastable: can be broadcasted to a larger number of axes. However, in OCANNL the broadcasting can happen "in the middle", with not only the given trailing axes fixed, but also with the given leading axes fixed. (TODO: clarify here the precise logic as it is implemented, I'm not sure this description is correct.)
14
16
15
17
```ocaml
16
18
type solved_dim = { d : int; label : string option; proj_id : proj_id option }
@@ -93,40 +95,40 @@ The entry point to shape inference is the shape logic specification, that each o
93
95
type deduce_within_shape = Not_constrained | Input_equals_output
94
96
95
97
type compose_type =
96
-
| Pointwise_bin (** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *)
98
+
| Pointwise_bin
99
+
(** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *)
97
100
| Compose
98
-
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape of
99
-
[fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix multiply). *)
101
+
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape
102
+
of [fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix
103
+
multiply). *)
100
104
| Einsum of string
101
-
(** The [einsum] syntax: LABELS1;LABELS2=>LABELS3, where LABELSi are labels specifications.
102
-
Since OCANNL's extended einsum notation supports both axis variables and row variables, it makes
103
-
other compose types redundant.
104
-
The [axis_labels] use pseudo-labels local to the notation, to line up the axes.
105
-
For [Einsum (ls1^";"^ls2^"=>"^ls3)], the symmetric difference / disjunctive union of [ls1] and [ls2]'s
106
-
pseudo-labels should be equal to [ls3] pseudo-labels.
105
+
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
106
+
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
107
+
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
108
+
notation, to line up the axes and row variables. The symmetric difference / disjunctive
109
+
union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
107
110
108
-
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs". *)
111
+
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs",
112
+
"rhs1;rhs2=>lhs". *)
109
113
110
114
type transpose_type =
111
115
| Transpose (** Swaps inputs and outputs of a shape, preserves batch axes. *)
112
116
| Pointwise_un (** Preserves the shape. *)
113
-
| Permute of string
114
-
(** [Permute (ls1^"=>"^ls2)] is a variant of the [einsum] syntax [Einsum (ls1^";"^ls1^"=>"^ls2)].
115
-
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs". *)
117
+
| Permute of string (** The unary "einsum" syntax: RHS1=>LHS. *)
116
118
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)
117
-
118
-
type logic =
119
-
| Broadcast of compose_type * shape * shape
120
-
(** Matches the shapes for a binary operation.
121
-
122
-
For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match according
123
-
to the [ls1], [ls2] lineup, and the resulting shape inherits the labels according to the [ls3] lineup.
124
-
*)
125
-
| Transpose of transpose_type * shape
126
-
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of [s1],
127
-
hence the name. *)
128
-
| Terminal of [ `Data of Ir.Assignments.init_data | `Fetch of Ir.Assignments.fetch_op ]
129
-
(** Extracts any available shape information from the initialization, e.g. the number of elements. *)
119
+
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
120
+
(** Converts precision in a bit-effient way, with a corresponding conversion in total number
121
+
of elements. Currently, assumes the incoming tensor (RHS) has just a single axis to not
122
+
force unnecessary minimum sizes on output axes. *)
123
+
124
+
(** If you miss expressivity here, leave a note on
(** The rows, inclusive of the further row spec, have this many elements. *)
147
+
| Exact of dim list (** The concatenated rows have these axes. *)
143
148
```
144
149
145
150
During the solution process, the constraints are incorporated, or propagated, into the environment `constr` entry fields, and into further `constraint_` constraints, as needed. This provides sufficient scaffolding to implement the other complex constraints as the need arises.
0 commit comments