Skip to content

Commit 72b981e

Browse files
committed
Fill-in eliminate_rows_constraint multi-row-var coverage
1 parent 19637e0 commit 72b981e

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

arrayjit/lib/utils.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,8 @@ let safe_force gated =
952952
gated.value <- `Value v;
953953
v
954954

955+
let is_safe_val = function { value = `Value _; _ } -> true | _ -> false
956+
955957
let safe_map ~upd ~f gated =
956958
let unique_id = gated.unique_id ^ "_" ^ upd in
957959
match gated.value with

lib/row.ml

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint)
871871
(* Cannot deduce no_further_axes, return unchanged *)
872872
([ Rows_constr { r = rows; constr } ], env))
873873
| Exact [ single_dim ] -> (
874-
(* Keep existing logic for single_dim case *)
874+
(* Shapes must have non-empty output rows. *)
875875
match List.rev rows with
876876
| { dims = []; bcast = Broadcastable; id = _ } :: more_rows ->
877877
apply_rows_constraint ~stage (List.rev more_rows) constr env
@@ -985,7 +985,8 @@ and apply_row_constraint ~stage (r : row) (constr : row_constraint) env : constr
985985
( List.map ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () }) (vs1 @ vs2)
986986
@ extras,
987987
env )
988-
| Strided_var { coeff; var; denom }, [], [ v ] when equal_dim_var var v ->
988+
| Strided_var { coeff; var; denom }, [], [ v ]
989+
when equal_dim_var var v && (Utils.is_safe_val coeff || is_stage2_up stage) ->
989990
(* Total = (coeff * v / denom) / v = coeff / denom *)
990991
if Utils.safe_force coeff % denom = 0 then
991992
( Dim_eq { d1 = Var v; d2 = get_dim ~d:(Utils.safe_force coeff / denom) () }
@@ -1890,11 +1891,35 @@ let%debug5_sexp close_dim_terminal ~(stage : stage) (env : environment) (dim : d
18901891

18911892
let last_dim_is dims d2 = match List.last dims with Some (Dim { d; _ }) -> d = d2 | _ -> false
18921893

1894+
let r_dims r =
1895+
match r.bcast with Broadcastable -> r.dims | Row_var { beg_dims; _ } -> beg_dims @ r.dims
1896+
18931897
let rec eliminate_rows_constraint ~lub (rows : row list) (constr : row_constraint)
18941898
(env : environment) : constraint_ list =
18951899
match rows_to_row_or_vars rows with
18961900
| Either.First single_row -> eliminate_row_constraint ~lub single_row constr env
1897-
| Either.Second (_all_dims, _row_vars) -> [ Rows_constr { r = rows; constr } ]
1901+
| Either.Second (_all_dims, row_vars) -> (
1902+
let rev_row_vars = List.rev row_vars in
1903+
match
1904+
( constr,
1905+
List.findi rev_row_vars ~f:(fun _ -> function
1906+
| _, { kind = `Output; _ } -> true
1907+
| _ -> false) )
1908+
with
1909+
| Total_elems _, Some (idx, (v, id)) ->
1910+
let other_vars = List.filteri rev_row_vars ~f:(fun i _ -> i <> idx) in
1911+
let other_vars = List.map other_vars ~f:(fun (v, id) -> row_of_var v id) in
1912+
let other_eqs =
1913+
List.map other_vars ~f:(fun r ->
1914+
Row_eq { r1 = r; r2 = { dims = []; bcast = Broadcastable; id } })
1915+
in
1916+
let rows =
1917+
List.map rows ~f:(function
1918+
| { bcast = Row_var { v = v'; _ }; _ } as r when equal_row_var v' v -> r
1919+
| r -> { r with dims = r_dims r; bcast = Broadcastable })
1920+
in
1921+
other_eqs @ eliminate_rows_constraint ~lub rows constr env
1922+
| _ -> [ Rows_constr { r = rows; constr } ])
18981923

18991924
and eliminate_row_constraint ~lub (r : row) (constr : row_constraint) env : constraint_ list =
19001925
match r with
@@ -1935,9 +1960,15 @@ and eliminate_row_constraint ~lub (r : row) (constr : row_constraint) env : cons
19351960
(* We can't determine the exact shape without knowing var *)
19361961
[]
19371962
| Strided_var { coeff; var; denom }, [ v ], _ when equal_dim_var var v ->
1938-
(* coeff * var / var = coeff *)
1939-
no_further_axes
1940-
:: [ Dim_eq { d1 = Var v; d2 = get_dim ~d:(Utils.safe_force coeff / denom) () } ]
1963+
let coeff = Utils.safe_force coeff in
1964+
if coeff % denom <> 0 then
1965+
raise
1966+
@@ Shape_error
1967+
( [%string "Strided_var constraint: %{coeff#Int} not divisible by %{denom#Int}"],
1968+
[ Row_mismatch [ r ] ] )
1969+
else
1970+
(* coeff * var / var = coeff *)
1971+
no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d:(coeff / denom) () } ]
19411972
| _ -> [])
19421973
| Exact dims -> [ Row_eq { r1; r2 = { dims; bcast = Broadcastable; id } } ]
19431974
| _ -> [])

lib/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ The constraints are solved by: unification of the equation constraints, unificat
164164
* Stage 2, when solving the constraints, substitutes dim variables in terminal shapes that do not have a LUB or other constraints, by dimension-1. (This is generalized at stage 6 to all variables.) It substitutes row variables in terminal shapes that do not have a LUB by one axis if that's required to satisfy the variable's constraint.
165165
* Stage 3, when solving the constraints, sets yet-unknown dimension and row variables in terminal shapes to their least upper bounds (if any). It substitutes row variables in terminal shapes that do not have a LUB by no-further-axes. (This is generalized at stage 5 to all variables.)
166166
* Stage 4 sets yet-unknown dimensions with >1 lower bounds from direct accesses, to their LUBs if they have any, otherwise to the lower bound.
167-
* Stage 5 addresses `Total_elems` constraints with yet-unknown row variables. If the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension.
167+
* Stage 5 addresses `Total_elems` and `Exact` constraints with yet-unknown row variables. For `Total_elems` and a single row variable: if the constraint can be satisfied by assuming the row variable is no-further-axes, it sets the row variable to `Broadcastable`, otherwise it sets it to one axis of the required dimension. For multiple row variables, if one is of the Output kind, sets the other variables to no-further-axes, and retries.
168168
* Stage 6 sets row variables in the remaining inequalities to no-further-axes values. This can unlock further between-axis inequalities because of row variables sandwiched between leftmost axes from their side of the inequality and rightmost axes from the other side of the inequality.
169169
* Stage 7 sets all dim resp. row variables remaining in updated shapes to dimension-1 resp. no-further-axes.
170170

0 commit comments

Comments
 (0)