Skip to content

Commit a2471cf

Browse files
committed
Shape inference: more agressive solving of Total_elems rows constraints that takes Least Upper Bounds into account
1 parent 3023848 commit a2471cf

File tree

2 files changed

+76
-17
lines changed

2 files changed

+76
-17
lines changed

lib/row.ml

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,28 @@ and apply_row_constraint ~depth stage (r : row) (constr : row_constraint) env :
12801280
:: List.map2_exn exact_dims (beg_dims @ dims) ~f:(fun d1 d2 -> Dim_eq { d1; d2 })
12811281
@ extras,
12821282
env )
1283+
| ( { bcast = Row_var { v; _ }; _ },
1284+
Total_elems { numerator = Strided_var { coeff; var = _; denom }; divided_by = [] } )
1285+
when is_stage2_up stage -> (
1286+
(* Check if we have a LUB and if it meets our conditions *)
1287+
match Map.find env.row_env v with
1288+
| Some (Bounds_row { lub = Some ({ dims = lub_dims; bcast = Broadcastable; _ } as lub); _ })
1289+
when Utils.is_safe_val coeff && Utils.safe_force coeff > denom -> (
1290+
(* Check if all LUB dimensions are known *)
1291+
match collect_factors lub_dims with
1292+
| Some (_known_product, []) ->
1293+
(* Check if LUB has at most one dimension greater than 1 *)
1294+
let greater_than_one =
1295+
List.filter lub_dims ~f:(function Dim { d; _ } -> d > 1 | _ -> false)
1296+
in
1297+
if List.length greater_than_one <= 1 then
1298+
(Row_eq { r1 = row_of_var v r.id; r2 = lub } :: extras, env)
1299+
else if stored then (extras, env)
1300+
else (Rows_constr { r = [ r ]; constr } :: extras, env)
1301+
| _ ->
1302+
if stored then (extras, env) else (Rows_constr { r = [ r ]; constr } :: extras, env)
1303+
)
1304+
| _ -> if stored then (extras, env) else (Rows_constr { r = [ r ]; constr } :: extras, env))
12831305
| { bcast = Row_var _; _ }, _ | _, Total_elems { numerator = _; divided_by = _ } ->
12841306
if stored then (extras, env)
12851307
else (Rows_constr { r = [ r ]; constr } :: extras, env (* Wait for more shape inference. *))
@@ -2004,26 +2026,43 @@ let%track5_sexp rec eliminate_rows_constraint ~depth stage ~lub (rows : row list
20042026
| _, { Row_id.kind = `Output; _ } -> true
20052027
| _ -> false) )
20062028
with
2007-
| Total_elems _, Some (idx, (v, id)) when is_stage5_up stage ->
2008-
let other_vars = List.filteri rev_row_vars ~f:(fun i _ -> i <> idx) in
2009-
let other_vars = List.map other_vars ~f:(fun (v, id) -> row_of_var v id) in
2010-
let other_eqs =
2011-
List.map other_vars ~f:(fun r ->
2012-
Row_eq { r1 = r; r2 = { dims = []; bcast = Broadcastable; id } })
2013-
in
2014-
let rows =
2015-
List.map rows ~f:(function
2016-
| { bcast = Row_var { v = v'; _ }; _ } as r when equal_row_var v' v -> r
2017-
| r -> { r with dims = r_dims r; bcast = Broadcastable })
2029+
| Total_elems _, Some (idx, (v, _id)) when is_stage4_up stage ->
2030+
(* TODO: in stage 4, consider restricting to a strided dimension variable case. *)
2031+
let other_vars : (row_var * Row_id.t) list =
2032+
List.filteri rev_row_vars ~f:(fun i _ -> i <> idx)
20182033
in
2019-
let ineqs, env =
2020-
eliminate_rows_constraint ~depth:(depth + 1) stage ~lub rows constr env
2034+
let other_eqs : constraint_ list =
2035+
List.concat_map other_vars ~f:(fun (v, id) ->
2036+
if
2037+
is_stage5_up stage
2038+
||
2039+
match Map.find env.row_env v with
2040+
| None
2041+
| Some
2042+
(Bounds_row { lub = None | Some { dims = []; bcast = Broadcastable; _ }; _ })
2043+
->
2044+
true
2045+
| _ -> false
2046+
then
2047+
let r1 = row_of_var v id in
2048+
[ Row_eq { r1; r2 = { dims = []; bcast = Broadcastable; id } } ]
2049+
else [])
20212050
in
2022-
(other_eqs @ ineqs, env)
2051+
if is_stage5_up stage then
2052+
let rows =
2053+
List.map rows ~f:(function
2054+
| { bcast = Row_var { v = v'; _ }; _ } as r when equal_row_var v' v -> r
2055+
| r -> { r with dims = r_dims r; bcast = Broadcastable })
2056+
in
2057+
let ineqs, env =
2058+
eliminate_rows_constraint ~depth:(depth + 1) stage ~lub rows constr env
2059+
in
2060+
(other_eqs @ ineqs, env)
2061+
else (other_eqs @ [ Rows_constr { r = rows; constr } ], env)
20232062
| _ -> ([ Rows_constr { r = rows; constr } ], env))
20242063

2025-
and eliminate_row_constraint ~depth stage ~terminal ~lub (r : row) (constr : row_constraint) env :
2026-
constraint_ list * environment =
2064+
and eliminate_row_constraint ~depth stage ~terminal ~(lub : row option) (r : row)
2065+
(constr : row_constraint) env : constraint_ list * environment =
20272066
let keep_constr () =
20282067
let ineqs, env = apply_row_constraint ~depth stage r constr env in
20292068
List.fold ineqs ~init:([], env) ~f:(fun (ineqs, env) ineq ->
@@ -2043,6 +2082,7 @@ and eliminate_row_constraint ~depth stage ~terminal ~lub (r : row) (constr : row
20432082
(* Note: the reduced constraint applies to just the row variable. *)
20442083
match reduce_row_constraint constr ~beg_dims ~dims with
20452084
| Total_elems { numerator; divided_by } -> (
2085+
let _divided_by : dim_var list = divided_by in
20462086
match (numerator, divided_by, lub) with
20472087
| Num_elems 1, vs, _ when is_stage5_up stage ->
20482088
( no_further_axes
@@ -2100,6 +2140,25 @@ and eliminate_row_constraint ~depth stage ~terminal ~lub (r : row) (constr : row
21002140
Row_eq { r1; r2 = { dims = []; bcast = Broadcastable; id } };
21012141
],
21022142
env )
2143+
| ( Strided_var { coeff; var = _; denom },
2144+
[],
2145+
Some ({ dims = lub_dims; bcast = _; id = lub_id } as lub) )
2146+
when is_stage5_up stage && Utils.safe_force coeff > denom -> (
2147+
(* Check if coeff > denom * product of known dimensions of the LUB *)
2148+
match collect_factors lub_dims with
2149+
| Some (known_product, []) ->
2150+
let coeff_val = Utils.safe_force coeff in
2151+
if coeff_val > denom * known_product then ([ Row_eq { r1; r2 = lub } ], env)
2152+
else
2153+
(* Equate the row variable to the dimensions of the LUB *)
2154+
( [ Row_eq { r1; r2 = { dims = lub_dims; bcast = Broadcastable; id = lub_id } } ],
2155+
env )
2156+
| _ -> keep_constr ())
2157+
| Strided_var { coeff; var; denom }, _, _ when is_stage5_up stage ->
2158+
let _var : dim_var = var in
2159+
let _coeff : int = Utils.safe_force coeff in
2160+
let _denom : int = denom in
2161+
keep_constr ()
21032162
| _ -> keep_constr ())
21042163
| Exact dims -> ([ Row_eq { r1; r2 = { dims; bcast = Broadcastable; id } } ], env)
21052164
| Unconstrained -> ([], env))

lib/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ Simplification of an inequality, and constraint propagation, can generate more c
167167
* Stage 1 is online as tensors are composed, and conservatively performs unification and constraint propagation. Stages 2, 3, 4 are only performed once necessary: when projections or dimensions are requested.
168168
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) (FIXME: reconsider this, see the algo for row variables: a new LUB can still be inferred.) Forces coefficients coming from precision byte sizes.
169169
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
170-
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
170+
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.) In Total_elems constraints with multiple row variables, it substitutes row variables originating from axes of non-output kind, and which do not have a LUB, by no-further-axes -- otherwise these constraints can be too hard to unlock. At this stage, we inject `Shape_row` constraints into the inequalities, so that we can re-process the variables of interest without traversing the whole environment.
171171
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
172172
* Stage 6 sets row variables in the remaining inequalities and updated shapes to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
173173
* Stage 7 sets all dim variables remaining in updated shapes to the lower bound if they have any, otherwise to dimension-1.

0 commit comments

Comments
 (0)