Skip to content

Commit e126bd2

Browse files
committed
Updated parsing of unary ops (%cd syntax)
1 parent ffbea70 commit e126bd2

File tree

2 files changed

+50
-28
lines changed

2 files changed

+50
-28
lines changed

lib/operation.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ let einsum1 ?(label = []) spec =
125125

126126
let relu ?(label = []) =
127127
let module NTDSL = Initial_NTDSL in
128-
let%cd op_asn ~v ~t1 ~projections = v =: ?/v1 ~projections in
128+
let%cd op_asn ~v ~t1 ~projections = v =: relu v1 ~projections in
129129
let%cd grad_asn ~v ~g ~t1 ~projections = g1 =+ v -?/ g in
130130
Tensor.unop ~label:("?/" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
131131

lib/ppx_cd.ml

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,25 @@ let is_binary_op ident =
112112
[ "+"; "-"; "*"; "/"; "**"; "-?/"; "-/>"; "-@>"; "<"; "<>"; "&&"; "%"; "@^"; "^^" ]
113113
ident ~equal:String.equal
114114

115-
let unary_op expr =
116-
(* This and is_unary_op should stay in sync with Arrayjit.Ops.unop_cd_syntax. *)
117-
let loc = expr.pexp_loc in
118-
match expr with
119-
| [%expr ( ~= )] -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Identity])
120-
| [%expr ( ?/ )] -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Relu])
121-
| _ ->
122-
( [%expr Shape.Pointwise_un],
123-
Ast_builder.Default.pexp_extension ~loc
124-
@@ Location.error_extensionf ~loc
125-
"ppx_ocannl %%cd: expected a unary operator, one of: = (Identity), ?/ (Relu)" )
126-
127-
let is_unary_op ident = List.mem [ "~="; "?/" ] ident ~equal:String.equal
115+
let unary_ops =
116+
Hashtbl.of_alist_exn
117+
(module String)
118+
[
119+
("id", fun loc -> [%expr Arrayjit.Ops.Identity]);
120+
("relu", fun loc -> [%expr Arrayjit.Ops.Relu]);
121+
("sat01", fun loc -> [%expr Arrayjit.Ops.Satur01]);
122+
("exp", fun loc -> [%expr Arrayjit.Ops.Exp]);
123+
("log", fun loc -> [%expr Arrayjit.Ops.Log]);
124+
("exp2", fun loc -> [%expr Arrayjit.Ops.Exp2]);
125+
("log2", fun loc -> [%expr Arrayjit.Ops.Log2]);
126+
("sin", fun loc -> [%expr Arrayjit.Ops.Sin]);
127+
("cos", fun loc -> [%expr Arrayjit.Ops.Cos]);
128+
("sqrt", fun loc -> [%expr Arrayjit.Ops.Sqrt]);
129+
("recip", fun loc -> [%expr Arrayjit.Ops.Recip]);
130+
("recip_sqrt", fun loc -> [%expr Arrayjit.Ops.Recip_sqrt]);
131+
("neg", fun loc -> [%expr Arrayjit.Ops.Neg]);
132+
("tanh", fun loc -> [%expr Arrayjit.Ops.Tanh_approx]);
133+
]
128134

129135
type result = {
130136
vbs : value_binding Map.M(String).t;
@@ -832,9 +838,24 @@ let translate (expr : expression) : result =
832838
[%e? lhs]
833839
([%e? bin_op] [%e? rhs1] ([%e? rhs2] ~projections:[%e? projections]))] ->
834840
process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~projections ~proj_in_scope:true ()
835-
| [%expr [%e? accu_op] [%e? lhs] (([%e? un_op] [%e? rhs]) ~projections:[%e? projections])]
836-
| [%expr [%e? accu_op] [%e? lhs] ([%e? un_op] ([%e? rhs] ~projections:[%e? projections]))] ->
837-
let _, un_op = unary_op un_op in
841+
| [%expr
842+
[%e? accu_op]
843+
[%e? lhs]
844+
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
845+
[%e? rhs]
846+
~projections:[%e? projections])]
847+
| [%expr
848+
[%e? accu_op]
849+
[%e? lhs]
850+
(([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])
851+
~projections:[%e? projections])]
852+
| [%expr
853+
[%e? accu_op]
854+
[%e? lhs]
855+
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
856+
([%e? rhs] ~projections:[%e? projections]))]
857+
when Hashtbl.mem unary_ops unop_ident ->
858+
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
838859
(* Handle both un_op priority levels -- where application binds tighter and less tight. *)
839860
process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~projections ~proj_in_scope:true ()
840861
| [%expr [%e? accu_op] [%e? lhs] ([%e? rhs] ~projections:[%e? projections])] ->
@@ -860,24 +881,24 @@ let translate (expr : expression) : result =
860881
| [%expr
861882
[%e? accu_op]
862883
[%e? lhs]
863-
(([%e? un_op] [%e? rhs])
884+
(([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])
864885
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic])]
865886
| [%expr
866887
[%e? accu_op]
867888
[%e? lhs]
868-
([%e? un_op]
889+
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }]
869890
([%e? rhs]
870891
~logic:
871892
[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic]))]
872-
->
893+
when Hashtbl.mem unary_ops unop_ident ->
873894
(* Handle both un_op priority levels -- where application binds tighter and less tight. *)
874895
let logic =
875896
let loc = s_loc in
876897
if String.equal spec "." then [%expr Shape.Pointwise_un]
877898
else if String.equal spec "T" then [%expr Shape.Transpose]
878899
else [%expr Shape.Permute [%e logic]]
879900
in
880-
let _, un_op = unary_op un_op in
901+
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
881902
process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
882903
| [%expr
883904
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
@@ -890,9 +911,9 @@ let translate (expr : expression) : result =
890911
| [%expr
891912
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
892913
[%e? lhs]
893-
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ } as un_op] [%e? rhs])]
894-
when is_assignment accu_ident && is_unary_op unop_ident && proj_in_scope ->
895-
let _, un_op = unary_op un_op in
914+
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])]
915+
when is_assignment accu_ident && Hashtbl.mem unary_ops unop_ident && proj_in_scope ->
916+
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
896917
process_assign_unop ~accu_op ~lhs ~un_op ~rhs ~proj_in_scope ()
897918
| [%expr
898919
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
@@ -913,10 +934,11 @@ let translate (expr : expression) : result =
913934
| [%expr
914935
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
915936
[%e? lhs]
916-
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ } as un_op] [%e? rhs])]
917-
when is_assignment accu_ident && is_unary_op unop_ident ->
918-
let logic, un_op = unary_op un_op in
919-
process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic
937+
([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; _ }; _ }] [%e? rhs])]
938+
when is_assignment accu_ident && Hashtbl.mem unary_ops unop_ident ->
939+
let un_op = Hashtbl.find_exn unary_ops unop_ident loc in
940+
(* FIXME: projections logic! *)
941+
process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic:[%expr Pointwise_un]
920942
| [%expr
921943
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op]
922944
[%e? lhs]

0 commit comments

Comments
 (0)