@@ -873,11 +873,14 @@ let s_dim_one_in_entry v ~value (in_ : dim_entry) : _ * dim_entry =
873873let s_dim_one_in_row v ~value in_ =
874874 { in_ with dims = List. map in_.dims ~f: (fun in_ -> s_dim_one v ~value ~in_ ) }
875875
876+ let reapply_rows_constr = ref false
877+
876878let subst_row_constraint_impl ~subst_in_dim ~get_dim_val stage constr =
877879 let subst_total_elems_divided_by numerator divided_by =
878880 let substituted_divided_by = List. map divided_by ~f: (fun v -> subst_in_dim (Var v)) in
879881 match collect_factors substituted_divided_by with
880882 | Some (known_product , residual_vars ) ->
883+ reapply_rows_constr := true ;
881884 Total_elems
882885 { numerator = total_elems_divide numerator known_product; divided_by = residual_vars }
883886 | None ->
@@ -889,10 +892,12 @@ let subst_row_constraint_impl ~subst_in_dim ~get_dim_val stage constr =
889892 when is_stage2_up stage && Option. is_some (get_dim_val var) ->
890893 let dim = Option. value_exn (get_dim_val var) in
891894 let tot = Utils. safe_force coeff * dim in
895+ reapply_rows_constr := true ;
892896 if tot % denom = 0 then subst_total_elems_divided_by (Num_elems (tot / denom)) divided_by
893897 else raise @@ Shape_error (" Total_elems constraint: shape cannot be strided" , [] )
894898 | Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by }
895899 when not (equal_dim (Var var) (subst_in_dim (Var var))) -> (
900+ reapply_rows_constr := true ;
896901 match subst_in_dim (Var var) with
897902 | Dim { d; _ } as value ->
898903 (* Replace (coeff * v / denom) with (coeff * d / denom) *)
@@ -910,7 +915,9 @@ let subst_row_constraint_impl ~subst_in_dim ~get_dim_val stage constr =
910915 (* FIXME: NOT IMPLEMENTED YET *)
911916 failwith " NOT IMPLEMENTED YET" )
912917 | Total_elems { numerator; divided_by } -> subst_total_elems_divided_by numerator divided_by
913- | Exact dims -> Exact (List. map dims ~f: subst_in_dim)
918+ | Exact dims ->
919+ (* The constraint update does not affect its applicability, so we don't need to reapply it. *)
920+ Exact (List. map dims ~f: subst_in_dim)
914921 | Unconstrained -> constr
915922
916923let s_dim_one_in_row_constr stage v ~value constr =
@@ -921,12 +928,23 @@ let s_dim_one_in_row_constr stage v ~value constr =
921928 ~subst_in_dim: (fun in_ -> s_dim_one v ~value ~in_ )
922929 ~get_dim_val stage constr
923930
924- let s_dim_one_in_row_entry stage v ~value in_ =
925- match in_ with
926- | Solved_row in_ -> Solved_row (s_dim_one_in_row v ~value in_)
927- | Bounds_row { cur; subr; lub; constr } ->
928- let constr = s_dim_one_in_row_constr stage v ~value constr in
929- Bounds_row { cur; subr; lub = Option. map lub ~f: (s_dim_one_in_row v ~value ); constr }
931+ let ineqs_from_reapply_rows_constr = ref []
932+
933+ let s_dim_one_in_row_entry stage v ~value ~key ~data =
934+ assert (not ! reapply_rows_constr);
935+ let result =
936+ match data with
937+ | Solved_row in_ -> Solved_row (s_dim_one_in_row v ~value in_)
938+ | Bounds_row { cur; subr; lub; constr } ->
939+ let constr = s_dim_one_in_row_constr stage v ~value constr in
940+ if ! reapply_rows_constr then
941+ ineqs_from_reapply_rows_constr :=
942+ Rows_constr { r = [ row_of_var key phantom_row_id ]; constr }
943+ :: ! ineqs_from_reapply_rows_constr;
944+ reapply_rows_constr := false ;
945+ Bounds_row { cur; subr; lub = Option. map lub ~f: (s_dim_one_in_row v ~value ); constr }
946+ in
947+ result
930948
931949let rec vars_of_dim = function
932950 | Dim _ -> Set. empty (module Dim_var )
@@ -1112,6 +1130,7 @@ and apply_row_constraint ~depth stage (r : row) (constr : row_constraint) env :
11121130 else if is_unconstrained constr then ([] , env)
11131131 else
11141132 let constr = subst_row_constraint stage env constr in
1133+ reapply_rows_constr := false ;
11151134 let reduce constr ~beg_dims ~dims =
11161135 try reduce_row_constraint constr ~beg_dims ~dims
11171136 with Shape_error (s , trace ) -> raise @@ Shape_error (s, Row_mismatch [ r ] :: trace)
@@ -1308,7 +1327,7 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
13081327 let dim_env = Map. map env.dim_env ~f in
13091328 {
13101329 dim_env = Map. add_exn dim_env ~key: v ~data: (Solved_dim dim2);
1311- row_env = Map. map env.row_env ~f: (s_dim_one_in_row_entry stage v ~value: dim2);
1330+ row_env = Map. mapi env.row_env ~f: (s_dim_one_in_row_entry stage v ~value: dim2);
13121331 }
13131332 | Some (Solved_dim _ ) -> assert false
13141333 | Some (Bounds_dim { cur; subr; lub; constr } ) ->
@@ -1325,9 +1344,11 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
13251344 ineqs := extras @ ! ineqs;
13261345 {
13271346 dim_env = Map. set dim_env ~key: v ~data: (Solved_dim dim2);
1328- row_env = Map. map env.row_env ~f: (s_dim_one_in_row_entry stage v ~value: dim2);
1347+ row_env = Map. mapi env.row_env ~f: (s_dim_one_in_row_entry stage v ~value: dim2);
13291348 }
13301349 in
1350+ ineqs := ! ineqs_from_reapply_rows_constr @ ! ineqs;
1351+ ineqs_from_reapply_rows_constr := [] ;
13311352 let dim_eqs, ineqs =
13321353 List. partition_map ! ineqs ~f: (function
13331354 | Dim_eq { d1; d2 } -> Either. First (d1, d2)
@@ -2119,6 +2140,7 @@ let%debug5_sexp eliminate_variables (env : environment) ({ dims; bcast; id } as
21192140 (subst_row_constraint stage env constr)
21202141 env
21212142 in
2143+ reapply_rows_constr := false ;
21222144 ineqs @ elim_dims
21232145 | _ -> elim_var :: elim_dims)
21242146
@@ -2164,6 +2186,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
21642186 (extras @ ineqs, env)
21652187 | Rows_constr { r = rows ; constr } ->
21662188 let constr : row_constraint = subst_row_constraint stage env constr in
2189+ reapply_rows_constr := false ;
21672190 let substituted_rows = List. map rows ~f: (subst_row env) in
21682191 let (more_ineqs : constraint_ list ), env =
21692192 if is_stage5_up stage then
@@ -2424,7 +2447,7 @@ let get_proj_index proj_env =
24242447 let repr, _ =
24252448 Utils. union_find ~equal: Proj_id. equal proj_env.proj_classes ~key: proj_id ~rank: 0
24262449 in
2427- match d, Map. find proj_env.proj_to_index repr with
2450+ match ( d, Map. find proj_env.proj_to_index repr) with
24282451 | _ , Some i -> i
24292452 | (0 | 1 ), None -> Fixed_idx 0
24302453 | _ -> unknown_projection proj_id d)
@@ -2606,7 +2629,7 @@ let get_dim_index proj_env =
26062629 let repr, _ =
26072630 Utils. union_find ~equal: Proj_id. equal proj_env.proj_classes ~key: proj_id ~rank: 0
26082631 in
2609- match d, Map. find proj_env.proj_to_index repr with
2632+ match ( d, Map. find proj_env.proj_to_index repr) with
26102633 | _ , Some i -> i
26112634 | (0 | 1 ), None -> Fixed_idx 0
26122635 | _ -> unknown_projection proj_id d)
@@ -2625,7 +2648,8 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26252648 let p_dims = ref [] in
26262649 let proj_classes = ref @@ Map. empty (module Proj_id ) in
26272650 let non_product = ref @@ Set. empty (module Proj_id ) in
2628- let rec loop (eq : proj_equation ) : unit = match eq with
2651+ let rec loop (eq : proj_equation ) : unit =
2652+ match eq with
26292653 | Proj_eq (Proj (p1 , { d; _ } ), Proj (p2 , _ )) when Proj_id. equal p1 p2 ->
26302654 p_dims := (p1, d) :: ! p_dims
26312655 | Proj_eq (Var v1 , Var v2 ) when equal_dim_var v1 v2 -> ()
0 commit comments