Skip to content

Commit b93d4fa

Browse files
committed
Fix dimension inference staging: do not close dims at stage 2
1 parent 4e85a0f commit b93d4fa

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

docs/shape_inference.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ The constraints are solved by: unification of the equation constraints, unificat
173173
Simplification of an inequality, and constraint propagation, can generate more constraints, so we need to be careful to keep it terminating. The solution proceeds in stages. Currently there are 8 stages, with a fractional stage coming from splitting an earlier design.
174174

175175
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
176-
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) (FIXME: reconsider this, see the algo for row variables: a new LUB can still be inferred.) Forces coefficients coming from precision byte sizes.
177-
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds LUB (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint. In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock.
178-
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
176+
* Stage 2, forces coefficients coming from precision byte sizes.
177+
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds LUB (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes dimension variables in terminal shapes that do not have a LUB by dim 1. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint. In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock.
178+
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, or terminal ones, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
179179
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
180180
* Stage 6 sets row variables in the remaining inequalities and updated shapes to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
181181
* Stage 7 sets all dim variables remaining in updated shapes to the lower bound if they have any, otherwise to dimension-1.

tensor/row.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,11 +2200,11 @@ let%debug5_sexp close_dim_terminal ~(stage : stage) origin (env : environment) (
22002200
| Var v -> (
22012201
match Map.find env.dim_env v with
22022202
| Some (Solved_dim _) -> assert false
2203-
| Some (Bounds_dim { lub = None; constr = Unconstrained_dim; _ }) when is_stage2_up stage ->
2203+
| Some (Bounds_dim { lub = None; constr = Unconstrained_dim; _ }) when is_stage3_up stage ->
22042204
[ Dim_eq { d1 = dim; d2 = get_dim ~d:1 (); origin } ]
2205-
| Some (Bounds_dim { lub = Some lub; _ }) when is_stage3_up stage ->
2205+
| Some (Bounds_dim { lub = Some lub; _ }) when is_stage4_up stage ->
22062206
[ Dim_eq { d1 = dim; d2 = lub; origin } ]
2207-
| _ when not (is_stage4_up stage) -> [ Terminal_dim (dim, origin) ]
2207+
| _ when not (is_stage5_up stage) -> [ Terminal_dim (dim, origin) ]
22082208
| _ -> [])
22092209
| Conv_input _ ->
22102210
(* The input dimension itself cannot be dim-1, and the output dimension doesn't become

0 commit comments

Comments
 (0)