@@ -829,8 +829,8 @@ let check_empty_row r =
829829 [ Row_eq { r1 = row_of_var v r.id; r2 = { dims = [] ; bcast = Broadcastable ; id = r.id } } ]
830830 else raise @@ Shape_error (" check_empty_row: row is not empty" , [ Row_mismatch [ r ] ])
831831
832- let rec apply_rows_constraint ~stage (rows : row list ) (constr : row_constraint ) ( env : environment )
833- : constraint_ list * environment =
832+ let % track5_sexp rec apply_rows_constraint ~stage (rows : row list ) (constr : row_constraint )
833+ (env : environment ) : constraint_ list * environment =
834834 match rows_to_row_or_vars rows with
835835 | Either. First single_row -> apply_row_constraint stage single_row constr env
836836 | Either. Second (all_dims , row_vars ) -> (
@@ -1904,42 +1904,46 @@ let last_dim_is dims p = match List.last dims with Some (Dim { d; _ }) -> p d |
19041904let r_dims r =
19051905 match r.bcast with Broadcastable -> r.dims | Row_var { beg_dims; _ } -> beg_dims @ r.dims
19061906
1907- let rec eliminate_rows_constraint stage ~lub (rows : row list ) (constr : row_constraint )
1908- (env : environment ) : constraint_ list =
1909- match rows_to_row_or_vars rows with
1910- | Either. First single_row -> eliminate_row_constraint stage ~lub single_row constr env
1911- | Either. Second (_all_dims , row_vars ) -> (
1912- let rev_row_vars = List. rev row_vars in
1913- match
1914- ( constr,
1915- List. findi rev_row_vars ~f: (fun _ -> function
1916- | _ , { kind = `Output ; _ } -> true
1917- | _ -> false ) )
1918- with
1919- | Total_elems _ , Some (idx , (v , id )) when is_stage5_up stage ->
1920- let other_vars = List. filteri rev_row_vars ~f: (fun i _ -> i <> idx) in
1921- let other_vars = List. map other_vars ~f: (fun (v , id ) -> row_of_var v id) in
1922- let other_eqs =
1923- List. map other_vars ~f: (fun r ->
1924- Row_eq { r1 = r; r2 = { dims = [] ; bcast = Broadcastable ; id } })
1925- in
1926- let rows =
1927- List. map rows ~f: (function
1928- | { bcast = Row_var { v = v' ; _ } ; _ } as r when equal_row_var v' v -> r
1929- | r -> { r with dims = r_dims r; bcast = Broadcastable })
1930- in
1931- other_eqs @ eliminate_rows_constraint stage ~lub rows constr env
1932- | _ -> [ Rows_constr { r = rows; constr } ])
1907+ let % track5_sexp rec eliminate_rows_constraint ~depth stage ~lub (rows : row list )
1908+ (constr : row_constraint ) (env : environment ) : constraint_ list =
1909+ if depth > 16 then []
1910+ else
1911+ match rows_to_row_or_vars rows with
1912+ | Either. First single_row ->
1913+ eliminate_row_constraint ~depth: (depth + 1 ) stage ~lub single_row constr env
1914+ | Either. Second (_all_dims , row_vars ) -> (
1915+ let rev_row_vars = List. rev row_vars in
1916+ match
1917+ ( constr,
1918+ List. findi rev_row_vars ~f: (fun _ -> function
1919+ | _ , { Row_id. kind = `Output ; _ } -> true
1920+ | _ -> false ) )
1921+ with
1922+ | Total_elems _ , Some (idx , (v , id )) when is_stage5_up stage ->
1923+ let other_vars = List. filteri rev_row_vars ~f: (fun i _ -> i <> idx) in
1924+ let other_vars = List. map other_vars ~f: (fun (v , id ) -> row_of_var v id) in
1925+ let other_eqs =
1926+ List. map other_vars ~f: (fun r ->
1927+ Row_eq { r1 = r; r2 = { dims = [] ; bcast = Broadcastable ; id } })
1928+ in
1929+ let rows =
1930+ List. map rows ~f: (function
1931+ | { bcast = Row_var { v = v' ; _ } ; _ } as r when equal_row_var v' v -> r
1932+ | r -> { r with dims = r_dims r; bcast = Broadcastable })
1933+ in
1934+ other_eqs @ eliminate_rows_constraint ~depth: (depth + 1 ) stage ~lub rows constr env
1935+ | _ -> [ Rows_constr { r = rows; constr } ])
19331936
1934- and eliminate_row_constraint stage ~lub (r : row ) (constr : row_constraint ) env : constraint_ list =
1937+ and eliminate_row_constraint ~depth stage ~lub (r : row ) (constr : row_constraint ) env :
1938+ constraint_ list =
19351939 let keep_constr = if is_stage5_up stage then [] else [ Rows_constr { r = [ r ]; constr } ] in
19361940 match r with
19371941 | { bcast = Broadcastable ; _ } ->
19381942 (* The environment is unchanged, as apply_row_constraint would update only the constr. *)
19391943 let ineqs, _env = apply_row_constraint stage r constr env in
19401944 List. concat_map ineqs ~f: (function
19411945 | Rows_constr { r = rows ; constr } ->
1942- eliminate_rows_constraint stage ~lub: None rows constr env
1946+ eliminate_rows_constraint ~depth stage ~lub: None rows constr env
19431947 | ineq -> [ ineq ])
19441948 | { bcast = Row_var { v; beg_dims } ; dims; id } -> (
19451949 let r1 = row_of_var v id in
@@ -1963,7 +1967,7 @@ and eliminate_row_constraint stage ~lub (r : row) (constr : row_constraint) env
19631967 let ineqs, _env = apply_row_constraint stage lub constr env in
19641968 List. concat_map ineqs ~f: (function
19651969 | Rows_constr { r = rows ; constr } ->
1966- eliminate_rows_constraint stage ~lub: None rows constr env
1970+ eliminate_rows_constraint ~depth stage ~lub: None rows constr env
19671971 | ineq -> [ ineq ])
19681972 | Num_elems d , [ v ], None when is_stage4_up stage ->
19691973 no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ]
@@ -2000,7 +2004,7 @@ let%track5_sexp close_row_terminal ~(stage : stage) (env : environment)
20002004 when is_stage2_up stage && not (equal_row_constraint constr Unconstrained ) ->
20012005 let ineqs =
20022006 (* This is the constraint on the row variable, not on the original row. *)
2003- try eliminate_row_constraint stage r1 ~lub: None constr env
2007+ try eliminate_row_constraint ~depth: 0 stage r1 ~lub: None constr env
20042008 with Shape_error (s , trace ) -> raise @@ Shape_error (s, Row_mismatch [ r1 ] :: trace)
20052009 in
20062010 (* FIXME: at which stage should we drop the terminal row? *)
@@ -2094,7 +2098,11 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
20942098 (extras @ ineqs, env)
20952099 | Rows_constr { r = rows ; constr } ->
20962100 let substituted_rows = List. map rows ~f: (subst_row env) in
2097- let more_ineqs, env = apply_rows_constraint ~stage substituted_rows constr env in
2101+ let more_ineqs, env =
2102+ if is_stage5_up stage then
2103+ (eliminate_rows_constraint ~depth: 0 stage ~lub: None substituted_rows constr env, env)
2104+ else apply_rows_constraint ~stage substituted_rows constr env
2105+ in
20982106 (more_ineqs @ ineqs, env)
20992107 | Terminal_dim d ->
21002108 let more_ineqs = close_dim_terminal ~stage env @@ subst_dim env d in
@@ -2129,7 +2137,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
21292137 | Bounds_row { lub; constr; _ } ->
21302138 (* TODO: should we store the id somewhere? *)
21312139 let id = phantom_row_id in
2132- eliminate_row_constraint stage (row_of_var v id) ~lub constr env
2140+ eliminate_row_constraint ~depth: 0 stage (row_of_var v id) ~lub constr env
21332141 | _ -> []
21342142 in
21352143 let finalizing_entries : constraint_ list =
0 commit comments