@@ -33,7 +33,7 @@ type expr_type =
3333
3434let is_unknown = function Unknown -> true | _ -> false
3535
36- type projections_slot = LHS | RHS1 | RHS2 | Nonslot | Undet [@@ deriving equal , sexp ]
36+ type projections_slot = LHS | RHS1 | RHS2 | RHS3 | Nonslot | Undet [@@ deriving equal , sexp ]
3737
3838let assignment_op expr =
3939 (* This should stay in sync with Arrayjit.Ops.assign_op_cd_syntax. *)
@@ -72,6 +72,7 @@ let assignment_op expr =
7272
7373let binary_op expr =
7474 (* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)
75+ (* FIXME: get rid of this and use binary_ops table instead. *)
7576 let loc = expr.pexp_loc in
7677 match expr with
7778 | [% expr ( + )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Add ])
@@ -106,6 +107,18 @@ let binary_op expr =
106107 " + (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \
107108 (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^ (Max), ^^ (Min)" )
108109
110+ let ternary_op expr =
111+ (* FIXME: get rid of this and use ternary_ops table instead. *)
112+ let loc = expr.pexp_loc in
113+ match expr with
114+ | [% expr where] -> ([% expr Shape. Pointwise_tern ], [% expr Arrayjit.Ops. Where ])
115+ | [% expr fma] -> ([% expr Shape. Compose_accumulate ], [% expr Arrayjit.Ops. FMA ])
116+ | _ ->
117+ ( [% expr Shape. Pointwise_bin ],
118+ Ast_builder.Default. pexp_extension ~loc
119+ @@ Location. error_extensionf ~loc " ppx_ocannl %%cd: expected a ternary operator, one of: %s"
120+ " where, fma" )
121+
109122type result = {
110123 vbs : value_binding Map .M (String ).t;
111124 (* * [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
@@ -206,6 +219,7 @@ let project_p_slot debug loc slot =
206219 | LHS -> [% expr p.project_lhs]
207220 | RHS1 -> [% expr p.project_rhs.(0 )]
208221 | RHS2 -> [% expr p.project_rhs.(1 )]
222+ | RHS3 -> [% expr p.project_rhs.(2 )]
209223 | Nonslot ->
210224 Ast_builder.Default. pexp_extension ~loc
211225 @@ Location. error_extensionf ~loc
@@ -221,6 +235,7 @@ let project_p_dims debug loc slot =
221235 | LHS -> [% expr p.lhs_dims]
222236 | RHS1 -> [% expr p.rhs_dims.(0 )]
223237 | RHS2 -> [% expr p.rhs_dims.(1 )]
238+ | RHS3 -> [% expr p.rhs_dims.(2 )]
224239 | Nonslot ->
225240 Ast_builder.Default. pexp_extension ~loc
226241 @@ Location. error_extensionf ~loc
@@ -344,6 +359,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
344359 | LHS -> [% pat? nondiff__lhs]
345360 | RHS1 -> [% pat? nondiff__rhs1]
346361 | RHS2 -> [% pat? nondiff__rhs2]
362+ | RHS3 -> [% pat? nondiff__rhs3]
347363 | Nonslot | Undet -> [% pat? nondiff__tensor]
348364 in
349365 let t = pat2expr v in
@@ -444,6 +460,74 @@ let translate (expr : expression) : result =
444460 { vbs = no_vbs; typ = Tensor ; slot = Undet ; expr; array_opt_of_code = None }
445461 in
446462 let loop = transl ~bad_pun_hints in
463+ (* FIXME: collapse these (code reuse) *)
464+ let process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ?projections ~proj_in_scope
465+ () =
466+ let initialize_neutral, accu_op = assignment_op accu_op in
467+ let setup_l =
468+ setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ~proj_in_scope: true lhs
469+ in
470+ let _, tern_op = ternary_op tern_op in
471+ let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1 in
472+ let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2 in
473+ let setup_r3 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs3 in
474+ let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
475+ let projections =
476+ match projections with
477+ | Some prjs -> prjs
478+ | None ->
479+ let lhs_dims = project_p_dims " LHS" lhs.pexp_loc setup_l.slot in
480+ let rhs1_dims = project_p_dims " RHS1" lhs.pexp_loc setup_r1.slot in
481+ let rhs2_dims = project_p_dims " RHS2" lhs.pexp_loc setup_r2.slot in
482+ let rhs3_dims = project_p_dims " RHS3" lhs.pexp_loc setup_r3.slot in
483+ let project_lhs = project_p_slot " LHS" lhs.pexp_loc setup_l.slot in
484+ let project_rhs1 = project_p_slot " RHS1" rhs1.pexp_loc setup_r1.slot in
485+ let project_rhs2 = project_p_slot " RHS2" rhs2.pexp_loc setup_r2.slot in
486+ let project_rhs3 = project_p_slot " RHS3" rhs3.pexp_loc setup_r3.slot in
487+ [% expr
488+ lazy
489+ (let p = Lazy. force projections in
490+ Arrayjit.Indexing.
491+ {
492+ product_space = p.product_space;
493+ product_iterators = p.product_iterators;
494+ lhs_dims = [% e lhs_dims];
495+ rhs_dims = [| [% e rhs1_dims]; [% e rhs2_dims]; [% e rhs3_dims] |];
496+ project_lhs = [% e project_lhs];
497+ project_rhs = [| [% e project_rhs1]; [% e project_rhs2]; [% e project_rhs3] |];
498+ debug_info =
499+ {
500+ p.debug_info with
501+ trace =
502+ ( " ppx_cd " ^ [% e expr2string_or_empty accu_op] ^ " "
503+ ^ [% e expr2string_or_empty tern_op],
504+ Arrayjit.Indexing. unique_debug_id () )
505+ :: p.debug_info.trace;
506+ };
507+ })]
508+ in
509+ (* TODO: might be better to treat missing [rhs1, rhs2, rhs3] as zeros or errors rather than
510+ eliding the code. *)
511+ let body =
512+ [% expr
513+ Option. value ~default: Arrayjit.Assignments. Noop
514+ @@ Option. map [% e setup_l.array_opt] ~f: (fun lhs ->
515+ Option. map3 [% e setup_r1.array_opt] [% e setup_r2.array_opt] [% e setup_r2.array_opt]
516+ ~f: (fun rhs1 rhs2 rhs3 ->
517+ Arrayjit.Assignments. Accum_ternop
518+ {
519+ initialize_neutral = [% e initialize_neutral];
520+ accum = [% e accu_op];
521+ lhs;
522+ op = [% e tern_op];
523+ rhs1;
524+ rhs2;
525+ rhs3;
526+ projections = [% e projections];
527+ }))]
528+ in
529+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2; setup_r3 ] body
530+ in
447531 let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
448532 let initialize_neutral, accu_op = assignment_op accu_op in
449533 let setup_l =
@@ -561,6 +645,27 @@ let translate (expr : expression) : result =
561645 in
562646 assignment ~punned ~lhs: setup_l ~rhses: [ setup_r ] body
563647 in
648+ let process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic =
649+ let initialize_neutral, accu_op = assignment_op accu_op in
650+ let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ~proj_in_scope lhs in
651+ let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1 in
652+ let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2 in
653+ let setup_r3 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs3 in
654+ let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
655+ let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
656+ let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
657+ let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in
658+ let t3_expr, rhs3_is_grad, rhs3_is_merge = args_for ~loc setup_r3 in
659+ let body =
660+ [% expr
661+ Tensor. raw_ternop ~initialize_neutral: [% e initialize_neutral] ~accum: [% e accu_op]
662+ ~t: [% e t_expr] ~lhs_is_grad: [% e lhs_is_grad] ~op: [% e tern_op] ~t1: [% e t1_expr]
663+ ~rhs1_is_grad: [% e rhs1_is_grad] ~rhs1_is_merge: [% e rhs1_is_merge] ~t2: [% e t2_expr]
664+ ~rhs2_is_grad: [% e rhs2_is_grad] ~rhs2_is_merge: [% e rhs2_is_merge] ~t3: [% e t3_expr]
665+ ~rhs3_is_grad: [% e rhs3_is_grad] ~rhs3_is_merge: [% e rhs3_is_merge] ~logic: [% e logic]]
666+ in
667+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2; setup_r3 ] body
668+ in
564669 let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
565670 let initialize_neutral, accu_op = assignment_op accu_op in
566671 let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ~proj_in_scope lhs in
@@ -655,6 +760,19 @@ let translate (expr : expression) : result =
655760 slot = RHS2 ;
656761 expr = [% expr Option. map t2.Tensor. diff ~f: (fun d -> d.Tensor. grad)];
657762 }
763+ | { pexp_desc = Pexp_ident { txt = Lident "rhs3" ; _ } ; _ } ->
764+ { default_result with typ = Array ; slot = RHS3 }
765+ | { pexp_desc = Pexp_ident { txt = Lident "t3" ; _ } ; _ } ->
766+ { default_result with typ = Tensor ; slot = RHS3 }
767+ | { pexp_desc = Pexp_ident { txt = Lident "v3" ; _ } ; _ } ->
768+ { default_result with typ = Array ; slot = RHS3 ; expr = [% expr t3.Tensor. value] }
769+ | { pexp_desc = Pexp_ident { txt = Lident "g3" ; _ } ; _ } ->
770+ {
771+ default_result with
772+ typ = Grad_of_tensor [% expr t3];
773+ slot = RHS3 ;
774+ expr = [% expr Option. map t3.Tensor. diff ~f: (fun d -> d.Tensor. grad)];
775+ }
658776 | { pexp_desc = Pexp_ident { txt = Lident op_ident ; _ } ; _ } when is_primitive_op op_ident ->
659777 default_result
660778 | [% expr [% e? expr1] **. [% e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
@@ -811,7 +929,15 @@ let translate (expr : expression) : result =
811929 [% e? accu_op]
812930 [% e? lhs]
813931 ([% e? bin_op] [% e? rhs1] ([% e? rhs2] ~projections: [% e? projections]))] ->
932+ (* Note: when clause not needed here and below, it's an error if bin_op is not a primitive
933+ binary op. *)
814934 process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~projections ~proj_in_scope: true ()
935+ | [% expr
936+ [% e? accu_op]
937+ [% e? lhs]
938+ ([% e? tern_op] ([% e? rhs1], [% e? rhs2], [% e? rhs3]) ~projections: [% e? projections])] ->
939+ process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~projections
940+ ~proj_in_scope: true ()
815941 | [% expr
816942 [% e? accu_op]
817943 [% e? lhs]
@@ -852,6 +978,25 @@ let translate (expr : expression) : result =
852978 in
853979 let _, bin_op = binary_op bin_op in
854980 process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
981+ | [% expr
982+ [% e? accu_op]
983+ [% e? lhs]
984+ ([% e? tern_op]
985+ ([% e? rhs1], [% e? rhs2], [% e? rhs3])
986+ ~logic: [% e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ }])] ->
987+ let logic =
988+ let loc = s_loc in
989+ if String. equal spec " ." then [% expr Shape. Pointwise_bin ]
990+ else if String. equal spec " @" then [% expr Shape. Compose ]
991+ else
992+ Ast_builder.Default. pexp_extension ~loc
993+ @@ Location. error_extensionf ~loc
994+ " ppx_ocannl %%cd: expected <.> or <@>, found <%s> -- einsum notation for ternary \
995+ operators not supported yet, see issue #305"
996+ spec
997+ in
998+ let _, tern_op = binary_op tern_op in
999+ process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
8551000 | [% expr
8561001 [% e? accu_op]
8571002 [% e? lhs]
@@ -882,6 +1027,13 @@ let translate (expr : expression) : result =
8821027 [% e? rhs2])]
8831028 when is_assignment accu_ident && Hashtbl. mem binary_ops binop_ident && proj_in_scope ->
8841029 process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~proj_in_scope ()
1030+ | [% expr
1031+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1032+ [% e? lhs]
1033+ ([% e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
1034+ ([% e? rhs1], [% e? rhs2], [% e? rhs3]))]
1035+ when is_assignment accu_ident && Hashtbl. mem ternary_ops ternop_ident && proj_in_scope ->
1036+ process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~proj_in_scope ()
8851037 | [% expr
8861038 [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
8871039 [% e? lhs]
@@ -905,6 +1057,14 @@ let translate (expr : expression) : result =
9051057 when is_assignment accu_ident && Hashtbl. mem binary_ops binop_ident ->
9061058 let logic, bin_op = binary_op bin_op in
9071059 process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
1060+ | [% expr
1061+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1062+ [% e? lhs]
1063+ ([% e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
1064+ ([% e? rhs1], [% e? rhs2], [% e? rhs3]))]
1065+ when is_assignment accu_ident && Hashtbl. mem ternary_ops ternop_ident ->
1066+ let logic, tern_op = ternary_op tern_op in
1067+ process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
9081068 | [% expr
9091069 [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
9101070 [% e? lhs]
0 commit comments