@@ -35,90 +35,6 @@ let is_unknown = function Unknown -> true | _ -> false
3535
3636type projections_slot = LHS | RHS1 | RHS2 | RHS3 | Nonslot | Undet [@@ deriving equal , sexp ]
3737
38- let assignment_op expr =
39- (* This should stay in sync with Arrayjit.Ops.assign_op_cd_syntax. *)
40- let loc = expr.pexp_loc in
41- match expr with
42- | [% expr ( =: )] -> (false , [% expr Arrayjit.Ops. Arg2 ])
43- | [% expr ( =+ )] -> (false , [% expr Arrayjit.Ops. Add ])
44- | [% expr ( =- )] -> (false , [% expr Arrayjit.Ops. Sub ])
45- | [% expr ( =* )] -> (false , [% expr Arrayjit.Ops. Mul ])
46- | [% expr ( =/ )] -> (false , [% expr Arrayjit.Ops. Div ])
47- | [% expr ( =** )] -> (false , [% expr Arrayjit.Ops. ToPowOf ])
48- | [% expr ( =?/ )] -> (false , [% expr Arrayjit.Ops. Relu_gate ])
49- | [% expr ( =|| )] -> (false , [% expr Arrayjit.Ops. Or ])
50- | [% expr ( =&& )] -> (false , [% expr Arrayjit.Ops. And ])
51- | [% expr ( =@^ )] -> (false , [% expr Arrayjit.Ops. Max ])
52- | [% expr ( =^^ )] -> (false , [% expr Arrayjit.Ops. Min ])
53- | [% expr ( =:+ )] -> (true , [% expr Arrayjit.Ops. Add ])
54- | [% expr ( =:- )] -> (true , [% expr Arrayjit.Ops. Sub ])
55- | [% expr ( =:* )] -> (true , [% expr Arrayjit.Ops. Mul ])
56- | [% expr ( =:/ )] -> (true , [% expr Arrayjit.Ops. Div ])
57- | [% expr ( =:** )] -> (true , [% expr Arrayjit.Ops. ToPowOf ])
58- | [% expr ( =:?/ )] -> (true , [% expr Arrayjit.Ops. Relu_gate ])
59- | [% expr ( =:|| )] -> (true , [% expr Arrayjit.Ops. Or ])
60- | [% expr ( =:&& )] -> (true , [% expr Arrayjit.Ops. And ])
61- | [% expr ( =:@^ )] -> (true , [% expr Arrayjit.Ops. Max ])
62- | [% expr ( =:^^ )] -> (true , [% expr Arrayjit.Ops. Min ])
63- | _ ->
64- ( false ,
65- Ast_builder.Default. pexp_extension ~loc
66- @@ Location. error_extensionf ~loc
67- " ppx_ocannl %%cd: expected an assignment operator, one of: %s %s"
68- " =+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =|| (Or), \
69- =&& (And), =@^ (Max), =^^ (Min), =: (Arg2), =:+, =:-,"
70- " =:*, =:/, =:**, =:?/, =:||, =:&&, =:@^, =:^^ (same with initializing the tensor to \
71- the neutral value before the start of the calculation)" )
72-
73- let binary_op expr =
74- (* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)
75- (* FIXME: get rid of this and use binary_ops table instead. *)
76- let loc = expr.pexp_loc in
77- match expr with
78- | [% expr ( + )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Add ])
79- | [% expr ( - )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Sub ])
80- | [% expr ( * )] ->
81- ( Ast_builder.Default. pexp_extension ~loc
82- @@ Location. error_extensionf ~loc
83- " No default compose type for binary `*`, try e.g. ~logic:\" .\" for pointwise, %s"
84- " ~logic:\" @\" for matrix multiplication" ,
85- [% expr Arrayjit.Ops. Mul ] )
86- | [% expr ( / )] ->
87- ( Ast_builder.Default. pexp_extension ~loc
88- @@ Location. error_extensionf ~loc
89- " For clarity, no default compose type for binary `/`, use ~logic:\" .\" for pointwise \
90- division" ,
91- [% expr Arrayjit.Ops. Div ] )
92- | [% expr ( ** )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. ToPowOf ])
93- | [% expr ( -?/ )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Relu_gate ])
94- | [% expr ( -/> )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Arg2 ])
95- | [% expr ( -@> )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Arg1 ])
96- | [% expr ( < )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Cmplt ])
97- | [% expr ( <> )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Cmpne ])
98- | [% expr ( || )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Or ])
99- | [% expr ( && )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. And ])
100- | [% expr ( % )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Mod ])
101- | [% expr ( @^ )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Max ])
102- | [% expr ( ^^ )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Min ])
103- | _ ->
104- ( [% expr Shape. Pointwise_bin ],
105- Ast_builder.Default. pexp_extension ~loc
106- @@ Location. error_extensionf ~loc " ppx_ocannl %%cd: expected a binary operator, one of: %s"
107- " + (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \
108- (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^ (Max), ^^ (Min)" )
109-
110- let ternary_op expr =
111- (* FIXME: get rid of this and use ternary_ops table instead. *)
112- let loc = expr.pexp_loc in
113- match expr with
114- | [% expr where] -> ([% expr Shape. Pointwise_tern ], [% expr Arrayjit.Ops. Where ])
115- | [% expr fma] -> ([% expr Shape. Compose_accumulate ], [% expr Arrayjit.Ops. FMA ])
116- | _ ->
117- ( [% expr Shape. Pointwise_bin ],
118- Ast_builder.Default. pexp_extension ~loc
119- @@ Location. error_extensionf ~loc " ppx_ocannl %%cd: expected a ternary operator, one of: %s"
120- " where, fma" )
121-
12238type result = {
12339 vbs : value_binding Map .M (String ).t;
12440 (* * [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
@@ -460,6 +376,46 @@ let translate (expr : expression) : result =
460376 { vbs = no_vbs; typ = Tensor ; slot = Undet ; expr; array_opt_of_code = None }
461377 in
462378 let loop = transl ~bad_pun_hints in
379+ let assignment_op accu_op =
380+ loc
381+ |> Option. value_or_thunk (Hashtbl. find assignment_ops accu_op) ~default: (fun () _loc ->
382+ ( false ,
383+ Ast_builder.Default. pexp_extension ~loc
384+ @@ Location. error_extensionf ~loc
385+ " ppx_ocannl %%cd: expected an assignment operator, one of: %s %s"
386+ " =+ (Add), =- (Sub), =* (Mul),=/ (Div), =** (ToPowOf), =?/ (Relu_gate), =|| \
387+ (Or), =&& (And), =@^ (Max), =^^ (Min), =: (Arg2),=:+, =:-,"
388+ " =:*, =:/, =:**, =:?/, =:||, =:&&, =:@^, =:^^ (same with initializing the \
389+ tensor to the neutral value before the start of the calculation)" ))
390+ in
391+ let unary_op un_op =
392+ loc
393+ |> Option. value_or_thunk (Hashtbl. find unary_ops un_op) ~default: (fun () loc ->
394+ ( [% expr Shape. Pointwise_un ],
395+ Ast_builder.Default. pexp_extension ~loc
396+ @@ Location. error_extensionf ~loc
397+ " ppx_ocannl %%cd: expected an assignment operator, one of: %s"
398+ " id, relu, sat01, exp, log, exp2, log2, sin, cos, sqrt, recip, recip_sqrt, \
399+ neg, tanh" ))
400+ in
401+ let binary_op bin_op =
402+ loc
403+ |> Option. value_or_thunk (Hashtbl. find binary_ops bin_op) ~default: (fun () _loc ->
404+ ( [% expr Shape. Pointwise_bin ],
405+ Ast_builder.Default. pexp_extension ~loc
406+ @@ Location. error_extensionf ~loc
407+ " ppx_ocannl %%cd: expected a binary operator, one of: %s"
408+ " + (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -/> (Arg2), \
409+ < (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^(Max), ^^ (Min)" ))
410+ in
411+ let ternary_op tern_op =
412+ loc
413+ |> Option. value_or_thunk (Hashtbl. find ternary_ops tern_op) ~default: (fun () _loc ->
414+ ( [% expr Shape. Pointwise_tern ],
415+ Ast_builder.Default. pexp_extension ~loc
416+ @@ Location. error_extensionf ~loc
417+ " ppx_ocannl %%cd: expected a ternary operator, one of: %s" " where, fma" ))
418+ in
463419 (* FIXME: collapse these (code reuse) *)
464420 let process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ?projections ~proj_in_scope
465421 () =
@@ -590,7 +546,8 @@ let translate (expr : expression) : result =
590546 assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2 ] body
591547 in
592548 let process_assign_unop ~accu_op ~lhs ~un_op ~rhs ?projections ~proj_in_scope () =
593- let initialize_neutral, accu_op = assignment_op accu_op in
549+ let initialize_neutral, accum = assignment_op accu_op in
550+ let _, op = unary_op un_op in
594551 (* FIXME: I think this ignores the slot information here! Just assuming [projections] is
595552 as-should-be, but that's not consistent with omitting the projections arg (assuming it
596553 comes from the context). *)
@@ -620,8 +577,8 @@ let translate (expr : expression) : result =
620577 {
621578 p.debug_info with
622579 trace =
623- ( " ppx_cd " ^ [% e expr2string_or_empty accu_op] ^ " "
624- ^ [% e expr2string_or_empty un_op],
580+ ( " ppx_cd " ^ [% e string_expr ~loc accu_op] ^ " "
581+ ^ [% e string_expr ~loc un_op],
625582 Arrayjit.Indexing. unique_debug_id () )
626583 :: p.debug_info.trace;
627584 };
@@ -636,9 +593,9 @@ let translate (expr : expression) : result =
636593 Arrayjit.Assignments. Accum_unop
637594 {
638595 initialize_neutral = [% e initialize_neutral];
639- accum = [% e accu_op ];
596+ accum = [% e accum ];
640597 lhs;
641- op = [% e un_op ];
598+ op = [% e op ];
642599 rhs;
643600 projections = [% e projections];
644601 })]
@@ -926,45 +883,50 @@ let translate (expr : expression) : result =
926883 }];
927884 }
928885 | [% expr
929- [% e? accu_op]
886+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
930887 [% e? lhs]
931- ([% e? bin_op] [% e? rhs1] ([% e? rhs2] ~projections: [% e? projections]))] ->
888+ ([% e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
889+ [% e? rhs1]
890+ ([% e? rhs2] ~projections: [% e? projections]))] ->
932891 (* Note: when clause not needed here and below, it's an error if bin_op is not a primitive
933892 binary op. *)
934893 process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~projections ~proj_in_scope: true ()
935894 | [% expr
936- [% e? accu_op]
895+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
937896 [% e? lhs]
938- ([% e? tern_op] ([% e? rhs1], [% e? rhs2], [% e? rhs3]) ~projections: [% e? projections])] ->
897+ ([% e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
898+ ([% e? rhs1], [% e? rhs2], [% e? rhs3])
899+ ~projections: [% e? projections])] ->
939900 process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~projections
940901 ~proj_in_scope: true ()
941902 | [% expr
942- [% e? accu_op]
903+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
943904 [% e? lhs]
944- ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident ; _ }; _ }]
905+ ([% e? { pexp_desc = Pexp_ident { txt = Lident un_op ; _ }; _ }]
945906 [% e? rhs]
946907 ~projections: [% e? projections])]
947908 | [% expr
948- [% e? accu_op]
909+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
949910 [% e? lhs]
950- (([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident ; _ }; _ }] [% e? rhs])
911+ (([% e? { pexp_desc = Pexp_ident { txt = Lident un_op ; _ }; _ }] [% e? rhs])
951912 ~projections: [% e? projections])]
952913 | [% expr
953- [% e? accu_op]
914+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
954915 [% e? lhs]
955- ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident ; _ }; _ }]
916+ ([% e? { pexp_desc = Pexp_ident { txt = Lident un_op ; _ }; _ }]
956917 ([% e? rhs] ~projections: [% e? projections]))]
957- when Hashtbl. mem unary_ops unop_ident ->
958- let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
918+ when Hashtbl. mem unary_ops un_op ->
959919 (* Handle both un_op priority levels -- where application binds tighter and less tight. *)
960920 process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~projections ~proj_in_scope: true ()
961- | [% expr [% e? accu_op] [% e? lhs] ([% e? rhs] ~projections: [% e? projections])] ->
962- process_assign_unop ~accu_op ~lhs ~un_op: [% expr Arrayjit.Ops. Identity ] ~rhs ~projections
963- ~proj_in_scope: true ()
964921 | [% expr
965- [% e? accu_op]
922+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
966923 [% e? lhs]
967- ([% e? bin_op]
924+ ([% e? rhs] ~projections: [% e? projections])] ->
925+ process_assign_unop ~accu_op ~lhs ~un_op: " id" ~rhs ~projections ~proj_in_scope: true ()
926+ | [% expr
927+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
928+ [% e? lhs]
929+ ([% e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
968930 [% e? rhs1]
969931 ([% e? rhs2]
970932 ~logic:
@@ -979,9 +941,9 @@ let translate (expr : expression) : result =
979941 let _, bin_op = binary_op bin_op in
980942 process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
981943 | [% expr
982- [% e? accu_op]
944+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
983945 [% e? lhs]
984- ([% e? tern_op]
946+ ([% e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ } ]
985947 ([% e? rhs1], [% e? rhs2], [% e? rhs3])
986948 ~logic: [% e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ }])] ->
987949 let logic =
@@ -995,15 +957,15 @@ let translate (expr : expression) : result =
995957 operators not supported yet, see issue #305"
996958 spec
997959 in
998- let _, tern_op = binary_op tern_op in
960+ let _, tern_op = ternary_op tern_op in
999961 process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
1000962 | [% expr
1001- [% e? accu_op]
963+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
1002964 [% e? lhs]
1003965 (([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [% e? rhs])
1004966 ~logic: [% e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic])]
1005967 | [% expr
1006- [% e? accu_op]
968+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ } ]
1007969 [% e? lhs]
1008970 ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
1009971 ([% e? rhs]
@@ -1017,67 +979,54 @@ let translate (expr : expression) : result =
1017979 else if String. equal spec " T" then [% expr Shape. Transpose ]
1018980 else [% expr Shape. Permute [% e logic]]
1019981 in
1020- let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
982+ let _, un_op = Hashtbl. find_exn unary_ops unop_ident loc in
1021983 process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
1022984 | [% expr
1023- [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident ; _ }; _ } as accu_op ]
985+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op ; _ }; _ }]
1024986 [% e? lhs]
1025- ([% e? { pexp_desc = Pexp_ident { txt = Lident binop_ident; _ }; _ } as bin_op]
1026- [% e? rhs1]
1027- [% e? rhs2])]
1028- when is_assignment accu_ident && Hashtbl. mem binary_ops binop_ident && proj_in_scope ->
987+ ([% e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }] [% e? rhs1] [% e? rhs2])]
988+ when is_assignment accu_op && Hashtbl. mem binary_ops bin_op && proj_in_scope ->
1029989 process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~proj_in_scope ()
1030990 | [% expr
1031- [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident ; _ }; _ } as accu_op ]
991+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op ; _ }; _ }]
1032992 [% e? lhs]
1033- ([% e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident ; _ }; _ } as tern_op ]
993+ ([% e? { pexp_desc = Pexp_ident { txt = Lident tern_op ; _ }; _ }]
1034994 ([% e? rhs1], [% e? rhs2], [% e? rhs3]))]
1035- when is_assignment accu_ident && Hashtbl. mem ternary_ops ternop_ident && proj_in_scope ->
995+ when is_assignment accu_op && Hashtbl. mem ternary_ops tern_op && proj_in_scope ->
1036996 process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~proj_in_scope ()
1037997 | [% expr
1038- [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident ; _ }; _ } as accu_op ]
998+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op ; _ }; _ }]
1039999 [% e? lhs]
1040- ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [% e? rhs])]
1041- when is_assignment accu_ident && Hashtbl. mem unary_ops unop_ident && proj_in_scope ->
1042- let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
1000+ ([% e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }] [% e? rhs])]
1001+ when is_assignment accu_op && Hashtbl. mem unary_ops un_op && proj_in_scope ->
10431002 process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~proj_in_scope ()
1003+ | [% expr [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }] [% e? lhs] [% e? rhs]]
1004+ when is_assignment accu_op && proj_in_scope ->
1005+ process_assign_unop ~accu_op ~lhs ~un_op: " id" ~rhs ~proj_in_scope ()
10441006 | [% expr
1045- [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
1046- [% e? lhs]
1047- [% e? rhs]]
1048- when is_assignment op_ident && proj_in_scope ->
1049- process_assign_unop ~accu_op ~lhs ~un_op: [% expr Arrayjit.Ops. Identity ] ~rhs ~proj_in_scope
1050- ()
1051- | [% expr
1052- [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1007+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
10531008 [% e? lhs]
1054- ([% e? { pexp_desc = Pexp_ident { txt = Lident binop_ident; _ }; _ } as bin_op]
1055- [% e? rhs1]
1056- [% e? rhs2])]
1057- when is_assignment accu_ident && Hashtbl. mem binary_ops binop_ident ->
1009+ ([% e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }] [% e? rhs1] [% e? rhs2])]
1010+ when is_assignment accu_op && Hashtbl. mem binary_ops bin_op ->
10581011 let logic, bin_op = binary_op bin_op in
10591012 process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
10601013 | [% expr
1061- [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident ; _ }; _ } as accu_op ]
1014+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op ; _ }; _ }]
10621015 [% e? lhs]
1063- ([% e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident ; _ }; _ } as tern_op ]
1016+ ([% e? { pexp_desc = Pexp_ident { txt = Lident tern_op ; _ }; _ }]
10641017 ([% e? rhs1], [% e? rhs2], [% e? rhs3]))]
1065- when is_assignment accu_ident && Hashtbl. mem ternary_ops ternop_ident ->
1018+ when is_assignment accu_op && Hashtbl. mem ternary_ops tern_op ->
10661019 let logic, tern_op = ternary_op tern_op in
10671020 process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
10681021 | [% expr
1069- [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1070- [% e? lhs]
1071- ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [% e? rhs])]
1072- when is_assignment accu_ident && Hashtbl. mem unary_ops unop_ident ->
1073- let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
1074- (* FIXME: projections logic! *)
1075- process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic: [% expr Pointwise_un ]
1076- | [% expr
1077- [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
1022+ [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
10781023 [% e? lhs]
1079- [% e? rhs]]
1080- when is_assignment op_ident ->
1024+ ([% e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }] [% e? rhs])]
1025+ when is_assignment accu_op && Hashtbl. mem unary_ops un_op ->
1026+ let logic, un_op = Hashtbl. find_exn unary_ops un_op loc in
1027+ process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
1028+ | [% expr [% e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }] [% e? lhs] [% e? rhs]]
1029+ when is_assignment accu_op ->
10811030 process_raw_unop ~accu_op ~lhs ~un_op: [% expr Arrayjit.Ops. Identity ] ~rhs
10821031 ~logic: [% expr Shape. Pointwise_un ]
10831032 | [% expr [% e? expr1] [% e? expr2] [% e? expr3]] ->
0 commit comments