Skip to content

Commit 51bc961

Browse files
committed
Postpone setting dim variables to their lower bounds till the very end.
1 parent 89301ce commit 51bc961

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

lib/row.ml

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ let rec row_conjunction ?(id = phantom_row_id) stage constr1 constr2 =
664664
elems_mismatch numerator (Num_elems known_product)
665665
else None))
666666

667-
let rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
667+
let%track5_sexp rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
668668
(constr : dim_constraint) (env : environment) : constraint_ list * dim_constraint =
669669
let extras, constr =
670670
match (dim, constr) with
@@ -696,10 +696,12 @@ let rec apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim)
696696
| _ -> Option.value ~default:([], constr) @@ dim_conjunction constr bounds.constr))
697697
| _, Unconstrained_dim -> ([], constr)
698698
in
699-
match (dim, constr, stage) with
699+
(* FIXME: *)
700+
(* match (dim, constr, stage) with
700701
| Var _, At_least_dim d, Stage4 ->
701702
(Dim_eq { d1 = dim; d2 = get_dim ~d () } :: extras, Unconstrained_dim)
702-
| _ -> (extras, constr)
703+
| _ -> *)
704+
(extras, constr)
703705

704706
exception Given_up
705707

@@ -2012,7 +2014,7 @@ let%track5_sexp close_row_terminal ~(stage : stage) (env : environment)
20122014
[%log6 "terminal row: keeping", (_r : row), "as", (r1 : row)];
20132015
Terminal_row r1 :: term_dims ())
20142016

2015-
let%debug5_sexp eliminate_dim_entry v ~lub constr =
2017+
let%debug5_sexp eliminate_dim_entry ~final v ~lub constr =
20162018
match (lub, constr) with
20172019
| _, Unconstrained_dim | _, At_least_dim 1 -> None
20182020
| Some (Dim { d; _ } as lub), At_least_dim d2 when d2 > d ->
@@ -2021,7 +2023,7 @@ let%debug5_sexp eliminate_dim_entry v ~lub constr =
20212023
( [%string "dereferenced at dimension %{d2#Int}, higher than use site"],
20222024
[ Dim_mismatch [ lub; Var v ] ] )
20232025
| Some lub, At_least_dim _ -> Some (Dim_eq { d1 = Var v; d2 = lub })
2024-
| None, At_least_dim d -> Some (Dim_eq { d1 = Var v; d2 = get_dim ~d () })
2026+
| None, At_least_dim d -> if final then Some (Dim_eq { d1 = Var v; d2 = get_dim ~d () }) else None
20252027

20262028
let%debug5_sexp eliminate_variables (env : environment) ({ dims; bcast; id } as _r : row) :
20272029
constraint_ list =
@@ -2030,8 +2032,8 @@ let%debug5_sexp eliminate_variables (env : environment) ({ dims; bcast; id } as
20302032
Some
20312033
(match Map.find env.dim_env v with
20322034
| Some (Bounds_dim { lub; constr; _ }) ->
2033-
Option.value_or_thunk (eliminate_dim_entry v ~lub constr) ~default:(fun () ->
2034-
Dim_eq { d1; d2 = get_dim ~d:1 () })
2035+
Option.value_or_thunk (eliminate_dim_entry ~final:true v ~lub constr)
2036+
~default:(fun () -> Dim_eq { d1; d2 = get_dim ~d:1 () })
20352037
| Some (Solved_dim _) -> assert false
20362038
| None -> Dim_eq { d1; d2 = get_dim ~d:1 () })
20372039
| _ -> None
@@ -2109,13 +2111,14 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
21092111
match stage with
21102112
| Stage1 | Stage2 | Stage3 | Stage6 | Stage7 -> solve ineqs env
21112113
| Stage4 ->
2112-
let finalize_lower_bound (v : dim_var) = function
2113-
| Bounds_dim { lub; constr; _ } -> Option.to_list @@ eliminate_dim_entry v ~lub constr
2114+
let finalize_upper_lower_bound (v : dim_var) = function
2115+
| Bounds_dim { lub; constr; _ } ->
2116+
Option.to_list @@ eliminate_dim_entry ~final:false v ~lub constr
21142117
| _ -> []
21152118
in
21162119
let finalizing_entries : constraint_ list =
21172120
Map.fold env.dim_env ~init:[] ~f:(fun ~key ~data accu ->
2118-
finalize_lower_bound key data @ accu)
2121+
finalize_upper_lower_bound key data @ accu)
21192122
in
21202123
solve (finalizing_entries @ ineqs) env
21212124
| Stage5 ->

lib/shape_inference.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ The constraints are solved by: unification of the equation constraints, unificat
163163
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
164164
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) (FIXME: reconsider this, see the algo for row variables: a new LUB can still be inferred.) Forces coefficients coming from precision byte sizes.
165165
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
166-
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any, otherwise to the lower bound. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.)
166+
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.)
167167
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
168168
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality.
169-
* Stage 7 sets all dim resp. row variables remaining in updated shapes to dimension-1 resp. no-further-axes.
169+
* Stage 7 sets all dim variables remaining in updated shapes to the lower bound if they have any, otherwise to dimension-1. It sets all row variables remaining in updated shapes to no-further-axes.
170170

171171
Let's explain the shape inference functions.
172172

0 commit comments

Comments
 (0)