@@ -2295,12 +2295,7 @@ type proj_env = {
22952295}
22962296[@@ deriving sexp_of ]
22972297
2298- type proj_equation =
2299- | Proj_eq of proj * proj
2300- (* * Two projections are the same, e.g. two axes share the same iterator. *)
2301- | Iterated of proj
2302- (* * The projection needs to be an iterator even if an axis is not matched with another axis,
2303- e.g. for broadcasted-to axes of a tensor assigned a constant. *)
2298+ type proj_equation = Proj_eq of proj * proj | Iterated of proj | Non_iterated of proj
23042299[@@ deriving compare , equal , sexp ]
23052300
23062301let % debug4_sexp get_proj_equations (inequalities : constraint_ list ) proj_axis_env
@@ -2410,16 +2405,15 @@ let%debug4_sexp get_proj_equations (inequalities : constraint_ list) proj_axis_e
24102405 match List. rev dims with
24112406 | [] -> assert false
24122407 | inner :: other_dims ->
2413- Proj_eq
2414- ( to_proj
2415- (Conv_input
2416- {
2417- stride;
2418- output = subst_dim env (Var var);
2419- dilation = 0 ;
2420- kernel = get_dim ~d: 0 () ;
2421- }),
2422- to_proj inner )
2408+ let output = subst_dim env (Var var) in
2409+ let input = to_proj inner in
2410+ Iterated (to_proj output)
2411+ :: Non_iterated input
2412+ :: Proj_eq
2413+ ( to_proj
2414+ (Conv_input
2415+ { stride; output; dilation = 0 ; kernel = get_dim ~d: 0 () }),
2416+ input )
24232417 :: List. map other_dims ~f: (fun d -> Proj_eq (to_proj d, Solved Sub_axis )))
24242418 else assert false
24252419 | None -> [] )
@@ -2637,6 +2631,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26372631 let verify_when_solved2 = ref [] in
26382632 let p_dims = ref [] in
26392633 let proj_classes = ref @@ Map. empty (module Proj_id ) in
2634+ let non_product = ref @@ Set. empty (module Proj_id ) in
26402635 let rec loop = function
26412636 | Proj_eq (Proj (p1 , { d; _ } ), Proj (p2 , _ )) when Proj_id. equal p1 p2 ->
26422637 p_dims := (p1, d) :: ! p_dims
@@ -2682,6 +2677,11 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26822677 match Hashtbl. find v_env v with
26832678 | None -> Hashtbl. add_exn v_env ~key: v ~data: p
26842679 | Some p2 -> loop (Proj_eq (p, p2)))
2680+ | Non_iterated p -> (
2681+ match p with
2682+ | Proj (proj_id , _ ) | Conv_input { input_id = Some proj_id ; _ } ->
2683+ non_product := Set. add ! non_product proj_id
2684+ | _ -> () )
26852685 | Iterated (Solved _ ) -> ()
26862686 | Iterated (Proj (pid , { d; _ } )) -> p_dims := (pid, d) :: ! p_dims
26872687 | Iterated (Conv_input { output; dilation = 0 ; kernel = _ ; _ } ) ->
@@ -2698,8 +2698,7 @@ let%debug4_sexp solve_proj_equations (eqs : proj_equation list)
26982698 | Some proj -> loop @@ Iterated proj)
26992699 in
27002700 List. iter eqs ~f: loop;
2701- let projs = ref @@ Map. empty (module Proj_id )
2702- and non_product = ref @@ Set. empty (module Proj_id ) in
2701+ let projs = ref @@ Map. empty (module Proj_id ) in
27032702 List. iter ! p_solved ~f: (fun (p , idx ) ->
27042703 let repr, _ = Utils. union_find ~equal: Proj_id. equal ! proj_classes ~key: p ~rank: 0 in
27052704 non_product := Set. add ! non_product repr;
0 commit comments