Skip to content

Commit f279856

Browse files
committed
Untested: primitive ops: change Cmpne to Cmpeq and add Not
1 parent a7c2053 commit f279856

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

arrayjit/lib/ops.ml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ type binop =
147147
| Min
148148
| Mod
149149
| Cmplt
150-
| Cmpne
150+
| Cmpeq
151151
(* Waiting till we have a use-case to see how to sensibly introduce bitwise operations. *)
152152
(* | Shl *)
153153
(* | Shr *)
@@ -170,6 +170,7 @@ type unop =
170170
| Recip_sqrt
171171
| Neg
172172
| Tanh_approx
173+
| Not (** 0. -> 1. | _ -> 0. *)
173174
[@@deriving sexp, compare, equal]
174175

175176
type ternop = Where (** Where(a,b,c): if a then b else c *) | FMA (** FMA(a,b,c): (a * b) + c *)
@@ -188,7 +189,7 @@ let neutral_elem = function
188189
| Min -> Float.infinity
189190
| And -> 1.
190191
| Or -> 0.
191-
| Arg2 | Arg1 | Mod | Cmplt | Cmpne (* | Shl | Shr *) -> 0.
192+
| Arg2 | Arg1 | Mod | Cmplt | Cmpeq (* | Shl | Shr *) -> 0.
192193

193194
let interpret_binop op v1 v2 =
194195
let open Float in
@@ -205,7 +206,7 @@ let interpret_binop op v1 v2 =
205206
| Min -> min v1 v2
206207
| Mod -> v1 % v2
207208
| Cmplt -> if v1 < v2 then 1. else 0.
208-
| Cmpne -> if v1 <> v2 then 1. else 0.
209+
| Cmpeq -> if v1 = v2 then 1. else 0.
209210
(* | Shl -> v1 * (int_pow 2. @@ to_int v2) *)
210211
(* | Shr -> v1 / (int_pow 2. @@ to_int v2) *)
211212
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
@@ -231,6 +232,7 @@ let interpret_unop op v =
231232
| Recip_sqrt -> 1. / sqrt v
232233
| Neg -> ~-.v
233234
| Tanh_approx -> tanh v
235+
| Not -> if v = 0. then 1. else 0.
234236

235237
let interpret_ternop op v1 v2 v3 =
236238
let open Float in
@@ -251,7 +253,7 @@ let binop_cd_syntax = function
251253
| ToPowOf -> "**"
252254
| Relu_gate -> "-?/"
253255
| Cmplt -> "<"
254-
| Cmpne -> "<>"
256+
| Cmpeq -> "="
255257
| Or -> "||"
256258
| And -> "&&"
257259
| Mod -> "%"
@@ -272,7 +274,7 @@ let binop_cd_fallback_syntax = function
272274
| ToPowOf -> "pow"
273275
| Relu_gate -> "relu_gate"
274276
| Cmplt -> "lt"
275-
| Cmpne -> "le"
277+
| Cmpeq -> "eq"
276278
| Or -> "or_"
277279
| And -> "and_"
278280
| Mod -> "mod_"
@@ -302,7 +304,7 @@ let binop_c_syntax prec v =
302304
| Min, _ -> ("fminf(", ",", ")")
303305
| Mod, _ -> ("(", " %", ")")
304306
| Cmplt, _ -> ("(", " <", ")")
305-
| Cmpne, _ -> ("(", " !=", ")")
307+
| Cmpeq, _ -> ("(", " ==", ")")
306308
(* | Shl, Byte_prec _ -> ("(", " <<", ")") *)
307309
(* | Shl, _ -> ("((", ") * exp2(", "))") *)
308310
(* | Shr, Byte_prec _ -> ("(", " >>", ")") *)
@@ -311,7 +313,7 @@ let binop_c_syntax prec v =
311313
| And, _ -> ("(", " &&", ")")
312314

313315
let is_assign_op = function
314-
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne -> false
316+
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq -> false
315317
| Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Max | Min | Or | And -> true
316318

317319
let assign_op_cd_syntax ~initialize_neutral = function
@@ -336,7 +338,7 @@ let assign_op_cd_syntax ~initialize_neutral = function
336338
| Min -> "=^^"
337339
| Or -> "=||"
338340
| And -> "=&&"
339-
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne ->
341+
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq ->
340342
invalid_arg "Ops.assign_op_cd_syntax: not an assignment op"
341343

342344
(** Note: currently we do not support unary prefix symbols. *)
@@ -355,6 +357,7 @@ let unop_cd_syntax = function
355357
| Recip_sqrt -> "recip_sqrt"
356358
| Neg -> "neg"
357359
| Tanh_approx -> "tanh"
360+
| Not -> "not"
358361

359362
let unop_c_syntax prec op =
360363
let fmax () =
@@ -400,6 +403,7 @@ let unop_c_syntax prec op =
400403
| Tanh_approx, Byte_prec _ ->
401404
invalid_arg "Ops.unop_c_syntax: Tanh_approx not supported for byte/integer precisions"
402405
| Tanh_approx, _ -> ("tanhf(", ")")
406+
| Not, _ -> ("(", " == 0.0 ? 1.0 : 0.0)")
403407

404408
(** In the %cd syntax, we use uncurried notation for ternary ops. *)
405409
let ternop_cd_syntax = function Where -> "where" | FMA -> "fma"

lib/ppx_cd.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ let translate (expr : expression) : result =
406406
@@ Location.error_extensionf ~loc
407407
"ppx_ocannl %%cd: expected a binary operator, one of: %s"
408408
"+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -/> (Arg2), \
409-
< (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^(Max), ^^ (Min)" ))
409+
< (Cmplt), = (Cmpeq), || (Or), && (And), % (Mod), @^(Max), ^^ (Min)" ))
410410
in
411411
let ternary_op tern_op =
412412
loc

lib/ppx_shared.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ let binary_ops =
149149
("relu_gate", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Relu_gate]));
150150
("<", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]));
151151
("lt", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]));
152-
("<>", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne]));
153-
("ne", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne]));
152+
("=", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpeq]));
153+
("eq", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpeq]));
154154
("||", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or]));
155155
("or_", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or]));
156156
("&&", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.And]));
@@ -182,6 +182,7 @@ let unary_ops =
182182
("recip_sqrt", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Recip_sqrt]));
183183
("neg", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Neg]));
184184
("tanh", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Tanh_approx]));
185+
("not", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Not]));
185186
]
186187

187188
(** Ternary primitive ops. *)

lib/syntax_extensions.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ The binary primitive operations:
7474
| `pow` | `**` | pointwise | `ToPowOf` | `=**`, `=:**` |
7575
| `relu_gate` | `-?/` | pointwise | `Relu_gate` | `=?/`, `=:?/` |
7676
| `lt` | `<` | pointwise | `Cmplt` | none |
77-
| `ne` | `<>` | pointwise | `Cmpne` | none |
77+
| `eq` | `<>` | pointwise | `Cmpeq` | none |
7878
| `or_` | `\|\|` | pointwise | `Or` | `=\|\|`, `=:\|\|` |
7979
| `and_` | `&&` | pointwise | `And` | `=&&`, `=:&&` |
80-
| `mod_` | `%` | pointwise | `Mod` | `=%`, `=:%` |
80+
| `mod_` | `%` | pointwise | `Mod` | none |
8181
| `max` | `@^` | pointwise | `Max` | `=@^`, `=:@^` |
8282
| `min` | `^^` | pointwise | `Min` | `=^^`, `=:^^` |
8383

0 commit comments

Comments
 (0)