Skip to content

Commit 38b4a3d

Browse files
committed
Reapply rows constraints as they get updated in the environment; debuggability
1 parent e5544d6 commit 38b4a3d

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

lib/row.ml

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,14 @@ let s_dim_one_in_entry v ~value (in_ : dim_entry) : _ * dim_entry =
873873
let s_dim_one_in_row v ~value in_ =
874874
{ in_ with dims = List.map in_.dims ~f:(fun in_ -> s_dim_one v ~value ~in_) }
875875

876+
let reapply_rows_constr = ref false
877+
876878
let subst_row_constraint_impl ~subst_in_dim ~get_dim_val stage constr =
877879
let subst_total_elems_divided_by numerator divided_by =
878880
let substituted_divided_by = List.map divided_by ~f:(fun v -> subst_in_dim (Var v)) in
879881
match collect_factors substituted_divided_by with
880882
| Some (known_product, residual_vars) ->
883+
reapply_rows_constr := true;
881884
Total_elems
882885
{ numerator = total_elems_divide numerator known_product; divided_by = residual_vars }
883886
| None ->
@@ -889,10 +892,12 @@ let subst_row_constraint_impl ~subst_in_dim ~get_dim_val stage constr =
889892
when is_stage2_up stage && Option.is_some (get_dim_val var) ->
890893
let dim = Option.value_exn (get_dim_val var) in
891894
let tot = Utils.safe_force coeff * dim in
895+
reapply_rows_constr := true;
892896
if tot % denom = 0 then subst_total_elems_divided_by (Num_elems (tot / denom)) divided_by
893897
else raise @@ Shape_error ("Total_elems constraint: shape cannot be strided", [])
894898
| Total_elems { numerator = Strided_var { coeff; var; denom }; divided_by }
895899
when not (equal_dim (Var var) (subst_in_dim (Var var))) -> (
900+
reapply_rows_constr := true;
896901
match subst_in_dim (Var var) with
897902
| Dim { d; _ } as value ->
898903
(* Replace (coeff * v / denom) with (coeff * d / denom) *)
@@ -910,7 +915,9 @@ let subst_row_constraint_impl ~subst_in_dim ~get_dim_val stage constr =
910915
(* FIXME: NOT IMPLEMENTED YET *)
911916
failwith "NOT IMPLEMENTED YET")
912917
| Total_elems { numerator; divided_by } -> subst_total_elems_divided_by numerator divided_by
913-
| Exact dims -> Exact (List.map dims ~f:subst_in_dim)
918+
| Exact dims ->
919+
(* The constraint update does not affect its applicability, so we don't need to reapply it. *)
920+
Exact (List.map dims ~f:subst_in_dim)
914921
| Unconstrained -> constr
915922

916923
let s_dim_one_in_row_constr stage v ~value constr =
@@ -921,12 +928,23 @@ let s_dim_one_in_row_constr stage v ~value constr =
921928
~subst_in_dim:(fun in_ -> s_dim_one v ~value ~in_)
922929
~get_dim_val stage constr
923930

924-
let s_dim_one_in_row_entry stage v ~value in_ =
925-
match in_ with
926-
| Solved_row in_ -> Solved_row (s_dim_one_in_row v ~value in_)
927-
| Bounds_row { cur; subr; lub; constr } ->
928-
let constr = s_dim_one_in_row_constr stage v ~value constr in
929-
Bounds_row { cur; subr; lub = Option.map lub ~f:(s_dim_one_in_row v ~value); constr }
931+
let ineqs_from_reapply_rows_constr = ref []
932+
933+
let s_dim_one_in_row_entry stage v ~value ~key ~data =
934+
assert (not !reapply_rows_constr);
935+
let result =
936+
match data with
937+
| Solved_row in_ -> Solved_row (s_dim_one_in_row v ~value in_)
938+
| Bounds_row { cur; subr; lub; constr } ->
939+
let constr = s_dim_one_in_row_constr stage v ~value constr in
940+
if !reapply_rows_constr then
941+
ineqs_from_reapply_rows_constr :=
942+
Rows_constr { r = [ row_of_var key phantom_row_id ]; constr }
943+
:: !ineqs_from_reapply_rows_constr;
944+
reapply_rows_constr := false;
945+
Bounds_row { cur; subr; lub = Option.map lub ~f:(s_dim_one_in_row v ~value); constr }
946+
in
947+
result
930948

931949
let rec vars_of_dim = function
932950
| Dim _ -> Set.empty (module Dim_var)
@@ -1112,6 +1130,7 @@ and apply_row_constraint ~depth stage (r : row) (constr : row_constraint) env :
11121130
else if is_unconstrained constr then ([], env)
11131131
else
11141132
let constr = subst_row_constraint stage env constr in
1133+
reapply_rows_constr := false;
11151134
let reduce constr ~beg_dims ~dims =
11161135
try reduce_row_constraint constr ~beg_dims ~dims
11171136
with Shape_error (s, trace) -> raise @@ Shape_error (s, Row_mismatch [ r ] :: trace)
@@ -1308,7 +1327,7 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
13081327
let dim_env = Map.map env.dim_env ~f in
13091328
{
13101329
dim_env = Map.add_exn dim_env ~key:v ~data:(Solved_dim dim2);
1311-
row_env = Map.map env.row_env ~f:(s_dim_one_in_row_entry stage v ~value:dim2);
1330+
row_env = Map.mapi env.row_env ~f:(s_dim_one_in_row_entry stage v ~value:dim2);
13121331
}
13131332
| Some (Solved_dim _) -> assert false
13141333
| Some (Bounds_dim { cur; subr; lub; constr }) ->
@@ -1325,9 +1344,11 @@ let%debug5_sexp rec unify_dim ~stage (eq : dim * dim) (env : environment) :
13251344
ineqs := extras @ !ineqs;
13261345
{
13271346
dim_env = Map.set dim_env ~key:v ~data:(Solved_dim dim2);
1328-
row_env = Map.map env.row_env ~f:(s_dim_one_in_row_entry stage v ~value:dim2);
1347+
row_env = Map.mapi env.row_env ~f:(s_dim_one_in_row_entry stage v ~value:dim2);
13291348
}
13301349
in
1350+
ineqs := !ineqs_from_reapply_rows_constr @ !ineqs;
1351+
ineqs_from_reapply_rows_constr := [];
13311352
let dim_eqs, ineqs =
13321353
List.partition_map !ineqs ~f:(function
13331354
| Dim_eq { d1; d2 } -> Either.First (d1, d2)
@@ -2119,6 +2140,7 @@ let%debug5_sexp eliminate_variables (env : environment) ({ dims; bcast; id } as
21192140
(subst_row_constraint stage env constr)
21202141
env
21212142
in
2143+
reapply_rows_constr := false;
21222144
ineqs @ elim_dims
21232145
| _ -> elim_var :: elim_dims)
21242146

@@ -2164,6 +2186,7 @@ let%debug4_sexp solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (
21642186
(extras @ ineqs, env)
21652187
| Rows_constr { r = rows; constr } ->
21662188
let constr : row_constraint = subst_row_constraint stage env constr in
2189+
reapply_rows_constr := false;
21672190
let substituted_rows = List.map rows ~f:(subst_row env) in
21682191
let (more_ineqs : constraint_ list), env =
21692192
if is_stage5_up stage then
@@ -2424,7 +2447,7 @@ let get_proj_index proj_env =
24242447
let repr, _ =
24252448
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0
24262449
in
2427-
match d, Map.find proj_env.proj_to_index repr with
2450+
match (d, Map.find proj_env.proj_to_index repr) with
24282451
| _, Some i -> i
24292452
| (0 | 1), None -> Fixed_idx 0
24302453
| _ -> unknown_projection proj_id d)
@@ -2606,7 +2629,7 @@ let get_dim_index proj_env =
26062629
let repr, _ =
26072630
Utils.union_find ~equal:Proj_id.equal proj_env.proj_classes ~key:proj_id ~rank:0
26082631
in
2609-
match d, Map.find proj_env.proj_to_index repr with
2632+
match (d, Map.find proj_env.proj_to_index repr) with
26102633
| _, Some i -> i
26112634
| (0 | 1), None -> Fixed_idx 0
26122635
| _ -> unknown_projection proj_id d)
@@ -2625,7 +2648,8 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26252648
let p_dims = ref [] in
26262649
let proj_classes = ref @@ Map.empty (module Proj_id) in
26272650
let non_product = ref @@ Set.empty (module Proj_id) in
2628-
let rec loop (eq : proj_equation) : unit = match eq with
2651+
let rec loop (eq : proj_equation) : unit =
2652+
match eq with
26292653
| Proj_eq (Proj (p1, { d; _ }), Proj (p2, _)) when Proj_id.equal p1 p2 ->
26302654
p_dims := (p1, d) :: !p_dims
26312655
| Proj_eq (Var v1, Var v2) when equal_dim_var v1 v2 -> ()

lib/shape.ml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,13 +694,13 @@ let apply_env_update ~eliminate_variables env update_step =
694694
let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
695695
(* Allow the derivation of constraints to depend on the shapes (currently, only Batch_slice
696696
does). *)
697-
ignore (apply_env_update ~eliminate_variables:false !state update_step);
697+
assert (List.is_empty (apply_env_update ~eliminate_variables:false !state update_step));
698698
let _, ineqs = get_inequalities update_step in
699699
active_update_steps := update_step :: !active_update_steps;
700700
active_constraints := ineqs @ !active_constraints;
701701
let ineqs', env = Row.solve_inequalities ~stage:Row.Stage1 ineqs !state in
702702
let _debug_remaining_constraints : Row.constraint_ list = ineqs' in
703-
ignore (apply_env_update ~eliminate_variables:false env update_step);
703+
assert (List.is_empty (apply_env_update ~eliminate_variables:false env update_step));
704704
state := env
705705

706706
let%debug4_sexp finish_inference (() : unit) : unit =
@@ -712,18 +712,23 @@ let%debug4_sexp finish_inference (() : unit) : unit =
712712
let unsolved, env = Row.solve_inequalities ~stage:Stage5 unsolved env in
713713
let unsolved, env = Row.solve_inequalities ~stage:Stage6 unsolved env in
714714
let unsolved, env = Row.solve_inequalities ~stage:Stage7 unsolved env in
715+
let _active_update_steps : update_step list = !active_update_steps in
715716
let eliminated =
716717
List.concat_map ~f:(apply_env_update ~eliminate_variables:true env) !active_update_steps
717718
in
718719
let unsolved, env = Row.solve_inequalities ~stage:Stage7 (eliminated @ unsolved) env in
719720
assert (List.is_empty unsolved);
720-
ignore @@ List.map ~f:(apply_env_update ~eliminate_variables:false env) !active_update_steps;
721+
List.iter
722+
~f:(fun update_step ->
723+
assert (List.is_empty (apply_env_update ~eliminate_variables:false env update_step)))
724+
!active_update_steps;
725+
let _applied_update_steps : update_step list = !active_update_steps in
721726
active_constraints := [];
722727
active_update_steps := [];
723728
(* There should not be any shape variables remaining in any inference-undergoing update steps. *)
724729
state := Row.empty_env
725730

726-
let row_to_dims row =
731+
let%debug4_sexp row_to_dims (row : Row.t) : int array =
727732
let open Row in
728733
let f = function
729734
| Dim { d; _ } -> d
@@ -734,6 +739,7 @@ let row_to_dims row =
734739
^ Sexp.to_string_hum ([%sexp_of: dim_var] v),
735740
[ Row_mismatch [ row ] ] )
736741
| Conv_input _ ->
742+
(* FIXME: reconsider this, we could return the input dimension of the convolution. *)
737743
raise
738744
@@ Row.Shape_error
739745
( "Not enough shape information: affine dimension cannot be converted to single int",

0 commit comments

Comments
 (0)