You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Don't guess dim variables as dim-1 prematurely, collab with Claude
If the variables participates in a `Total_elems` constraint (currently numerator only), it doesn't need to be guessed so shouldn't.
Claude wrote the docs update and helper functions.
Copy file name to clipboardExpand all lines: docs/shape_inference.md
+42-3Lines changed: 42 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -61,14 +61,51 @@ The shape and projection inference handles `Conv_input` terms differently depend
61
61
62
62
Shape inference does not maintain padding for axes of individual tensor nodes, these padding values are computed and updated during projections inference.
63
63
64
+
### Preventing Premature Guessing with Total_elems Constraints
65
+
66
+
A critical aspect of shape inference is avoiding premature "guessing" of dimension variables to minimal values (dimension-1 or no-further-axes for rows) when such guessing would make pending constraints unsatisfiable. This is particularly important for `Total_elems` constraints of the form:
where `var` appears in the numerator and other variables appear in `divided_by`.
73
+
74
+
The problem: If we guess `var` to be 1 prematurely (before the constraint is resolved), and the denominator variables are later inferred to have values > 1, the constraint becomes unsatisfiable. For example, if the constraint states `total = (64 * var) / [d1, d2]` and we guess `var = 1`, but later infer `d1 = 4` and `d2 = 1`, we get `total = 64 / 4 = 16`, which might not match the actual required total.
75
+
76
+
The solution: We use the `has_uniq_constr_unless` field in `Bounds_dim` entries to track these dependencies:
77
+
78
+
1.**Marking Phase**: When a `Total_elems` constraint with a `Strided_var` numerator is encountered, `mark_total_elems_vars` is called to mark the numerator variable with `has_uniq_constr_unless = Some (Set.of_list divided_by)`.
79
+
80
+
2.**Checking Phase**: Before guessing a variable to 1 in `close_dim_terminal` at Stage 3, we call `can_guess_dim_to_one` to check:
81
+
- If `has_uniq_constr_unless = None`: guessing is allowed (no restriction)
82
+
- If `has_uniq_constr_unless = Some unless_vars`: guessing is allowed only if at least one variable in `unless_vars` is also prevented from guessing (has its own `has_uniq_constr_unless` set)
83
+
84
+
3.**Cycle Breaking**: The "unless at least one is also prevented" condition prevents infinite prevention chains. If both numerator and denominator variables would block each other indefinitely, the condition allows progress by permitting guessing when mutual prevention is detected.
85
+
86
+
4.**Deferred Closing**: If a variable cannot be guessed, `close_dim_terminal` returns `Terminal_dim (is_param, dim, origin)` instead of `Dim_eq`, allowing later stages to retry after more constraints are resolved.
87
+
88
+
This mechanism ensures that `Total_elems` constraints with stride-based numerators are fully resolved before any involved variables are closed to minimal values, preventing the "shape cannot be strided" errors that would otherwise occur.
89
+
64
90
### Inference strategy
65
91
66
92
The actual shape inference combines row polymorphism with (nominal) subtyping, as known in the type inference literature. The subtyping stems merely from the fact that a dimension-1 axis can be used in the context of any dimension due to per-axis broadcasting. Row polymorphism stems from broadcasting to more axes: for example, when unifying an unknown (shape) row with a known one, we cannot assume that the unknown row will have just the axes of the known one, because maybe the known row is meant to be broadcasted here to more axes. The combination of row polymorphism with nominal subtyping means that the constraints we are solving are inequalities, both inequalities between rows (the `Row.t` type, i.e. the `row` type above), and between axes/dimensions (the `Row.dim` type). We maintain the inequality ordering between variables in the environment to compute the transitive closure during simplification. We also maintain a least upper bound on the solution.
67
93
68
94
```ocaml
69
95
type dim_entry =
70
96
| Solved_dim of dim
71
-
| Bounds_dim of { cur : dim_var list; subr : dim_var list; lub : dim option; constr : dim_constraint }
97
+
| Bounds_dim of {
98
+
is_in_param : bool;
99
+
has_uniq_constr_unless : dim_var_set option;
100
+
(** If set, the variable should not be guessed 1 unless a variable from the set is also
101
+
prevented from being guessed 1. Used to prevent premature guessing for variables in
102
+
Total_elems numerators. *)
103
+
cur : dim_var list;
104
+
subr : dim_var list;
105
+
lub : dim option;
106
+
constr : dim_constraint;
107
+
origin : constraint_origin list;
108
+
}
72
109
73
110
type row_entry =
74
111
| Solved_row of t
@@ -174,7 +211,7 @@ Simplification of an inequality, and constraint propagation, can generate more c
174
211
175
212
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
176
213
* Stage 2, forces coefficients coming from precision byte sizes.
177
-
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds LUB (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes dimension variables in terminal shapes that do not have a LUB by dim 1. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint. In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock.
214
+
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds LUB (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes dimension variables in terminal shapes that do not have a LUB by dim 1, **except** when the variable has `has_uniq_constr_unless` set (indicating it's in the numerator of a `Total_elems` constraint) and none of the denominator variables are also prevented from guessing -- this prevents premature guessing that would make `Total_elems` constraints unsatisfiable. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint. In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock.
178
215
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, or terminal ones, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
179
216
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
180
217
* Stage 6 sets row variables in the remaining inequalities and updated shapes to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
@@ -190,7 +227,9 @@ Let's explain the shape inference functions.
190
227
*`apply_dim_constraint` resp. `apply_row_constraint`: if they cannot make any progress on the constraint, they return `None`. Otherwise, they return a list of derived constraints, and an updated `dim_constraint` resp. `row_constraint`.
191
228
*`solve_dim_ineq`: solves a single inequality between two values of type `dim`; returns derived equations and inequalities. It maintains the between-variable bounds and the least-upper-bound (LUB). But there can only be one LUB (a dimension > 1) without forcing the bound variable itself to a solved form (with dimension = 1).
192
229
*`solve_row_ineq`: solves a single inequality between two rows; returns derived equations and inequalities. It derives between-`dim` inequalities from the known parts of the compared rows. It maintains between-row-variable bounds (when known parts of the rows match) and the LUB. It forces the `cur` side to have at least the number of axes of the `subr` side (via a variables-only `template`). It updates the LUB by computing dimensions-wise LUBs.
193
-
*`close_dim_terminal` and `close_row_terminal`: produce the equal-to-LUB constraint when available, from `Terminal_dim` and `Terminal_row` constraints produced for shapes of leaf tensors in tensor expressions, but only when `~stage:true`.
230
+
*`mark_total_elems_vars`: when a `Total_elems` constraint with `Strided_var { var; _ }` in the numerator and non-empty `divided_by` list is encountered, this function marks `var` with `has_uniq_constr_unless = Some (Set.of_list divided_by)`. This tracking ensures that numerator variables aren't prematurely guessed to 1, which would make the constraint unsatisfiable.
231
+
*`can_guess_dim_to_one`: checks if a dimension variable can be guessed to 1. A variable with `has_uniq_constr_unless` can only be guessed if at least one of the "unless" (denominator) variables is also prevented from guessing. This cycle-breaking condition prevents infinite prevention chains where both numerator and denominator would block each other.
232
+
*`close_dim_terminal` and `close_row_terminal`: produce the equal-to-LUB constraint when available, from `Terminal_dim` and `Terminal_row` constraints produced for shapes of leaf tensors in tensor expressions, but only when `~stage:true`. `close_dim_terminal` now calls `can_guess_dim_to_one` before guessing a variable to 1, ensuring `Total_elems` constraints can be satisfied.
194
233
*`solve_inequalities`: solves equations, inequalities, and row constraints, until only row constraints remain. Row constraints can "pass" if there is not enough information, rather than reflecting their effect in the environment. Calls `close_dim_terminal` and `close_row_terminal` as appropriate.
195
234
196
235
The rationale behind only closing leaf (terminal) tensor shapes to their LUBs, while closing the remaining ones to dim-1:
0 commit comments