Skip to content

Commit 64f1ec1

Browse files
lukstaficlaude
andcommitted
Doc: clarify that einsum operations use equations, not inequalities
Einsum operations (both binary Einsum and unary Permute) generate Row_eq and Dim_eq constraints, not Row_ineq and Dim_ineq. This means they do NOT permit broadcasting, unlike Pointwise_bin, Pointwise_un, and Compose operations which use inequalities. Updated docs/shape_inference.md and tensor/shape.mli to: - Remove claim that einsum "makes other compose types redundant" - Clarify einsum is more restrictive (no broadcasting) but more precise - Update get_inequalities description to reflect equations for einsum 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 3659fa5 commit 64f1ec1

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

docs/shape_inference.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,14 @@ type compose_type =
143143
multiply). *)
144144
| Einsum of string * Ir.Indexing.variable_ref list
145145
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
146-
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
147-
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
148-
notation, to line up the axes and row variables. The symmetric difference / disjunctive
149-
union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
146+
OCANNL's extended einsum notation supports both axis variables and row variables.
147+
The [axis_labels] use pseudo-labels local to the notation, to line up the axes and row
148+
variables. The symmetric difference / disjunctive union of RHS1 and RHS2's pseudo-labels
149+
should be equal to LHS pseudo-labels.
150+
151+
Unlike [Pointwise_bin] and [Compose], einsum operations use equations only (not
152+
inequalities), so they do NOT permit broadcasting. This makes einsum more restrictive
153+
but also more precise for operations where exact shape matching is required.
150154
151155
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
152156
corresponding to the specification labels equal to [ref_label] of a reference.
@@ -161,6 +165,10 @@ type transpose_type =
161165
| Permute of string * Ir.Indexing.variable_ref list
162166
(** The unary "einsum" syntax: RHS1=>LHS.
163167
168+
Unlike [Pointwise_un], permute operations use equations only (not inequalities), so they
169+
do NOT permit broadcasting. This makes permute more restrictive but also more precise
170+
for operations where exact shape matching is required.
171+
164172
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
165173
corresponding to the specification labels equal to [ref_label] of a reference. *)
166174
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)
@@ -270,7 +278,7 @@ There is an important and intentional difference between `dims` in the `arrayjit
270278
Other important functions in the `Shape` module.
271279

272280
* `einsum_slot_spec_to_dims_bio` parses an einsum spec for a single shape, returns the three rows and a mapping from axis (`dim`) variables to indices where the einsum specifies fixed indexing.
273-
* `get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values`. For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
281+
* `get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values`. For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives equations (not inequalities), equating the current shape with the lhs spec shape, and the rhs spec shapes with the sub-shapes. This means einsum operations do NOT permit broadcasting, unlike pointwise and compose operations which use inequalities.
274282
* `propagate_shapes` gets and then solves the inequalities, using a global state for the environment. It udpates the shapes in-place with the partial solution. It is invoked twice for each `update_step`: first during the bottom-up process of building tensors, and then in reverse order from `finish_inference`.
275283
* `finish_inference` is called right before some projections or array dimensions are required (typically, because of jitting). It performs a second round of `propagate_shapes`, and then once again attempts to solve any remaining constraints that `propagate_shapes` didn't solve. Then it "closes the shapes": substitutes out remaining shape variables by their LUBs if any, or dimension-1 / `Broadcastable` (no-more-axes). Then it resets the environment state, since the shapes are now guaranteed to not have variables.
276284
* `derive_projections` starts by freshening the `proj_id`s in the `update_step`. Then it generates and solves shape inequalities, and then generates and solves projection equations, and constructs the `projections` record.

tensor/shape.ml

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -984,10 +984,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
984984
( proj_env,
985985
extras_dim_refs @ extras_rhs @ extras_lhs
986986
@ [
987-
Row_ineq
987+
Row_eq
988988
{
989-
cur = cur_sh.batch;
990-
subr = b_lhs;
989+
r1 = cur_sh.batch;
990+
r2 = b_lhs;
991991
origin =
992992
[
993993
{
@@ -1014,10 +1014,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
10141014
};
10151015
];
10161016
};
1017-
Row_ineq
1017+
Row_eq
10181018
{
1019-
cur = cur_sh.input;
1020-
subr = i_lhs;
1019+
r1 = cur_sh.input;
1020+
r2 = i_lhs;
10211021
origin =
10221022
[
10231023
{
@@ -1044,10 +1044,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
10441044
};
10451045
];
10461046
};
1047-
Row_ineq
1047+
Row_eq
10481048
{
1049-
cur = cur_sh.output;
1050-
subr = o_lhs;
1049+
r1 = cur_sh.output;
1050+
r2 = o_lhs;
10511051
origin =
10521052
[
10531053
{
@@ -1210,10 +1210,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
12101210
( proj_env,
12111211
extras_dim_refs @ extras_rhs1 @ extras_rhs2 @ extras_lhs
12121212
@ [
1213-
Row_ineq
1213+
Row_eq
12141214
{
1215-
cur = cur_sh.batch;
1216-
subr = b_lhs;
1215+
r1 = cur_sh.batch;
1216+
r2 = b_lhs;
12171217
origin =
12181218
[
12191219
{
@@ -1255,10 +1255,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
12551255
};
12561256
];
12571257
};
1258-
Row_ineq
1258+
Row_eq
12591259
{
1260-
cur = cur_sh.input;
1261-
subr = i_lhs;
1260+
r1 = cur_sh.input;
1261+
r2 = i_lhs;
12621262
origin =
12631263
[
12641264
{
@@ -1300,10 +1300,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
13001300
};
13011301
];
13021302
};
1303-
Row_ineq
1303+
Row_eq
13041304
{
1305-
cur = cur_sh.output;
1306-
subr = o_lhs;
1305+
r1 = cur_sh.output;
1306+
r2 = o_lhs;
13071307
origin =
13081308
[
13091309
{

tensor/shape.mli

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,14 @@ type compose_type =
113113
multiply). *)
114114
| Einsum of string * delayed_var_ref list
115115
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
116-
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
117-
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
118-
notation, to line up the axes and row variables. The symmetric difference / disjunctive
119-
union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
116+
OCANNL's extended einsum notation supports both axis variables and row variables.
117+
The [axis_labels] use pseudo-labels local to the notation, to line up the axes and row
118+
variables. The symmetric difference / disjunctive union of RHS1 and RHS2's pseudo-labels
119+
should be equal to LHS pseudo-labels.
120+
121+
Unlike [Pointwise_bin] and [Compose], einsum operations use equations only (not
122+
inequalities), so they do NOT permit broadcasting. This makes einsum more restrictive
123+
but also more precise for operations where exact shape matching is required.
120124
121125
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
122126
corresponding to the specification labels equal to [ref_label] of a reference.
@@ -131,6 +135,10 @@ type transpose_type =
131135
| Permute of string * delayed_var_ref list
132136
(** The unary "einsum" syntax: RHS1=>LHS.
133137
138+
Unlike [Pointwise_un], permute operations use equations only (not inequalities), so they
139+
do NOT permit broadcasting. This makes permute more restrictive but also more precise
140+
for operations where exact shape matching is required.
141+
134142
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
135143
corresponding to the specification labels equal to [ref_label] of a reference. *)
136144
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)

0 commit comments

Comments
 (0)