@@ -102,7 +102,11 @@ type delayed_var_ref = {
102102let get_variable_ref ref_label =
103103 { var_ref = Ir.Indexing. { ref_label; solved_dim = None }; var = `Not_set_yet }
104104
105- type compose_type = Pointwise_bin | Compose | Einsum of string * delayed_var_ref list
105+ type compose_type =
106+ | Pointwise_bin
107+ | Compose
108+ | Einsum of string * delayed_var_ref list
109+ | Defined_by_cd_logic
106110[@@ deriving sexp_of , equal ]
107111
108112type transpose_type =
@@ -111,12 +115,14 @@ type transpose_type =
111115 | Permute of string * delayed_var_ref list
112116 | Batch_slice of Idx .static_symbol
113117 | Uint4x32_to_prec of Ir.Ops .prec Lazy .t
118+ | Defined_by_cd_logic
114119[@@ deriving equal , sexp_of ]
115120
116121type terminal_type = Data of Ir.Assignments .init_data | Fetch of Ir.Assignments .fetch_op
117122[@@ deriving equal , sexp_of ]
118123
119- type ternary_type = Pointwise_tern | Compose_accumulate [@@ deriving sexp , equal ]
124+ type ternary_type = Pointwise_tern | Compose_accumulate | Defined_by_cd_logic
125+ [@@ deriving sexp , equal ]
120126
121127let identifier ~multichar =
122128 let open Angstrom in
@@ -277,6 +283,10 @@ let logic_to_spec = function
277283 | Transpose (Transpose, _ ) -> " T"
278284 | Transpose (Batch_slice _ , _ ) -> " @|"
279285 | Transpose (Uint4x32_to_prec _ , _ ) -> " U4x32"
286+ | Broadcast (Defined_by_cd_logic , _, _)
287+ | Transpose (Defined_by_cd_logic , _)
288+ | Broadcast_tern (Defined_by_cd_logic, _ , _ , _ ) ->
289+ " <cd_logic>"
280290 | Terminal _ -> " <terminal>"
281291
282292module Update_id = struct
@@ -371,8 +381,7 @@ let axes_spec_to_dims_bio ~sh_id ~row_var_env ~dim_var_env:_ ~f labels =
371381 let output = to_row `Output labels.bcast_output o_dims beg_o_dims in
372382 (batch, input, output)
373383
374- let einsum_slot_spec_to_dims_bio ~original_spec ~sh_id ~row_var_env ~dim_var_env labels
375- =
384+ let einsum_slot_spec_to_dims_bio ~original_spec ~sh_id ~row_var_env ~dim_var_env labels =
376385 let proj_env_update = ref @@ Row. dim_map_empty in
377386 let extras = ref [] in
378387 let f kind = function
@@ -840,6 +849,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
840849 subr = sh3.output;
841850 };
842851 ] )
852+ | Broadcast (Defined_by_cd_logic , _, _)
853+ | Transpose (Defined_by_cd_logic , _)
854+ | Broadcast_tern (Defined_by_cd_logic, _ , _ , _ ) ->
855+ (Row. dim_map_empty, [] )
843856 | Transpose (Batch_slice { static_range; static_symbol } , sh ) ->
844857 Hash_set. remove unused_shapes sh.id;
845858 let slice_v = get_var () in
@@ -915,12 +928,12 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
915928 let dim_var_env = Hashtbl. create (module String ) in
916929
917930 let extras_rhs, proj_env_rhs, (b_rhs, i_rhs, o_rhs) =
918- einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: sh.id ~row_var_env
919- ~dim_var_env ls_rhs
931+ einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: sh.id ~row_var_env ~dim_var_env
932+ ls_rhs
920933 in
921934 let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
922- einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: cur_sh.id ~row_var_env
923- ~dim_var_env ls_lhs
935+ einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: cur_sh.id ~row_var_env ~dim_var_env
936+ ls_lhs
924937 in
925938 (* Bind delayed_var_refs to the variables after they are created *)
926939 let extras_dim_refs =
@@ -1134,16 +1147,16 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
11341147 let row_var_env = Hashtbl. create (module String ) in
11351148 let dim_var_env = Hashtbl. create (module String ) in
11361149 let extras_rhs1, proj_env_rhs1, (b_rhs1, i_rhs1, o_rhs1) =
1137- einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: sh1.id ~row_var_env
1138- ~dim_var_env ls_rhs1
1150+ einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: sh1.id ~row_var_env ~dim_var_env
1151+ ls_rhs1
11391152 in
11401153 let extras_rhs2, proj_env_rhs2, (b_rhs2, i_rhs2, o_rhs2) =
1141- einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: sh2.id ~row_var_env
1142- ~dim_var_env ls_rhs2
1154+ einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: sh2.id ~row_var_env ~dim_var_env
1155+ ls_rhs2
11431156 in
11441157 let extras_lhs, proj_env_lhs, (b_lhs, i_lhs, o_lhs) =
1145- einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: cur_sh.id ~row_var_env
1146- ~dim_var_env ls_lhs
1158+ einsum_slot_spec_to_dims_bio ~original_spec: spec ~sh_id: cur_sh.id ~row_var_env ~dim_var_env
1159+ ls_lhs
11471160 in
11481161 (* Bind delayed_var_refs to the variables after they are created *)
11491162 (* TODO: refactor to avoid duplication with the one for unary einsum *)
@@ -1735,6 +1748,10 @@ let fresh_proj_ids update =
17351748 in
17361749 fresh_shape update.shape;
17371750 (match update.logic with
1751+ | Broadcast (Defined_by_cd_logic , _, _)
1752+ | Transpose (Defined_by_cd_logic , _)
1753+ | Broadcast_tern (Defined_by_cd_logic, _ , _ , _ ) ->
1754+ ()
17381755 | Terminal _ -> ()
17391756 | Transpose (_ , sh ) -> fresh_shape sh
17401757 | Broadcast (_ , sh1 , sh2 ) ->
@@ -1777,6 +1794,12 @@ let%debug4_sexp derive_projections (update_step : update_step) : Idx.projections
17771794 let lhs = update_step.shape in
17781795 let rhs =
17791796 match update_step.logic with
1797+ | Broadcast (Defined_by_cd_logic , _, _)
1798+ | Transpose (Defined_by_cd_logic , _)
1799+ | Broadcast_tern (Defined_by_cd_logic, _ , _ , _ ) ->
1800+ raise
1801+ @@ Utils. User_error
1802+ " Defined_by_cd_logic: use explicit ~logic annotations when defining this operation"
17801803 | Terminal _ -> []
17811804 | Transpose (_ , sh ) -> [ sh ]
17821805 | Broadcast (_ , sh1 , sh2 ) -> [ sh1; sh2 ]
@@ -1830,7 +1853,7 @@ let make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_ax
18301853 match kind with
18311854 | `Batch | `Input -> get_dim ~d ()
18321855 | `Output ->
1833- if not known_no_batch && num_dim1_output = 1 && d = 1 then
1856+ if ( not known_no_batch) && num_dim1_output = 1 && d = 1 then
18341857 let label = debug_name ^ " _output" in
18351858 get_dim ~d ~label ()
18361859 else get_dim ~d ()
0 commit comments