Skip to content

Commit 86bf625

Browse files
committed
cuda_backend.ml tiny refactoring fixes
Signed-off-by: lukstafi <lukstafi@users.noreply.github.com>
1 parent e906f5f commit 86bf625

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

arrayjit/lib/cuda_backend.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
open Base
2+
open Ir
23
module Tn = Tnode
34
module Lazy = Utils.Lazy
45
module Cu = Cuda
@@ -301,7 +302,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
301302
| ToPowOf, Half_prec _ -> C_syntax.binop_adapter ("hexp2(hlog2(", "),", ")")
302303
| ToPowOf, Byte_prec _ ->
303304
invalid_arg "Cuda_backend.binop_syntax: ToPowOf not supported for byte/integer precisions"
304-
| Relu_gate, Byte_prec _ -> ("(", " > 0 ?", " : 0)")
305+
| Relu_gate, Byte_prec _ -> C_syntax.binop_adapter ("(", " > 0 ?", " : 0)")
305306
| Relu_gate, Half_prec _ ->
306307
C_syntax.binop_adapter
307308
( "(__hgt(",
@@ -390,7 +391,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
390391
| Recip_sqrt, Half_prec _ -> func "hrsqrt"
391392
| Recip_sqrt, Double_prec _ -> f "(1.0 / sqrt(" "))"
392393
| Recip_sqrt, _ -> f "(1.0 / sqrtf(" "))"
393-
| Neg, _ -> (f "(-(", "))")
394+
| Neg, _ -> f "(-(" "))"
394395
| Tanh_approx, Byte_prec _ ->
395396
invalid_arg
396397
"Cuda_backend.unop_syntax: Tanh_approx not supported for byte/integer precisions"

0 commit comments

Comments
 (0)