Skip to content

Commit c16efd5

Browse files
committed
Fix Satur01_gate for non-Metal backends
1 parent 132144b commit c16efd5

File tree

5 files changed

+45
-13
lines changed

5 files changed

+45
-13
lines changed

CHANGES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
### Changed
88

99
- Removed `initialize` and `is_initialized` from the backend API; instead, backends should be initialized on functor application. The functors now take `config` as argument.
10+
- More descriptive code identifier names in case of name conflicts.
11+
12+
### Fixed
13+
14+
- Avoid conflicts with C math function names like `fma`.
15+
- Satur01_gate had wrong semantics.
1016

1117
## [0.5.2] -- 2025-04-07
1218

arrayjit/lib/c_syntax.ml

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,27 @@ struct
144144
Set.to_list !functions
145145

146146
let ternop_syntax prec op = ternop_adapter (Ops.ternop_c_syntax prec op)
147-
let binop_syntax prec op = binop_adapter (Ops.binop_c_syntax prec op)
147+
148+
let binop_syntax prec op =
149+
match op with
150+
| Ops.Satur01_gate -> (
151+
match prec with
152+
| Ops.Byte_prec _ ->
153+
fun ppf pp1 v1 pp2 v2 ->
154+
Stdlib.Format.fprintf ppf "(((float)%a > 0.0f && (float)%a < 1.0f) ? %a : (unsigned char)0)"
155+
pp1 v1 pp1 v1 pp2 v2
156+
| Ops.Half_prec _ ->
157+
fun ppf pp1 v1 pp2 v2 ->
158+
Stdlib.Format.fprintf ppf "((%a > 0.0f16 && %a < 1.0f16) ? %a : 0.0f16)" pp1 v1 pp1 v1 pp2 v2
159+
| Ops.Single_prec _ ->
160+
fun ppf pp1 v1 pp2 v2 ->
161+
Stdlib.Format.fprintf ppf "((%a > 0.0f && %a < 1.0f) ? %a : 0.0f)" pp1 v1 pp1 v1 pp2 v2
162+
| Ops.Double_prec _ ->
163+
fun ppf pp1 v1 pp2 v2 ->
164+
Stdlib.Format.fprintf ppf "((%a > 0.0 && %a < 1.0) ? %a : 0.0)" pp1 v1 pp1 v1 pp2 v2
165+
| Ops.Void_prec -> invalid_arg "Pure_C_config.binop_syntax: Satur01_gate on Void_prec")
166+
| _ -> binop_adapter (Ops.binop_c_syntax prec op)
167+
148168
let unop_syntax prec op = unop_adapter (Ops.unop_c_syntax prec op)
149169
let convert_precision = Ops.c_convert_precision
150170
end

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized)
9191
end)) in
9292
(* FIXME: do we really want all of them, or only the used ones? *)
9393
let idx_params = Indexing.bound_symbols bindings in
94-
let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in
94+
let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in
9595
Syntax.print_declarations pp_file.ppf;
9696
let params = Syntax.compile_proc ~name pp_file.ppf idx_params lowered in
9797
pp_file.finalize ();

arrayjit/lib/cuda_backend.ml

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,21 @@ end) : Ir.Backend_impl.Lowered_backend = struct
308308
", __ushort_as_half((unsigned short)0x0000U)) ?",
309309
" : __ushort_as_half((unsigned short)0x0000U))" )
310310
| Relu_gate, _ -> C_syntax.binop_adapter ("(", " > 0.0 ?", " : 0.0)")
311-
| Satur01_gate, Byte_prec _ -> C_syntax.binop_adapter ("(abs(", ") > 0 ? 0 : (", ")")
311+
| Satur01_gate, Byte_prec _ ->
312+
fun ppf pp1 v1 pp2 v2 ->
313+
Stdlib.Format.fprintf ppf "(((float)%a > 0.0f && (float)%a < 1.0f) ? %a : (unsigned char)0)"
314+
pp1 v1 pp1 v1 pp2 v2
312315
| Satur01_gate, Half_prec _ ->
313-
C_syntax.binop_adapter
314-
( "(__hgt(__habs(htrunc(",
315-
")), __ushort_as_half((unsigned short)0x0000U)) ? __ushort_as_half((unsigned \
316-
short)0x0000U) : (",
317-
"))" )
318-
| Satur01_gate, Double_prec _ ->
319-
C_syntax.binop_adapter ("(fabs(trunc(", ")) > 0.0 ? 0.0 : (", "))")
316+
fun ppf pp1 v1 pp2 v2 ->
317+
Stdlib.Format.fprintf ppf
318+
"((__hgt(%a, __ushort_as_half((unsigned short)0x0000U)) && __hlt(%a, __ushort_as_half((unsigned short)0x3C00U))) ? %a : __ushort_as_half((unsigned short)0x0000U))"
319+
pp1 v1 pp1 v1 pp2 v2
320320
| Satur01_gate, Single_prec _ ->
321-
C_syntax.binop_adapter ("(fabsf(truncf(", ")) > 0.0 ? 0.0 : (", "))")
321+
fun ppf pp1 v1 pp2 v2 ->
322+
Stdlib.Format.fprintf ppf "((%a > 0.0f && %a < 1.0f) ? %a : 0.0f)" pp1 v1 pp1 v1 pp2 v2
323+
| Satur01_gate, Double_prec _ ->
324+
fun ppf pp1 v1 pp2 v2 ->
325+
Stdlib.Format.fprintf ppf "((%a > 0.0 && %a < 1.0) ? %a : 0.0)" pp1 v1 pp1 v1 pp2 v2
322326
| Max, Byte_prec _ -> func "max"
323327
| Max, Half_prec _ -> func "__hmax"
324328
| Max, Double_prec _ -> func "fmax"

arrayjit/lib/ops.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,10 @@ let binop_c_syntax prec v =
314314
| Relu_gate, Byte_prec _ -> ("(", " > 0 ?", " : 0)")
315315
| Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)")
316316
| Satur01_gate, Byte_prec _ -> ("(abs(", " ) > 0 ? 0 : (", "))")
317-
| Satur01_gate, Single_prec _ -> ("(fabsf(truncf(", ")) > 0.0 ? 0.0 : (", "))")
318-
| Satur01_gate, _ -> ("(fabs(trunc(", ")) > 0.0 ? 0.0 : (", "))")
317+
| Satur01_gate, Single_prec _ ->
318+
(* This disagrees at 0 with the semantics. *)
319+
("(fabsf(floorf(", ")) > 0.0 ? 0.0 : (", "))")
320+
| Satur01_gate, _ -> ("(fabs(floor(", ")) > 0.0 ? 0.0 : (", "))")
319321
| Max, (Double_prec _ | Byte_prec _) -> ("fmax(", ",", ")")
320322
| Max, _ -> ("fmaxf(", ",", ")")
321323
| Min, (Double_prec _ | Byte_prec _) -> ("fmin(", ",", ")")

0 commit comments

Comments
 (0)