@@ -112,19 +112,25 @@ let is_binary_op ident =
112112 [ " +" ; " -" ; " *" ; " /" ; " **" ; " -?/" ; " -/>" ; " -@>" ; " <" ; " <>" ; " &&" ; " %" ; " @^" ; " ^^" ]
113113 ident ~equal: String. equal
114114
115- let unary_op expr =
116- (* This and is_unary_op should stay in sync with Arrayjit.Ops.unop_cd_syntax. *)
117- let loc = expr.pexp_loc in
118- match expr with
119- | [% expr ( ~= )] -> ([% expr Shape. Pointwise_un ], [% expr Arrayjit.Ops. Identity ])
120- | [% expr ( ?/ )] -> ([% expr Shape. Pointwise_un ], [% expr Arrayjit.Ops. Relu ])
121- | _ ->
122- ( [% expr Shape. Pointwise_un ],
123- Ast_builder.Default. pexp_extension ~loc
124- @@ Location. error_extensionf ~loc
125- " ppx_ocannl %%cd: expected a unary operator, one of: = (Identity), ?/ (Relu)" )
126-
127- let is_unary_op ident = List. mem [ " ~=" ; " ?/" ] ident ~equal: String. equal
115+ let unary_ops =
116+ Hashtbl. of_alist_exn
117+ (module String )
118+ [
119+ (" id" , fun loc -> [% expr Arrayjit.Ops. Identity ]);
120+ (" relu" , fun loc -> [% expr Arrayjit.Ops. Relu ]);
121+ (" sat01" , fun loc -> [% expr Arrayjit.Ops. Satur01 ]);
122+ (" exp" , fun loc -> [% expr Arrayjit.Ops. Exp ]);
123+ (" log" , fun loc -> [% expr Arrayjit.Ops. Log ]);
124+ (" exp2" , fun loc -> [% expr Arrayjit.Ops. Exp2 ]);
125+ (" log2" , fun loc -> [% expr Arrayjit.Ops. Log2 ]);
126+ (" sin" , fun loc -> [% expr Arrayjit.Ops. Sin ]);
127+ (" cos" , fun loc -> [% expr Arrayjit.Ops. Cos ]);
128+ (" sqrt" , fun loc -> [% expr Arrayjit.Ops. Sqrt ]);
129+ (" recip" , fun loc -> [% expr Arrayjit.Ops. Recip ]);
130+ (" recip_sqrt" , fun loc -> [% expr Arrayjit.Ops. Recip_sqrt ]);
131+ (" neg" , fun loc -> [% expr Arrayjit.Ops. Neg ]);
132+ (" tanh" , fun loc -> [% expr Arrayjit.Ops. Tanh_approx ]);
133+ ]
128134
129135type result = {
130136 vbs : value_binding Map .M (String ).t;
@@ -832,9 +838,24 @@ let translate (expr : expression) : result =
832838 [% e? lhs]
833839 ([% e? bin_op] [% e? rhs1] ([% e? rhs2] ~projections: [% e? projections]))] ->
834840 process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~projections ~proj_in_scope: true ()
835- | [% expr [% e? accu_op] [% e? lhs] (([% e? un_op] [% e? rhs]) ~projections: [% e? projections])]
836- | [% expr [% e? accu_op] [% e? lhs] ([% e? un_op] ([% e? rhs] ~projections: [% e? projections]))] ->
837- let _, un_op = unary_op un_op in
841+ | [% expr
842+ [% e? accu_op]
843+ [% e? lhs]
844+ ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
845+ [% e? rhs]
846+ ~projections: [% e? projections])]
847+ | [% expr
848+ [% e? accu_op]
849+ [% e? lhs]
850+ (([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [% e? rhs])
851+ ~projections: [% e? projections])]
852+ | [% expr
853+ [% e? accu_op]
854+ [% e? lhs]
855+ ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
856+ ([% e? rhs] ~projections: [% e? projections]))]
857+ when Hashtbl. mem unary_ops unop_ident ->
858+ let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
838859 (* Handle both un_op priority levels -- where application binds tighter and less tight. *)
839860 process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~projections ~proj_in_scope: true ()
840861 | [% expr [% e? accu_op] [% e? lhs] ([% e? rhs] ~projections: [% e? projections])] ->
@@ -860,24 +881,24 @@ let translate (expr : expression) : result =
860881 | [% expr
861882 [% e? accu_op]
862883 [% e? lhs]
863- (([% e? un_op ] [% e? rhs])
884+ (([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ } ] [% e? rhs])
864885 ~logic: [% e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic])]
865886 | [% expr
866887 [% e? accu_op]
867888 [% e? lhs]
868- ([% e? un_op ]
889+ ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ } ]
869890 ([% e? rhs]
870891 ~logic:
871892 [% e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic]))]
872- ->
893+ when Hashtbl. mem unary_ops unop_ident ->
873894 (* Handle both un_op priority levels -- where application binds tighter and less tight. *)
874895 let logic =
875896 let loc = s_loc in
876897 if String. equal spec " ." then [% expr Shape. Pointwise_un ]
877898 else if String. equal spec " T" then [% expr Shape. Transpose ]
878899 else [% expr Shape. Permute [% e logic]]
879900 in
880- let _, un_op = unary_op un_op in
901+ let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
881902 process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
882903 | [% expr
883904 [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
@@ -890,9 +911,9 @@ let translate (expr : expression) : result =
890911 | [% expr
891912 [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
892913 [% e? lhs]
893- ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ } as un_op ] [% e? rhs])]
894- when is_assignment accu_ident && is_unary_op unop_ident && proj_in_scope ->
895- let _, un_op = unary_op un_op in
914+ ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [% e? rhs])]
915+ when is_assignment accu_ident && Hashtbl. mem unary_ops unop_ident && proj_in_scope ->
916+ let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
896917 process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~proj_in_scope ()
897918 | [% expr
898919 [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
@@ -913,10 +934,11 @@ let translate (expr : expression) : result =
913934 | [% expr
914935 [% e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
915936 [% e? lhs]
916- ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ } as un_op] [% e? rhs])]
917- when is_assignment accu_ident && is_unary_op unop_ident ->
918- let logic, un_op = unary_op un_op in
919- process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
937+ ([% e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [% e? rhs])]
938+ when is_assignment accu_ident && Hashtbl. mem unary_ops unop_ident ->
939+ let un_op = Hashtbl. find_exn unary_ops unop_ident loc in
940+ (* FIXME: projections logic! *)
941+ process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic: [% expr Pointwise_un ]
920942 | [% expr
921943 [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
922944 [% e? lhs]
0 commit comments