You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Doc: clarify that einsum operations use equations, not inequalities
Einsum operations (both binary Einsum and unary Permute) generate
Row_eq and Dim_eq constraints, not Row_ineq and Dim_ineq. This means
they do NOT permit broadcasting, unlike Pointwise_bin, Pointwise_un,
and Compose operations which use inequalities.
Updated docs/shape_inference.md and tensor/shape.mli to:
- Remove claim that einsum "makes other compose types redundant"
- Clarify einsum is more restrictive (no broadcasting) but more precise
- Update get_inequalities description to reflect equations for einsum
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
Copy file name to clipboardExpand all lines: docs/shape_inference.md
+13-5Lines changed: 13 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -143,10 +143,14 @@ type compose_type =
143
143
multiply). *)
144
144
| Einsum of string * Ir.Indexing.variable_ref list
145
145
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
146
-
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
147
-
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
148
-
notation, to line up the axes and row variables. The symmetric difference / disjunctive
149
-
union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
146
+
OCANNL's extended einsum notation supports both axis variables and row variables.
147
+
The [axis_labels] use pseudo-labels local to the notation, to line up the axes and row
148
+
variables. The symmetric difference / disjunctive union of RHS1 and RHS2's pseudo-labels
149
+
should be equal to LHS pseudo-labels.
150
+
151
+
Unlike [Pointwise_bin] and [Compose], einsum operations use equations only (not
152
+
inequalities), so they do NOT permit broadcasting. This makes einsum more restrictive
153
+
but also more precise for operations where exact shape matching is required.
150
154
151
155
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
152
156
corresponding to the specification labels equal to [ref_label] of a reference.
@@ -161,6 +165,10 @@ type transpose_type =
161
165
| Permute of string * Ir.Indexing.variable_ref list
162
166
(** The unary "einsum" syntax: RHS1=>LHS.
163
167
168
+
Unlike [Pointwise_un], permute operations use equations only (not inequalities), so they
169
+
do NOT permit broadcasting. This makes permute more restrictive but also more precise
170
+
for operations where exact shape matching is required.
171
+
164
172
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
165
173
corresponding to the specification labels equal to [ref_label] of a reference. *)
166
174
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)
@@ -270,7 +278,7 @@ There is an important and intentional difference between `dims` in the `arrayjit
270
278
Other important functions in the `Shape` module.
271
279
272
280
*`einsum_slot_spec_to_dims_bio` parses an einsum spec for a single shape, returns the three rows and a mapping from axis (`dim`) variables to indices where the einsum specifies fixed indexing.
273
-
*`get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values`. For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape ≥ lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
281
+
*`get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values`. For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives equations (not inequalities), equating the current shape with the lhs spec shape, and the rhs spec shapes with the sub-shapes. This means einsum operations do NOT permit broadcasting, unlike pointwise and compose operations which use inequalities.
274
282
*`propagate_shapes` gets and then solves the inequalities, using a global state for the environment. It udpates the shapes in-place with the partial solution. It is invoked twice for each `update_step`: first during the bottom-up process of building tensors, and then in reverse order from `finish_inference`.
275
283
*`finish_inference` is called right before some projections or array dimensions are required (typically, because of jitting). It performs a second round of `propagate_shapes`, and then once again attempts to solve any remaining constraints that `propagate_shapes` didn't solve. Then it "closes the shapes": substitutes out remaining shape variables by their LUBs if any, or dimension-1 / `Broadcastable` (no-more-axes). Then it resets the environment state, since the shapes are now guaranteed to not have variables.
276
284
*`derive_projections` starts by freshening the `proj_id`s in the `update_step`. Then it generates and solves shape inequalities, and then generates and solves projection equations, and constructs the `projections` record.
0 commit comments