@@ -46,21 +46,29 @@ let assignment_op expr =
4646 | [% expr ( =/ )] -> (false , [% expr Arrayjit.Ops. Div ])
4747 | [% expr ( =** )] -> (false , [% expr Arrayjit.Ops. ToPowOf ])
4848 | [% expr ( =?/ )] -> (false , [% expr Arrayjit.Ops. Relu_gate ])
49+ | [% expr ( =|| )] -> (false , [% expr Arrayjit.Ops. Or ])
50+ | [% expr ( =&& )] -> (false , [% expr Arrayjit.Ops. And ])
51+ | [% expr ( =@^ )] -> (false , [% expr Arrayjit.Ops. Max ])
52+ | [% expr ( =^^ )] -> (false , [% expr Arrayjit.Ops. Min ])
4953 | [% expr ( =:+ )] -> (true , [% expr Arrayjit.Ops. Add ])
5054 | [% expr ( =:- )] -> (true , [% expr Arrayjit.Ops. Sub ])
5155 | [% expr ( =:* )] -> (true , [% expr Arrayjit.Ops. Mul ])
5256 | [% expr ( =:/ )] -> (true , [% expr Arrayjit.Ops. Div ])
5357 | [% expr ( =:** )] -> (true , [% expr Arrayjit.Ops. ToPowOf ])
5458 | [% expr ( =:?/ )] -> (true , [% expr Arrayjit.Ops. Relu_gate ])
59+ | [% expr ( =:|| )] -> (true , [% expr Arrayjit.Ops. Or ])
60+ | [% expr ( =:&& )] -> (true , [% expr Arrayjit.Ops. And ])
61+ | [% expr ( =:@^ )] -> (true , [% expr Arrayjit.Ops. Max ])
62+ | [% expr ( =:^^ )] -> (true , [% expr Arrayjit.Ops. Min ])
5563 | _ ->
5664 ( false ,
5765 Ast_builder.Default. pexp_extension ~loc
5866 @@ Location. error_extensionf ~loc
5967 " ppx_ocannl %%cd: expected an assignment operator, one of: %s %s"
60- " =+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =: (Arg2 ), \
61- =:+, =:-,"
62- " =:*, =:/, =:**, =:?/ (same with initializing the tensor to the neutral value before \
63- the start of the calculation)" )
68+ " =+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =|| (Or ), \
69+ =&& (And), =@^ (Max), =^^ (Min), =: (Arg2), = :+, =:-,"
70+ " =:*, =:/, =:**, =:?/, =:||, =:&&, =:@^, =:^^ (same with initializing the tensor to \
71+ the neutral value before the start of the calculation)" )
6472
6573let binary_op expr =
6674 (* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)
@@ -84,14 +92,25 @@ let binary_op expr =
8492 | [% expr ( -?/ )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Relu_gate ])
8593 | [% expr ( -/> )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Arg2 ])
8694 | [% expr ( -@> )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Arg1 ])
95+ | [% expr ( < )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Cmplt ])
96+ | [% expr ( <> )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Cmpne ])
97+ | [% expr ( || )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Or ])
98+ | [% expr ( && )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. And ])
99+ | [% expr ( % )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Mod ])
100+ | [% expr ( @^ )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Max ])
101+ | [% expr ( ^^ )] -> ([% expr Shape. Pointwise_bin ], [% expr Arrayjit.Ops. Min ])
87102 | _ ->
88103 ( [% expr Shape. Pointwise_bin ],
89104 Ast_builder.Default. pexp_extension ~loc
90105 @@ Location. error_extensionf ~loc " ppx_ocannl %%cd: expected a binary operator, one of: %s"
91- " + (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2)" )
106+ " + (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \
107+ (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^ (Max), ^^ (Min)" )
92108
93109let is_binary_op ident =
94- List. mem [ " +" ; " -" ; " *" ; " /" ; " **" ; " -?/" ; " -/>" ; " -@>" ] ident ~equal: String. equal
110+ (* TODO: compile into a hashtable *)
111+ List. mem
112+ [ " +" ; " -" ; " *" ; " /" ; " **" ; " -?/" ; " -/>" ; " -@>" ; " <" ; " <>" ; " &&" ; " %" ; " @^" ; " ^^" ]
113+ ident ~equal: String. equal
95114
96115let unary_op expr =
97116 (* This and is_unary_op should stay in sync with Arrayjit.Ops.unop_cd_syntax. *)
0 commit comments