Skip to content

Commit a15aeae

Browse files
committed
Fix insufficient propagation when Total_elems has both a row and a dim variable
1 parent 400f967 commit a15aeae

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

arrayjit/lib/utils.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,3 +1004,7 @@ let sexp_of_safe_lazy sexp_of_elem gated =
10041004
Sexp.List [ Sexp.Atom "id"; Sexp.Atom gated.unique_id ];
10051005
Sexp.List [ Sexp.Atom "value"; status ];
10061006
]
1007+
1008+
let gcd a b =
1009+
let rec loop a b = if b = 0 then a else loop b (a % b) in
1010+
loop (abs a) (abs b)

lib/row.ml

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,15 @@ type error_trace +=
247247
| Row_mismatch of t list
248248
| Dim_mismatch of dim list
249249
| Index_mismatch of Idx.axis_index list
250+
| Rows_constr_failed of { constr : row_constraint }
250251

251252
let sexp_of_error_trace = function
252253
| Row_mismatch rs -> Sexp.List (Sexp.Atom "Row_mismatch" :: List.map rs ~f:sexp_of_t)
253254
| Dim_mismatch ds -> Sexp.List (Sexp.Atom "Dim_mismatch" :: List.map ds ~f:sexp_of_dim)
254255
| Index_mismatch idcs ->
255256
Sexp.List (Sexp.Atom "Index_mismatch" :: List.map idcs ~f:Idx.sexp_of_axis_index)
257+
| Rows_constr_failed { constr } ->
258+
Sexp.List (Sexp.Atom "Rows_constr_failed" :: [ sexp_of_row_constraint constr ])
256259
| _ -> Sexp.Atom "<outdated version of sexp_of_error_trace>"
257260

258261
exception Shape_error of string * error_trace list [@@deriving sexp_of]
@@ -894,7 +897,13 @@ let subst_row_constraint_impl ~subst_in_dim ~get_dim_val stage constr =
894897
let tot = Utils.safe_force coeff * dim in
895898
reapply_rows_constr := true;
896899
if tot % denom = 0 then subst_total_elems_divided_by (Num_elems (tot / denom)) divided_by
897-
else raise @@ Shape_error ("Total_elems constraint: shape cannot be strided", [])
900+
else
901+
raise
902+
@@ Shape_error
903+
( [%string
904+
"Total_elems constraint: shape cannot be strided, %{tot#Int} not divisible by \
905+
%{denom#Int}"],
906+
[ Rows_constr_failed { constr } ] )
898907
| Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by }
899908
when not (equal_dim (Var var) (subst_in_dim (Var var))) -> (
900909
reapply_rows_constr := true;
@@ -925,7 +934,7 @@ let s_dim_one_in_row_constr stage v ~value constr =
925934
if equal_dim_var v v' then match value with Dim { d; _ } -> Some d | _ -> None else None
926935
in
927936
subst_row_constraint_impl
928-
~subst_in_dim:(fun in_ -> s_dim_one v ~value ~in_)
937+
~subst_in_dim:(fun in_ -> s_dim_one ~keep_conv:true v ~value ~in_)
929938
~get_dim_val stage constr
930939

931940
let ineqs_from_reapply_rows_constr = ref []
@@ -1208,7 +1217,10 @@ and apply_row_constraint ~depth stage (r : row) (constr : row_constraint) env :
12081217
else
12091218
raise
12101219
@@ Shape_error
1211-
("apply_row_constraint: Total_elems constraint failed", [ Row_mismatch [ r ] ])
1220+
( [%string
1221+
"apply_row_constraint: Total_elems constraint failed: %{denom*d#Int} not \
1222+
divisible by %{coeff#Int}"],
1223+
[ Row_mismatch [ r ] ] )
12121224
| { dims; bcast = Broadcastable; _ }, Total_elems { numerator; divided_by }
12131225
when List.length divided_by <= 1 -> (
12141226
try
@@ -1321,6 +1333,7 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
13211333
ineqs := more_ineqs @ !ineqs;
13221334
result
13231335
in
1336+
ineqs_from_reapply_rows_constr := [];
13241337
let env =
13251338
match Map.find env.dim_env v with
13261339
| None ->
@@ -2060,7 +2073,27 @@ and eliminate_row_constraint ~depth stage ~lub (r : row) (constr : row_constrain
20602073
if d = d2 && is_stage5_up stage then no_further_axes else Row_eq { r1; r2 }
20612074
in
20622075
(row_eq :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d:(d / d2) () } ], env)
2063-
| Strided_var { coeff = _; var = _; denom = _ }, _, _ -> keep_constr ()
2076+
| Strided_var { coeff; var; denom }, [], None
2077+
when is_stage5_up stage
2078+
&& (Utils.safe_force coeff > denom || denom % Utils.safe_force coeff <> 0) ->
2079+
let coeff = Utils.safe_force coeff in
2080+
let gcd = Utils.gcd coeff denom in
2081+
let d = denom / gcd in
2082+
let d2 = get_dim ~d () in
2083+
let d3 = get_dim ~d:(coeff / gcd) () in
2084+
( [
2085+
Dim_eq { d1 = Var var; d2 };
2086+
Row_eq { r1; r2 = { dims = [ d3 ]; bcast = Broadcastable; id } };
2087+
],
2088+
env )
2089+
| Strided_var { coeff; var; denom }, [], _
2090+
when is_stage6_up stage && denom % Utils.safe_force coeff = 0 ->
2091+
let d2 = get_dim ~d:(denom / Utils.safe_force coeff) () in
2092+
( [
2093+
Dim_eq { d1 = Var var; d2 };
2094+
Row_eq { r1; r2 = { dims = []; bcast = Broadcastable; id } };
2095+
],
2096+
env )
20642097
| _ -> keep_constr ())
20652098
| Exact dims -> ([ Row_eq { r1; r2 = { dims; bcast = Broadcastable; id } } ], env)
20662099
| Unconstrained -> ([], env))

lib/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ The constraints are solved by: unification of the equation constraints, unificat
165165
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any), but for rows only if they don't have a `Total_elems 1` constraint. It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
166166
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any. It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 6 to all variables.)
167167
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
168-
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality.
168+
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality. In row constraints, this also unlocks inference for the embedded dim variables.
169169
* Stage 7 sets all dim variables remaining in updated shapes to the lower bound if they have any, otherwise to dimension-1. It sets all row variables remaining in updated shapes to no-further-axes.
170170

171171
Let's explain the shape inference functions.

0 commit comments

Comments
 (0)