@@ -147,7 +147,7 @@ type binop =
147147 | Min
148148 | Mod
149149 | Cmplt
150- | Cmpne
150+ | Cmpeq
151151 (* Waiting till we have a use-case to see how to sensibly introduce bitwise operations. *)
152152 (* | Shl *)
153153 (* | Shr *)
@@ -170,6 +170,7 @@ type unop =
170170 | Recip_sqrt
171171 | Neg
172172 | Tanh_approx
173+ | Not (* * 0. -> 1. | _ -> 0. *)
173174[@@ deriving sexp , compare , equal ]
174175
175176type ternop = Where (* * Where(a,b,c): if a then b else c *) | FMA (* * FMA(a,b,c): (a * b) + c *)
@@ -188,7 +189,7 @@ let neutral_elem = function
188189 | Min -> Float. infinity
189190 | And -> 1.
190191 | Or -> 0.
191- | Arg2 | Arg1 | Mod | Cmplt | Cmpne (* | Shl | Shr * ) -> 0.
192+ | Arg2 | Arg1 | Mod | Cmplt | Cmpeq (* | Shl | Shr * ) -> 0.
192193
193194let interpret_binop op v1 v2 =
194195 let open Float in
@@ -205,7 +206,7 @@ let interpret_binop op v1 v2 =
205206 | Min -> min v1 v2
206207 | Mod -> v1 % v2
207208 | Cmplt -> if v1 < v2 then 1. else 0.
208- | Cmpne -> if v1 <> v2 then 1. else 0.
209+ | Cmpeq -> if v1 = v2 then 1. else 0.
209210 (* | Shl -> v1 * (int_pow 2. @@ to_int v2) *)
210211 (* | Shr -> v1 / (int_pow 2. @@ to_int v2) *)
211212 | Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
@@ -231,6 +232,7 @@ let interpret_unop op v =
231232 | Recip_sqrt -> 1. / sqrt v
232233 | Neg -> ~-. v
233234 | Tanh_approx -> tanh v
235+ | Not -> if v = 0. then 1. else 0.
234236
235237let interpret_ternop op v1 v2 v3 =
236238 let open Float in
@@ -251,7 +253,7 @@ let binop_cd_syntax = function
251253 | ToPowOf -> " **"
252254 | Relu_gate -> " -?/"
253255 | Cmplt -> " <"
254- | Cmpne -> " <> "
256+ | Cmpeq -> " = "
255257 | Or -> " ||"
256258 | And -> " &&"
257259 | Mod -> " %"
@@ -272,7 +274,7 @@ let binop_cd_fallback_syntax = function
272274 | ToPowOf -> " pow"
273275 | Relu_gate -> " relu_gate"
274276 | Cmplt -> " lt"
275- | Cmpne -> " le "
277+ | Cmpeq -> " eq "
276278 | Or -> " or_"
277279 | And -> " and_"
278280 | Mod -> " mod_"
@@ -302,7 +304,7 @@ let binop_c_syntax prec v =
302304 | Min , _ -> (" fminf(" , " ," , " )" )
303305 | Mod , _ -> (" (" , " %" , " )" )
304306 | Cmplt , _ -> (" (" , " <" , " )" )
305- | Cmpne , _ -> (" (" , " ! =" , " )" )
307+ | Cmpeq , _ -> (" (" , " = =" , " )" )
306308 (* | Shl, Byte_prec _ -> ("(", " <<", ")") *)
307309 (* | Shl, _ -> ("((", ") * exp2(", "))") *)
308310 (* | Shr, Byte_prec _ -> ("(", " >>", ")") *)
@@ -311,7 +313,7 @@ let binop_c_syntax prec v =
311313 | And , _ -> (" (" , " &&" , " )" )
312314
313315let is_assign_op = function
314- | Arg1 | Mod (* | Shl | Shr * ) | Cmplt | Cmpne -> false
316+ | Arg1 | Mod (* | Shl | Shr * ) | Cmplt | Cmpeq -> false
315317 | Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Max | Min | Or | And -> true
316318
317319let assign_op_cd_syntax ~initialize_neutral = function
@@ -336,7 +338,7 @@ let assign_op_cd_syntax ~initialize_neutral = function
336338 | Min -> " =^^"
337339 | Or -> " =||"
338340 | And -> " =&&"
339- | Arg1 | Mod (* | Shl | Shr * ) | Cmplt | Cmpne ->
341+ | Arg1 | Mod (* | Shl | Shr * ) | Cmplt | Cmpeq ->
340342 invalid_arg " Ops.assign_op_cd_syntax: not an assignment op"
341343
342344(* * Note: currently we do not support unary prefix symbols. *)
@@ -355,6 +357,7 @@ let unop_cd_syntax = function
355357 | Recip_sqrt -> " recip_sqrt"
356358 | Neg -> " neg"
357359 | Tanh_approx -> " tanh"
360+ | Not -> " not"
358361
359362let unop_c_syntax prec op =
360363 let fmax () =
@@ -400,6 +403,7 @@ let unop_c_syntax prec op =
400403 | Tanh_approx , Byte_prec _ ->
401404 invalid_arg " Ops.unop_c_syntax: Tanh_approx not supported for byte/integer precisions"
402405 | Tanh_approx , _ -> (" tanhf(" , " )" )
406+ | Not , _ -> (" (" , " == 0.0 ? 1.0 : 0.0)" )
403407
404408(* * In the %cd syntax, we use uncurried notation for ternary ops. *)
405409let ternop_cd_syntax = function Where -> " where" | FMA -> " fma"
0 commit comments