Skip to content

Commit bcd89d3

Browse files
committed
Refactors %cd parsing of primitive ops (uniform, hashtable lookups)
1 parent 9682b4a commit bcd89d3

File tree

3 files changed

+218
-197
lines changed

3 files changed

+218
-197
lines changed

lib/ppx_cd.ml

Lines changed: 99 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -35,90 +35,6 @@ let is_unknown = function Unknown -> true | _ -> false
3535

3636
type projections_slot = LHS | RHS1 | RHS2 | RHS3 | Nonslot | Undet [@@deriving equal, sexp]
3737

38-
let assignment_op expr =
39-
(* This should stay in sync with Arrayjit.Ops.assign_op_cd_syntax. *)
40-
let loc = expr.pexp_loc in
41-
match expr with
42-
| [%expr ( =: )] -> (false, [%expr Arrayjit.Ops.Arg2])
43-
| [%expr ( =+ )] -> (false, [%expr Arrayjit.Ops.Add])
44-
| [%expr ( =- )] -> (false, [%expr Arrayjit.Ops.Sub])
45-
| [%expr ( =* )] -> (false, [%expr Arrayjit.Ops.Mul])
46-
| [%expr ( =/ )] -> (false, [%expr Arrayjit.Ops.Div])
47-
| [%expr ( =** )] -> (false, [%expr Arrayjit.Ops.ToPowOf])
48-
| [%expr ( =?/ )] -> (false, [%expr Arrayjit.Ops.Relu_gate])
49-
| [%expr ( =|| )] -> (false, [%expr Arrayjit.Ops.Or])
50-
| [%expr ( =&& )] -> (false, [%expr Arrayjit.Ops.And])
51-
| [%expr ( =@^ )] -> (false, [%expr Arrayjit.Ops.Max])
52-
| [%expr ( =^^ )] -> (false, [%expr Arrayjit.Ops.Min])
53-
| [%expr ( =:+ )] -> (true, [%expr Arrayjit.Ops.Add])
54-
| [%expr ( =:- )] -> (true, [%expr Arrayjit.Ops.Sub])
55-
| [%expr ( =:* )] -> (true, [%expr Arrayjit.Ops.Mul])
56-
| [%expr ( =:/ )] -> (true, [%expr Arrayjit.Ops.Div])
57-
| [%expr ( =:** )] -> (true, [%expr Arrayjit.Ops.ToPowOf])
58-
| [%expr ( =:?/ )] -> (true, [%expr Arrayjit.Ops.Relu_gate])
59-
| [%expr ( =:|| )] -> (true, [%expr Arrayjit.Ops.Or])
60-
| [%expr ( =:&& )] -> (true, [%expr Arrayjit.Ops.And])
61-
| [%expr ( =:@^ )] -> (true, [%expr Arrayjit.Ops.Max])
62-
| [%expr ( =:^^ )] -> (true, [%expr Arrayjit.Ops.Min])
63-
| _ ->
64-
( false,
65-
Ast_builder.Default.pexp_extension ~loc
66-
@@ Location.error_extensionf ~loc
67-
"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s"
68-
"=+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =|| (Or), \
69-
=&& (And), =@^ (Max), =^^ (Min), =: (Arg2), =:+, =:-,"
70-
" =:*, =:/, =:**, =:?/, =:||, =:&&, =:@^, =:^^ (same with initializing the tensor to \
71-
the neutral value before the start of the calculation)" )
72-
73-
let binary_op expr =
74-
(* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)
75-
(* FIXME: get rid of this and use binary_ops table instead. *)
76-
let loc = expr.pexp_loc in
77-
match expr with
78-
| [%expr ( + )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Add])
79-
| [%expr ( - )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Sub])
80-
| [%expr ( * )] ->
81-
( Ast_builder.Default.pexp_extension ~loc
82-
@@ Location.error_extensionf ~loc
83-
"No default compose type for binary `*`, try e.g. ~logic:\".\" for pointwise, %s"
84-
"~logic:\"@\" for matrix multiplication",
85-
[%expr Arrayjit.Ops.Mul] )
86-
| [%expr ( / )] ->
87-
( Ast_builder.Default.pexp_extension ~loc
88-
@@ Location.error_extensionf ~loc
89-
"For clarity, no default compose type for binary `/`, use ~logic:\".\" for pointwise \
90-
division",
91-
[%expr Arrayjit.Ops.Div] )
92-
| [%expr ( ** )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.ToPowOf])
93-
| [%expr ( -?/ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Relu_gate])
94-
| [%expr ( -/> )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Arg2])
95-
| [%expr ( -@> )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Arg1])
96-
| [%expr ( < )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt])
97-
| [%expr ( <> )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne])
98-
| [%expr ( || )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or])
99-
| [%expr ( && )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.And])
100-
| [%expr ( % )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Mod])
101-
| [%expr ( @^ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Max])
102-
| [%expr ( ^^ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Min])
103-
| _ ->
104-
( [%expr Shape.Pointwise_bin],
105-
Ast_builder.Default.pexp_extension ~loc
106-
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected a binary operator, one of: %s"
107-
"+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \
108-
(Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^ (Max), ^^ (Min)" )
109-
110-
let ternary_op expr =
111-
(* FIXME: get rid of this and use ternary_ops table instead. *)
112-
let loc = expr.pexp_loc in
113-
match expr with
114-
| [%expr where] -> ([%expr Shape.Pointwise_tern], [%expr Arrayjit.Ops.Where])
115-
| [%expr fma] -> ([%expr Shape.Compose_accumulate], [%expr Arrayjit.Ops.FMA])
116-
| _ ->
117-
( [%expr Shape.Pointwise_bin],
118-
Ast_builder.Default.pexp_extension ~loc
119-
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected a ternary operator, one of: %s"
120-
"where, fma" )
121-
12238
type result = {
12339
vbs : value_binding Map.M(String).t;
12440
(** [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
@@ -460,6 +376,46 @@ let translate (expr : expression) : result =
460376
{ vbs = no_vbs; typ = Tensor; slot = Undet; expr; array_opt_of_code = None }
461377
in
462378
let loop = transl ~bad_pun_hints in
379+
let assignment_op accu_op =
380+
loc
381+
|> Option.value_or_thunk (Hashtbl.find assignment_ops accu_op) ~default:(fun () _loc ->
382+
( false,
383+
Ast_builder.Default.pexp_extension ~loc
384+
@@ Location.error_extensionf ~loc
385+
"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s"
386+
"=+ (Add), =- (Sub), =* (Mul),=/ (Div), =** (ToPowOf), =?/ (Relu_gate), =|| \
387+
(Or), =&& (And), =@^ (Max), =^^ (Min), =: (Arg2),=:+, =:-,"
388+
" =:*, =:/, =:**, =:?/, =:||, =:&&, =:@^, =:^^ (same with initializing the \
389+
tensor to the neutral value before the start of the calculation)" ))
390+
in
391+
let unary_op un_op =
392+
loc
393+
|> Option.value_or_thunk (Hashtbl.find unary_ops un_op) ~default:(fun () loc ->
394+
( [%expr Shape.Pointwise_un],
395+
Ast_builder.Default.pexp_extension ~loc
396+
@@ Location.error_extensionf ~loc
397+
"ppx_ocannl %%cd: expected an assignment operator, one of: %s"
398+
"id, relu, sat01, exp, log, exp2, log2, sin, cos, sqrt, recip, recip_sqrt, \
399+
neg, tanh" ))
400+
in
401+
let binary_op bin_op =
402+
loc
403+
|> Option.value_or_thunk (Hashtbl.find binary_ops bin_op) ~default:(fun () _loc ->
404+
( [%expr Shape.Pointwise_bin],
405+
Ast_builder.Default.pexp_extension ~loc
406+
@@ Location.error_extensionf ~loc
407+
"ppx_ocannl %%cd: expected a binary operator, one of: %s"
408+
"+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -/> (Arg2), \
409+
< (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^(Max), ^^ (Min)" ))
410+
in
411+
let ternary_op tern_op =
412+
loc
413+
|> Option.value_or_thunk (Hashtbl.find ternary_ops tern_op) ~default:(fun () _loc ->
414+
( [%expr Shape.Pointwise_tern],
415+
Ast_builder.Default.pexp_extension ~loc
416+
@@ Location.error_extensionf ~loc
417+
"ppx_ocannl %%cd: expected a ternary operator, one of: %s" "where, fma" ))
418+
in
463419
(* FIXME: collapse these (code reuse) *)
464420
let process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ?projections ~proj_in_scope
465421
() =
@@ -590,7 +546,8 @@ let translate (expr : expression) : result =
590546
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2 ] body
591547
in
592548
let process_assign_unop ~accu_op ~lhs ~un_op ~rhs ?projections ~proj_in_scope () =
593-
let initialize_neutral, accu_op = assignment_op accu_op in
549+
let initialize_neutral, accum = assignment_op accu_op in
550+
let _, op = unary_op un_op in
594551
(* FIXME: I think this ignores the slot information here! Just assuming [projections] is
595552
as-should-be, but that's not consistent with omitting the projections arg (assuming it
596553
comes from the context). *)
@@ -620,8 +577,8 @@ let translate (expr : expression) : result =
620577
{
621578
p.debug_info with
622579
trace =
623-
( "ppx_cd " ^ [%e expr2string_or_empty accu_op] ^ " "
624-
^ [%e expr2string_or_empty un_op],
580+
( "ppx_cd " ^ [%e string_expr ~loc accu_op] ^ " "
581+
^ [%e string_expr ~loc un_op],
625582
Arrayjit.Indexing.unique_debug_id () )
626583
:: p.debug_info.trace;
627584
};
@@ -636,9 +593,9 @@ let translate (expr : expression) : result =
636593
Arrayjit.Assignments.Accum_unop
637594
{
638595
initialize_neutral = [%e initialize_neutral];
639-
accum = [%e accu_op];
596+
accum = [%e accum];
640597
lhs;
641-
op = [%e un_op];
598+
op = [%e op];
642599
rhs;
643600
projections = [%e projections];
644601
})]
@@ -926,45 +883,50 @@ let translate (expr : expression) : result =
926883
}];
927884
}
928885
| [%expr
929-
[%e? accu_op]
886+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
930887
[%e? lhs]
931-
([%e? bin_op] [%e? rhs1] ([%e? rhs2] ~projections:[%e? projections]))] ->
888+
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
889+
[%e? rhs1]
890+
([%e? rhs2] ~projections:[%e? projections]))] ->
932891
(* Note: when clause not needed here and below, it's an error if bin_op is not a primitive
933892
binary op. *)
934893
process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~projections ~proj_in_scope:true ()
935894
| [%expr
936-
[%e? accu_op]
895+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
937896
[%e? lhs]
938-
([%e? tern_op] ([%e? rhs1], [%e? rhs2], [%e? rhs3]) ~projections:[%e? projections])] ->
897+
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
898+
([%e? rhs1], [%e? rhs2], [%e? rhs3])
899+
~projections:[%e? projections])] ->
939900
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~projections
940901
~proj_in_scope:true ()
941902
| [%expr
942-
[%e? accu_op]
903+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
943904
[%e? lhs]
944-
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
905+
([%e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }]
945906
[%e? rhs]
946907
~projections:[%e? projections])]
947908
| [%expr
948-
[%e? accu_op]
909+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
949910
[%e? lhs]
950-
(([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])
911+
(([%e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }] [%e? rhs])
951912
~projections:[%e? projections])]
952913
| [%expr
953-
[%e? accu_op]
914+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
954915
[%e? lhs]
955-
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
916+
([%e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }]
956917
([%e? rhs] ~projections:[%e? projections]))]
957-
when Hashtbl.mem unary_ops unop_ident ->
958-
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
918+
when Hashtbl.mem unary_ops un_op ->
959919
(* Handle both un_op priority levels -- where application binds tighter and less tight. *)
960920
process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~projections ~proj_in_scope:true ()
961-
| [%expr [%e? accu_op] [%e? lhs] ([%e? rhs] ~projections:[%e? projections])] ->
962-
process_assign_unop ~accu_op ~lhs ~un_op:[%expr Arrayjit.Ops.Identity] ~rhs ~projections
963-
~proj_in_scope:true ()
964921
| [%expr
965-
[%e? accu_op]
922+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
966923
[%e? lhs]
967-
([%e? bin_op]
924+
([%e? rhs] ~projections:[%e? projections])] ->
925+
process_assign_unop ~accu_op ~lhs ~un_op:"id" ~rhs ~projections ~proj_in_scope:true ()
926+
| [%expr
927+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
928+
[%e? lhs]
929+
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }]
968930
[%e? rhs1]
969931
([%e? rhs2]
970932
~logic:
@@ -979,9 +941,9 @@ let translate (expr : expression) : result =
979941
let _, bin_op = binary_op bin_op in
980942
process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
981943
| [%expr
982-
[%e? accu_op]
944+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
983945
[%e? lhs]
984-
([%e? tern_op]
946+
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
985947
([%e? rhs1], [%e? rhs2], [%e? rhs3])
986948
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ }])] ->
987949
let logic =
@@ -995,15 +957,15 @@ let translate (expr : expression) : result =
995957
operators not supported yet, see issue #305"
996958
spec
997959
in
998-
let _, tern_op = binary_op tern_op in
960+
let _, tern_op = ternary_op tern_op in
999961
process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
1000962
| [%expr
1001-
[%e? accu_op]
963+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
1002964
[%e? lhs]
1003965
(([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])
1004966
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic])]
1005967
| [%expr
1006-
[%e? accu_op]
968+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
1007969
[%e? lhs]
1008970
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
1009971
([%e? rhs]
@@ -1017,67 +979,54 @@ let translate (expr : expression) : result =
1017979
else if String.equal spec "T" then [%expr Shape.Transpose]
1018980
else [%expr Shape.Permute [%e logic]]
1019981
in
1020-
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
982+
let _, un_op = Hashtbl.find_exn unary_ops unop_ident loc in
1021983
process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
1022984
| [%expr
1023-
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
985+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
1024986
[%e? lhs]
1025-
([%e? { pexp_desc = Pexp_ident { txt = Lident binop_ident; _ }; _ } as bin_op]
1026-
[%e? rhs1]
1027-
[%e? rhs2])]
1028-
when is_assignment accu_ident && Hashtbl.mem binary_ops binop_ident && proj_in_scope ->
987+
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }] [%e? rhs1] [%e? rhs2])]
988+
when is_assignment accu_op && Hashtbl.mem binary_ops bin_op && proj_in_scope ->
1029989
process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~proj_in_scope ()
1030990
| [%expr
1031-
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
991+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
1032992
[%e? lhs]
1033-
([%e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
993+
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
1034994
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
1035-
when is_assignment accu_ident && Hashtbl.mem ternary_ops ternop_ident && proj_in_scope ->
995+
when is_assignment accu_op && Hashtbl.mem ternary_ops tern_op && proj_in_scope ->
1036996
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~proj_in_scope ()
1037997
| [%expr
1038-
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
998+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
1039999
[%e? lhs]
1040-
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])]
1041-
when is_assignment accu_ident && Hashtbl.mem unary_ops unop_ident && proj_in_scope ->
1042-
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
1000+
([%e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }] [%e? rhs])]
1001+
when is_assignment accu_op && Hashtbl.mem unary_ops un_op && proj_in_scope ->
10431002
process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~proj_in_scope ()
1003+
| [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }] [%e? lhs] [%e? rhs]]
1004+
when is_assignment accu_op && proj_in_scope ->
1005+
process_assign_unop ~accu_op ~lhs ~un_op:"id" ~rhs ~proj_in_scope ()
10441006
| [%expr
1045-
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
1046-
[%e? lhs]
1047-
[%e? rhs]]
1048-
when is_assignment op_ident && proj_in_scope ->
1049-
process_assign_unop ~accu_op ~lhs ~un_op:[%expr Arrayjit.Ops.Identity] ~rhs ~proj_in_scope
1050-
()
1051-
| [%expr
1052-
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1007+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
10531008
[%e? lhs]
1054-
([%e? { pexp_desc = Pexp_ident { txt = Lident binop_ident; _ }; _ } as bin_op]
1055-
[%e? rhs1]
1056-
[%e? rhs2])]
1057-
when is_assignment accu_ident && Hashtbl.mem binary_ops binop_ident ->
1009+
([%e? { pexp_desc = Pexp_ident { txt = Lident bin_op; _ }; _ }] [%e? rhs1] [%e? rhs2])]
1010+
when is_assignment accu_op && Hashtbl.mem binary_ops bin_op ->
10581011
let logic, bin_op = binary_op bin_op in
10591012
process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
10601013
| [%expr
1061-
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1014+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
10621015
[%e? lhs]
1063-
([%e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
1016+
([%e? { pexp_desc = Pexp_ident { txt = Lident tern_op; _ }; _ }]
10641017
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
1065-
when is_assignment accu_ident && Hashtbl.mem ternary_ops ternop_ident ->
1018+
when is_assignment accu_op && Hashtbl.mem ternary_ops tern_op ->
10661019
let logic, tern_op = ternary_op tern_op in
10671020
process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
10681021
| [%expr
1069-
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1070-
[%e? lhs]
1071-
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])]
1072-
when is_assignment accu_ident && Hashtbl.mem unary_ops unop_ident ->
1073-
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
1074-
(* FIXME: projections logic! *)
1075-
process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic:[%expr Pointwise_un]
1076-
| [%expr
1077-
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
1022+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }]
10781023
[%e? lhs]
1079-
[%e? rhs]]
1080-
when is_assignment op_ident ->
1024+
([%e? { pexp_desc = Pexp_ident { txt = Lident un_op; _ }; _ }] [%e? rhs])]
1025+
when is_assignment accu_op && Hashtbl.mem unary_ops un_op ->
1026+
let logic, un_op = Hashtbl.find_exn unary_ops un_op loc in
1027+
process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
1028+
| [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident accu_op; _ }; _ }] [%e? lhs] [%e? rhs]]
1029+
when is_assignment accu_op ->
10811030
process_raw_unop ~accu_op ~lhs ~un_op:[%expr Arrayjit.Ops.Identity] ~rhs
10821031
~logic:[%expr Shape.Pointwise_un]
10831032
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->

0 commit comments

Comments
 (0)