Skip to content

Commit 9a988fd

Browse files
committed
Perform eliminate_rows_constraint even on standalone constraints, work around looping on no-progress constraints with depth tracking
1 parent dc14f2c commit 9a988fd

File tree

1 file changed

+42
-34
lines changed

1 file changed

+42
-34
lines changed

lib/row.ml

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,8 @@ let check_empty_row r =
829829
[ Row_eq { r1 = row_of_var v r.id; r2 = { dims = []; bcast = Broadcastable; id = r.id } } ]
830830
else raise @@ Shape_error ("check_empty_row: row is not empty", [ Row_mismatch [ r ] ])
831831

832-
let rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint) (env : environment)
833-
: constraint_ list * environment =
832+
let%track5_sexp rec apply_rows_constraint ~stage (rows : row list) (constr : row_constraint)
833+
(env : environment) : constraint_ list * environment =
834834
match rows_to_row_or_vars rows with
835835
| Either.First single_row -> apply_row_constraint stage single_row constr env
836836
| Either.Second (all_dims, row_vars) -> (
@@ -1904,42 +1904,46 @@ let last_dim_is dims p = match List.last dims with Some (Dim { d; _ }) -> p d |
19041904
let r_dims r =
19051905
match r.bcast with Broadcastable -> r.dims | Row_var { beg_dims; _ } -> beg_dims @ r.dims
19061906

1907-
let rec eliminate_rows_constraint stage ~lub (rows : row list) (constr : row_constraint)
1908-
(env : environment) : constraint_ list =
1909-
match rows_to_row_or_vars rows with
1910-
| Either.First single_row -> eliminate_row_constraint stage ~lub single_row constr env
1911-
| Either.Second (_all_dims, row_vars) -> (
1912-
let rev_row_vars = List.rev row_vars in
1913-
match
1914-
( constr,
1915-
List.findi rev_row_vars ~f:(fun _ -> function
1916-
| _, { kind = `Output; _ } -> true
1917-
| _ -> false) )
1918-
with
1919-
| Total_elems _, Some (idx, (v, id)) when is_stage5_up stage ->
1920-
let other_vars = List.filteri rev_row_vars ~f:(fun i _ -> i <> idx) in
1921-
let other_vars = List.map other_vars ~f:(fun (v, id) -> row_of_var v id) in
1922-
let other_eqs =
1923-
List.map other_vars ~f:(fun r ->
1924-
Row_eq { r1 = r; r2 = { dims = []; bcast = Broadcastable; id } })
1925-
in
1926-
let rows =
1927-
List.map rows ~f:(function
1928-
| { bcast = Row_var { v = v'; _ }; _ } as r when equal_row_var v' v -> r
1929-
| r -> { r with dims = r_dims r; bcast = Broadcastable })
1930-
in
1931-
other_eqs @ eliminate_rows_constraint stage ~lub rows constr env
1932-
| _ -> [ Rows_constr { r = rows; constr } ])
1907+
let%track5_sexp rec eliminate_rows_constraint ~depth stage ~lub (rows : row list)
1908+
(constr : row_constraint) (env : environment) : constraint_ list =
1909+
if depth > 16 then []
1910+
else
1911+
match rows_to_row_or_vars rows with
1912+
| Either.First single_row ->
1913+
eliminate_row_constraint ~depth:(depth + 1) stage ~lub single_row constr env
1914+
| Either.Second (_all_dims, row_vars) -> (
1915+
let rev_row_vars = List.rev row_vars in
1916+
match
1917+
( constr,
1918+
List.findi rev_row_vars ~f:(fun _ -> function
1919+
| _, { Row_id.kind = `Output; _ } -> true
1920+
| _ -> false) )
1921+
with
1922+
| Total_elems _, Some (idx, (v, id)) when is_stage5_up stage ->
1923+
let other_vars = List.filteri rev_row_vars ~f:(fun i _ -> i <> idx) in
1924+
let other_vars = List.map other_vars ~f:(fun (v, id) -> row_of_var v id) in
1925+
let other_eqs =
1926+
List.map other_vars ~f:(fun r ->
1927+
Row_eq { r1 = r; r2 = { dims = []; bcast = Broadcastable; id } })
1928+
in
1929+
let rows =
1930+
List.map rows ~f:(function
1931+
| { bcast = Row_var { v = v'; _ }; _ } as r when equal_row_var v' v -> r
1932+
| r -> { r with dims = r_dims r; bcast = Broadcastable })
1933+
in
1934+
other_eqs @ eliminate_rows_constraint ~depth:(depth + 1) stage ~lub rows constr env
1935+
| _ -> [ Rows_constr { r = rows; constr } ])
19331936

1934-
and eliminate_row_constraint stage ~lub (r : row) (constr : row_constraint) env : constraint_ list =
1937+
and eliminate_row_constraint ~depth stage ~lub (r : row) (constr : row_constraint) env :
1938+
constraint_ list =
19351939
let keep_constr = if is_stage5_up stage then [] else [ Rows_constr { r = [ r ]; constr } ] in
19361940
match r with
19371941
| { bcast = Broadcastable; _ } ->
19381942
(* The environment is unchanged, as apply_row_constraint would update only the constr. *)
19391943
let ineqs, _env = apply_row_constraint stage r constr env in
19401944
List.concat_map ineqs ~f:(function
19411945
| Rows_constr { r = rows; constr } ->
1942-
eliminate_rows_constraint stage ~lub:None rows constr env
1946+
eliminate_rows_constraint ~depth stage ~lub:None rows constr env
19431947
| ineq -> [ ineq ])
19441948
| { bcast = Row_var { v; beg_dims }; dims; id } -> (
19451949
let r1 = row_of_var v id in
@@ -1963,7 +1967,7 @@ and eliminate_row_constraint stage ~lub (r : row) (constr : row_constraint) env
19631967
let ineqs, _env = apply_row_constraint stage lub constr env in
19641968
List.concat_map ineqs ~f:(function
19651969
| Rows_constr { r = rows; constr } ->
1966-
eliminate_rows_constraint stage ~lub:None rows constr env
1970+
eliminate_rows_constraint ~depth stage ~lub:None rows constr env
19671971
| ineq -> [ ineq ])
19681972
| Num_elems d, [ v ], None when is_stage4_up stage ->
19691973
no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ]
@@ -2000,7 +2004,7 @@ let%track5_sexp close_row_terminal ~(stage : stage) (env : environment)
20002004
when is_stage2_up stage && not (equal_row_constraint constr Unconstrained) ->
20012005
let ineqs =
20022006
(* This is the constraint on the row variable, not on the original row. *)
2003-
try eliminate_row_constraint stage r1 ~lub:None constr env
2007+
try eliminate_row_constraint ~depth:0 stage r1 ~lub:None constr env
20042008
with Shape_error (s, trace) -> raise @@ Shape_error (s, Row_mismatch [ r1 ] :: trace)
20052009
in
20062010
(* FIXME: at which stage should we drop the terminal row? *)
@@ -2094,7 +2098,11 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
20942098
(extras @ ineqs, env)
20952099
| Rows_constr { r = rows; constr } ->
20962100
let substituted_rows = List.map rows ~f:(subst_row env) in
2097-
let more_ineqs, env = apply_rows_constraint ~stage substituted_rows constr env in
2101+
let more_ineqs, env =
2102+
if is_stage5_up stage then
2103+
(eliminate_rows_constraint ~depth:0 stage ~lub:None substituted_rows constr env, env)
2104+
else apply_rows_constraint ~stage substituted_rows constr env
2105+
in
20982106
(more_ineqs @ ineqs, env)
20992107
| Terminal_dim d ->
21002108
let more_ineqs = close_dim_terminal ~stage env @@ subst_dim env d in
@@ -2129,7 +2137,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
21292137
| Bounds_row { lub; constr; _ } ->
21302138
(* TODO: should we store the id somewhere? *)
21312139
let id = phantom_row_id in
2132-
eliminate_row_constraint stage (row_of_var v id) ~lub constr env
2140+
eliminate_row_constraint ~depth:0 stage (row_of_var v id) ~lub constr env
21332141
| _ -> []
21342142
in
21352143
let finalizing_entries : constraint_ list =

0 commit comments

Comments
 (0)