Skip to content

Commit d37645d

Browse files
committed
Tightly control what enters into the product space for Total_elems with Strided_var
We will probably need the same for convolution / strided iteration...
1 parent 87f78c1 commit d37645d

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

lib/row.ml

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,12 +2295,7 @@ type proj_env = {
22952295
}
22962296
[@@deriving sexp_of]
22972297

2298-
type proj_equation =
2299-
| Proj_eq of proj * proj
2300-
(** Two projections are the same, e.g. two axes share the same iterator. *)
2301-
| Iterated of proj
2302-
(** The projection needs to be an iterator even if an axis is not matched with another axis,
2303-
e.g. for broadcasted-to axes of a tensor assigned a constant. *)
2298+
type proj_equation = Proj_eq of proj * proj | Iterated of proj | Non_iterated of proj
23042299
[@@deriving compare, equal, sexp]
23052300

23062301
let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_env
@@ -2410,16 +2405,15 @@ let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_e
24102405
match List.rev dims with
24112406
| [] -> assert false
24122407
| inner :: other_dims ->
2413-
Proj_eq
2414-
( to_proj
2415-
(Conv_input
2416-
{
2417-
stride;
2418-
output = subst_dim env (Var var);
2419-
dilation = 0;
2420-
kernel = get_dim ~d:0 ();
2421-
}),
2422-
to_proj inner )
2408+
let output = subst_dim env (Var var) in
2409+
let input = to_proj inner in
2410+
Iterated (to_proj output)
2411+
:: Non_iterated input
2412+
:: Proj_eq
2413+
( to_proj
2414+
(Conv_input
2415+
{ stride; output; dilation = 0; kernel = get_dim ~d:0 () }),
2416+
input )
24232417
:: List.map other_dims ~f:(fun d -> Proj_eq (to_proj d, Solved Sub_axis)))
24242418
else assert false
24252419
| None -> [])
@@ -2637,6 +2631,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26372631
let verify_when_solved2 = ref [] in
26382632
let p_dims = ref [] in
26392633
let proj_classes = ref @@ Map.empty (module Proj_id) in
2634+
let non_product = ref @@ Set.empty (module Proj_id) in
26402635
let rec loop = function
26412636
| Proj_eq (Proj (p1, { d; _ }), Proj (p2, _)) when Proj_id.equal p1 p2 ->
26422637
p_dims := (p1, d) :: !p_dims
@@ -2682,6 +2677,11 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26822677
match Hashtbl.find v_env v with
26832678
| None -> Hashtbl.add_exn v_env ~key:v ~data:p
26842679
| Some p2 -> loop (Proj_eq (p, p2)))
2680+
| Non_iterated p -> (
2681+
match p with
2682+
| Proj (proj_id, _) | Conv_input { input_id = Some proj_id; _ } ->
2683+
non_product := Set.add !non_product proj_id
2684+
| _ -> ())
26852685
| Iterated (Solved _) -> ()
26862686
| Iterated (Proj (pid, { d; _ })) -> p_dims := (pid, d) :: !p_dims
26872687
| Iterated (Conv_input { output; dilation = 0; kernel = _; _ }) ->
@@ -2698,8 +2698,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26982698
| Some proj -> loop @@ Iterated proj)
26992699
in
27002700
List.iter eqs ~f:loop;
2701-
let projs = ref @@ Map.empty (module Proj_id)
2702-
and non_product = ref @@ Set.empty (module Proj_id) in
2701+
let projs = ref @@ Map.empty (module Proj_id) in
27032702
List.iter !p_solved ~f:(fun (p, idx) ->
27042703
let repr, _ = Utils.union_find ~equal:Proj_id.equal !proj_classes ~key:p ~rank:0 in
27052704
non_product := Set.add !non_product repr;

lib/row.mli

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ type proj_equation =
151151
| Iterated of proj
152152
(** The projection needs to be an iterator even if an axis is not matched with another axis,
153153
e.g. for broadcasted-to axes of a tensor assigned a constant. *)
154+
| Non_iterated of proj
155+
(** The projection is not part of a product space, e.g. for convolution input. *)
154156
[@@deriving compare, equal, sexp]
155157

156158
val get_proj_equations :

0 commit comments

Comments
 (0)