Skip to content

Commit f6ea375

Browse files
committed
Untested: revert the Cmpne primitive op: can be used to test for NaN (x <> x ==> x = NaN)
1 parent 8b6a6fa commit f6ea375

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

arrayjit/lib/ops.ml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ type binop =
148148
| Mod
149149
| Cmplt
150150
| Cmpeq
151+
| Cmpne
151152
(* Waiting till we have a use-case to see how to sensibly introduce bitwise operations. *)
152153
(* | Shl *)
153154
(* | Shr *)
@@ -181,16 +182,18 @@ type op = Ternop of ternop | Binop of binop | Unop of unop [@@deriving sexp, com
181182
(** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation
182183
does not have a neutral element. *)
183184
let neutral_elem = function
184-
| Add | Sub -> 0.
185-
| Mul | Div -> 1.
185+
| Add -> 0.
186+
| Sub -> 0.
187+
| Mul -> 1.
188+
| Div -> 1.
186189
| ToPowOf -> 1.
187190
| Relu_gate -> 1.
188191
| Satur01_gate -> 0.5
189192
| Max -> Float.neg_infinity
190193
| Min -> Float.infinity
191194
| And -> 1.
192195
| Or -> 0.
193-
| Arg2 | Arg1 | Mod | Cmplt | Cmpeq (* | Shl | Shr *) -> 0.
196+
| Arg2 | Arg1 | Mod | Cmplt | Cmpeq | Cmpne (* | Shl | Shr *) -> 0.
194197

195198
let interpret_binop op v1 v2 =
196199
let open Float in
@@ -210,6 +213,7 @@ let interpret_binop op v1 v2 =
210213
| Mod -> v1 % v2
211214
| Cmplt -> if v1 < v2 then 1. else 0.
212215
| Cmpeq -> if v1 = v2 then 1. else 0.
216+
| Cmpne -> if v1 <> v2 then 1. else 0.
213217
(* | Shl -> v1 * (int_pow 2. @@ to_int v2) *)
214218
(* | Shr -> v1 / (int_pow 2. @@ to_int v2) *)
215219
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
@@ -260,6 +264,7 @@ let binop_cd_syntax = function
260264
| Satur01_gate -> "-?^"
261265
| Cmplt -> "<"
262266
| Cmpeq -> "="
267+
| Cmpne -> "<>"
263268
| Or -> "||"
264269
| And -> "&&"
265270
| Mod -> "%"
@@ -282,6 +287,7 @@ let binop_cd_fallback_syntax = function
282287
| Satur01_gate -> "sat01_gate"
283288
| Cmplt -> "lt"
284289
| Cmpeq -> "eq"
290+
| Cmpne -> "ne"
285291
| Or -> "or_"
286292
| And -> "and_"
287293
| Mod -> "mod_"
@@ -315,6 +321,7 @@ let binop_c_syntax prec v =
315321
| Mod, _ -> ("(", " %", ")")
316322
| Cmplt, _ -> ("(", " <", ")")
317323
| Cmpeq, _ -> ("(", " ==", ")")
324+
| Cmpne, _ -> ("(", " !=", ")")
318325
(* | Shl, Byte_prec _ -> ("(", " <<", ")") *)
319326
(* | Shl, _ -> ("((", ") * exp2(", "))") *)
320327
(* | Shr, Byte_prec _ -> ("(", " >>", ")") *)
@@ -323,7 +330,7 @@ let binop_c_syntax prec v =
323330
| And, _ -> ("(", " &&", ")")
324331

325332
let is_assign_op = function
326-
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq -> false
333+
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq | Cmpne -> false
327334
| Add | Sub | Mul | Div | ToPowOf | Relu_gate | Satur01_gate | Arg2 | Max | Min | Or | And -> true
328335

329336
let assign_op_cd_syntax ~initialize_neutral = function
@@ -350,7 +357,7 @@ let assign_op_cd_syntax ~initialize_neutral = function
350357
| Min -> "=^^"
351358
| Or -> "=||"
352359
| And -> "=&&"
353-
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq ->
360+
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq | Cmpne ->
354361
invalid_arg "Ops.assign_op_cd_syntax: not an assignment op"
355362

356363
(** Note: currently we do not support unary prefix symbols. *)

lib/ppx_cd.ml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ let translate (expr : expression) : result =
407407
@@ Location.error_extensionf ~loc
408408
"ppx_ocannl %%cd: expected a binary operator, one of: %s"
409409
"+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -?^ \
410-
(Satur01_gate), -/> (Arg2), < (Cmplt), = (Cmpeq), || (Or), && (And), % \
411-
(Mod), @^(Max), ^^ (Min)" ))
410+
(Satur01_gate), -/> (Arg2), < (Cmplt), = (Cmpeq), <> (Cmpne), || (Or), && \
411+
(And), % (Mod), @^(Max), ^^ (Min)" ))
412412
in
413413
let ternary_op tern_op =
414414
loc
@@ -694,10 +694,8 @@ let translate (expr : expression) : result =
694694
{ default_result with typ = Array; slot = LHS }
695695
| { pexp_desc = Pexp_ident { txt = Lident "rhs1"; _ }; _ } ->
696696
{ default_result with typ = Array; slot = RHS1 }
697-
| { pexp_desc = Pexp_ident { txt = Lident "t"; _ }; _ } ->
698-
{ default_result with slot = LHS }
699-
| { pexp_desc = Pexp_ident { txt = Lident "t1"; _ }; _ } ->
700-
{ default_result with slot = RHS1 }
697+
| { pexp_desc = Pexp_ident { txt = Lident "t"; _ }; _ } -> { default_result with slot = LHS }
698+
| { pexp_desc = Pexp_ident { txt = Lident "t1"; _ }; _ } -> { default_result with slot = RHS1 }
701699
| { pexp_desc = Pexp_ident { txt = Lident "v1"; _ }; _ } ->
702700
{ default_result with typ = Array; slot = RHS1; expr = [%expr t1.Tensor.value] }
703701
| { pexp_desc = Pexp_ident { txt = Lident "g1"; _ }; _ } ->

lib/ppx_shared.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ let binary_ops =
153153
("lt", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]));
154154
("=", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpeq]));
155155
("eq", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpeq]));
156+
("<>", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne]));
157+
("ne", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne]));
156158
("||", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or]));
157159
("or_", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or]));
158160
("&&", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.And]));

lib/syntax_extensions.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ The binary primitive operations:
7676
| `relu_gate` | `-?/` | pointwise | `Relu_gate` | `=?/`, `=:?/` |
7777
| `sat01_gate` | `-?^` | pointwise | `Satur01_gate` | `=?^`, `=:?^` |
7878
| `lt` | `<` | pointwise | `Cmplt` | none |
79-
| `eq` | `<>` | pointwise | `Cmpeq` | none |
79+
| `eq` | `=` | pointwise | `Cmpeq` | none |
80+
| `ne` | `<>` | pointwise | `Cmpne` | none |
8081
| `or_` | `\|\|` | pointwise | `Or` | `=\|\|`, `=:\|\|` |
8182
| `and_` | `&&` | pointwise | `And` | `=&&`, `=:&&` |
8283
| `mod_` | `%` | pointwise | `Mod` | none |
@@ -133,6 +134,7 @@ let interpret_binop op v1 v2 =
133134
| Mod -> v1 % v2
134135
| Cmplt -> if v1 < v2 then 1. else 0.
135136
| Cmpeq -> if v1 = v2 then 1. else 0.
137+
| Cmpne -> if v1 <> v2 then 1. else 0.
136138
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
137139
| And -> if v1 <> 0. && v2 <> 0. then 1. else 0.
138140

0 commit comments

Comments
 (0)