Skip to content

Commit 32a9731

Browse files
committed
Fix: CUDA syntax binops were missing outer parentheses
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent b6f271d commit 32a9731

File tree

3 files changed

+54
-40
lines changed

3 files changed

+54
-40
lines changed

arrayjit/lib/cuda_backend.ml

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ module Alloc_buffer = struct
6767
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
6868
in
6969
set_ctx stream.device.dev.primary_context;
70-
Cu.Deviceptr.mem_alloc ~size_in_bytes
70+
let ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes in
71+
(* TODO: consider using memset_d8 to zero-initialize the memory. *)
72+
(* if size_in_bytes > 0 then
73+
Cu.Stream.memset_d8 ptr Unsigned.UChar.zero ~length:size_in_bytes stream.runner; *)
74+
ptr
7175

7276
let free_buffer = Some (fun _stream ptr -> Cu.Deviceptr.mem_free ptr)
7377
end
@@ -283,10 +287,10 @@ end) : Ir.Backend_impl.Lowered_backend = struct
283287
| Void_prec -> "void"
284288

285289
let binop_syntax prec v =
290+
(* TODO: consider using binop_syntax inherited from Pure_C_config and overriding only
291+
where different. *)
286292
let open PPrint in
287-
let f op_str v1 v2 =
288-
group (lparen ^^ v1 ^^ space ^^ string op_str ^^ space ^^ v2 ^^ rparen)
289-
in
293+
let f op_str v1 v2 = group @@ parens (v1 ^^ space ^^ string op_str ^^ space ^^ v2) in
290294
let func fn v1 v2 = group (string fn ^^ parens (separate comma [ v1; v2 ])) in
291295
match (v, prec) with
292296
| Ops.Arg1, _ -> invalid_arg "Cuda_backend.binop_syntax: Arg1 is not an operator"
@@ -307,44 +311,50 @@ end) : Ir.Backend_impl.Lowered_backend = struct
307311
| ToPowOf, Byte_prec _ ->
308312
invalid_arg "Cuda_backend.binop_syntax: ToPowOf not supported for byte/integer precisions"
309313
| Relu_gate, Byte_prec _ ->
310-
fun v1 v2 -> group (parens (v1 ^^ string " > 0") ^^ string " ? " ^^ v2 ^^ string " : 0")
314+
fun v1 v2 ->
315+
group @@ parens (parens (v1 ^^ string " > 0") ^^ string " ? " ^^ v2 ^^ string " : 0")
311316
| Relu_gate, Half_prec _ ->
312317
fun v1 v2 ->
313318
group
314-
(parens
315-
(string "__hgt(" ^^ v1 ^^ comma
316-
^^ string " __ushort_as_half((unsigned short)0x0000U))")
317-
^^ string " ? " ^^ v2
318-
^^ string " : __ushort_as_half((unsigned short)0x0000U)")
319+
@@ parens
320+
(parens
321+
(string "__hgt(" ^^ v1 ^^ comma
322+
^^ string " __ushort_as_half((unsigned short)0x0000U))")
323+
^^ string " ? " ^^ v2
324+
^^ string " : __ushort_as_half((unsigned short)0x0000U)")
319325
| Relu_gate, _ ->
320326
fun v1 v2 ->
321-
group (parens (v1 ^^ string " > 0.0") ^^ string " ? " ^^ v2 ^^ string " : 0.0")
327+
group @@ parens (parens (v1 ^^ string " > 0.0") ^^ string " ? " ^^ v2 ^^ string " : 0.0")
322328
| Satur01_gate, Byte_prec _ ->
323329
fun v1 v2 ->
324-
parens
325-
(parens
326-
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f")
327-
^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0")
330+
group
331+
@@ parens
332+
(parens
333+
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f")
334+
^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0")
328335
| Satur01_gate, Half_prec _ ->
329336
fun v1 v2 ->
330-
parens
331-
(parens
332-
(string "__hgt(" ^^ v1 ^^ comma
333-
^^ string " __ushort_as_half((unsigned short)0x0000U)) && __hlt("
334-
^^ v1 ^^ comma
335-
^^ string " __ushort_as_half((unsigned short)0x3C00U)))")
336-
^^ string " ? " ^^ v2
337-
^^ string " : __ushort_as_half((unsigned short)0x0000U)")
337+
group
338+
@@ parens
339+
(parens
340+
(string "__hgt(" ^^ v1 ^^ comma
341+
^^ string " __ushort_as_half((unsigned short)0x0000U)) && __hlt("
342+
^^ v1 ^^ comma
343+
^^ string " __ushort_as_half((unsigned short)0x3C00U)))")
344+
^^ string " ? " ^^ v2
345+
^^ string " : __ushort_as_half((unsigned short)0x0000U)")
338346
| Satur01_gate, Single_prec _ ->
339347
fun v1 v2 ->
340-
parens
341-
(parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f")
342-
^^ string " ? " ^^ v2 ^^ string " : 0.0f")
348+
group
349+
@@ parens
350+
(parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f")
351+
^^ string " ? " ^^ v2 ^^ string " : 0.0f")
343352
| Satur01_gate, Double_prec _ ->
344353
fun v1 v2 ->
345-
parens
346-
(parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0")
347-
^^ string " ? " ^^ v2 ^^ string " : 0.0")
354+
group
355+
@@ parens
356+
(parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0")
357+
^^ string " ? " ^^ v2 ^^ string " : 0.0")
348358
| Max, Byte_prec _ -> func "max"
349359
| Max, Half_prec _ -> func "__hmax"
350360
| Max, Double_prec _ -> func "fmax"
@@ -403,13 +413,16 @@ end) : Ir.Backend_impl.Lowered_backend = struct
403413
| Recip, Byte_prec _ ->
404414
invalid_arg "Cuda_backend.unop_syntax: Recip not supported for byte/integer precisions"
405415
| Recip, Half_prec _ -> func "hrcp"
406-
| Recip, _ -> f "(1.0 / (" "))"
416+
| Recip, Single_prec _ -> f "(1.0f / (" "))"
417+
| Recip, Double_prec _ -> f "(1.0 / (" "))"
418+
| Recip, _ -> f "(1 / (" "))"
407419
| Recip_sqrt, Byte_prec _ ->
408420
invalid_arg
409421
"Cuda_backend.unop_syntax: Recip_sqrt not supported for byte/integer precisions"
410422
| Recip_sqrt, Half_prec _ -> func "hrsqrt"
411423
| Recip_sqrt, Double_prec _ -> f "(1.0 / sqrt(" "))"
412-
| Recip_sqrt, _ -> f "(1.0 / sqrtf(" "))"
424+
| Recip_sqrt, Single_prec _ -> f "(1.0f / sqrtf(" "))"
425+
| Recip_sqrt, _ -> f "(1 / sqrtf(" "))"
413426
| Neg, _ -> f "(-(" "))"
414427
| Tanh_approx, Byte_prec _ ->
415428
invalid_arg

bin/micrograd_basic.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module Rand = Ir.Rand.Lib
88

99
let _get_local_debug_runtime = Utils.get_local_debug_runtime
1010

11-
let%diagn_sexp () =
11+
let%diagn_sexp _suspended() =
1212
let module Backend = (val Backends.fresh_backend ~backend_name:"multicore_cc" ()) in
1313
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
1414
let ctx = Backend.make_context stream in
@@ -33,7 +33,7 @@ let%diagn_sexp () =
3333
Tensor.print ~with_code:false ~with_grad:true `Default @@ a;
3434
Tensor.print ~with_code:false ~with_grad:true `Default @@ b
3535

36-
let%diagn_sexp _suspended () : unit =
36+
let%diagn_sexp () : unit =
3737
let module Backend = (val Backends.fresh_backend ()) in
3838
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
3939
let ctx = Backend.make_context stream in
@@ -53,6 +53,7 @@ let%diagn_sexp _suspended () : unit =
5353
(* Train.every_non_literal_on_host g; *)
5454
let update = Train.grad_update g in
5555
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
56+
Utils.capture_stdout_logs @@ fun () ->
5657
Train.run routine;
5758
(* Tensor.print_tree ~with_grad:true ~depth:9 g; *)
5859
Tensor.print ~with_code:false ~with_grad:false `Default @@ g;

test/micrograd_demo_logging-cuda-0-0.log.expected

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ b_grad[0]{=MAYBE UNINITIALIZED} = -333201e-3 = fmaf(a[0]{=-4000e-3},n14_d_grad[0
6060
# b.grad[0] := fma((3 * (b[0] * b[0])), n14_d.grad[0], b.grad[0]);
6161
b_grad[0]{=MAYBE UNINITIALIZED} = 666402e-3 = fmaf(((float)(3) * (b[0]{=2000e-3} * b[0]{=2000e-3})),n14_d_grad[0]{=83300e-3},b_grad[0]{=-333201e-3})
6262
# b.grad[0] := (b.grad[0] + relu_gate(n31[0], n40_d.grad[0]));
63-
b_grad[0]{=MAYBE UNINITIALIZED} = 27766e-3 = (b_grad[0]{=666402e-3} + (n31[0]{=-2000e-3} > 0.0) ? n40_d_grad[0]{=27766e-3} : 0.0)
63+
b_grad[0]{=MAYBE UNINITIALIZED} = 666402e-3 = (b_grad[0]{=666402e-3} + ((n31[0]{=-2000e-3} > 0.0) ? n40_d_grad[0]{=27766e-3} : 0.0))
6464
# a.grad[0] := (a.grad[0] + relu_gate(n31[0], n40_d.grad[0]));
65-
a_grad[0]{=MAYBE UNINITIALIZED} = 27766e-3 = (a_grad[0]{=166600e-3} + (n31[0]{=-2000e-3} > 0.0) ? n40_d_grad[0]{=27766e-3} : 0.0)
65+
a_grad[0]{=MAYBE UNINITIALIZED} = 166600e-3 = (a_grad[0]{=166600e-3} + ((n31[0]{=-2000e-3} > 0.0) ? n40_d_grad[0]{=27766e-3} : 0.0))
6666
# b.grad[0] :=$ (b.grad[0] + relu_gate(n42[0], (-1 * ((2 * e[0]) * f.grad[0]))));
67-
b_grad[0]{=MAYBE UNINITIALIZED} = 6941e-3 = (b_grad[0]{=27766e-3} + (n42[0]{=6000e-3} > 0.0) ? ((float)(-1) * (((float)(2) * e[0]{=-7000e-3}) * f_grad[0]{=495e-3})) : 0.0)
67+
b_grad[0]{=MAYBE UNINITIALIZED} = 673344e-3 = (b_grad[0]{=666402e-3} + ((n42[0]{=6000e-3} > 0.0) ? ((float)(-1) * (((float)(2) * e[0]{=-7000e-3}) * f_grad[0]{=495e-3})) : 0.0))
6868
# a.grad[0] :=$ (a.grad[0] - relu_gate(n42[0], (-1 * ((2 * e[0]) * f.grad[0]))));
69-
a_grad[0]{=MAYBE UNINITIALIZED} = 6941e-3 = (a_grad[0]{=27766e-3} - (n42[0]{=6000e-3} > 0.0) ? ((float)(-1) * (((float)(2) * e[0]{=-7000e-3}) * f_grad[0]{=495e-3})) : 0.0)
69+
a_grad[0]{=MAYBE UNINITIALIZED} = 159658e-3 = (a_grad[0]{=166600e-3} - ((n42[0]{=6000e-3} > 0.0) ? ((float)(-1) * (((float)(2) * e[0]{=-7000e-3}) * f_grad[0]{=495e-3})) : 0.0))
7070
# n19_c.grad[0] := fma((2 * e[0]), f.grad[0], n19_c.grad[0]);
7171
n19_c_grad[0]{=MAYBE UNINITIALIZED} = -6941e-3 = fmaf(((float)(2) * e[0]{=-7000e-3}),f_grad[0]{=495e-3},n19_c_grad[0]{=0e-3})
7272
# n19_c.grad[0] := fma((2 * e[0]), f.grad[0], n19_c.grad[0]);
@@ -76,10 +76,10 @@ n4_c_grad[0]{=MAYBE UNINITIALIZED} = -13883e-3 = (n4_c_grad[0]{=0e-3} + n19_c_gr
7676
# n4_c.grad[0] := (n4_c.grad[0] + n19_c.grad[0]);
7777
n4_c_grad[0]{=MAYBE UNINITIALIZED} = -27766e-3 = (n4_c_grad[0]{=-13883e-3} + n19_c_grad[0]{=-13883e-3})
7878
# a.grad[0] := (a.grad[0] + n4_c.grad[0]);
79-
a_grad[0]{=MAYBE UNINITIALIZED} = -20825e-3 = (a_grad[0]{=6941e-3} + n4_c_grad[0]{=-27766e-3})
79+
a_grad[0]{=MAYBE UNINITIALIZED} = 131892e-3 = (a_grad[0]{=159658e-3} + n4_c_grad[0]{=-27766e-3})
8080
# b.grad[0] := (b.grad[0] + n4_c.grad[0]);
81-
b_grad[0]{=MAYBE UNINITIALIZED} = -20825e-3 = (b_grad[0]{=6941e-3} + n4_c_grad[0]{=-27766e-3})
81+
b_grad[0]{=MAYBE UNINITIALIZED} = 645577e-3 = (b_grad[0]{=673344e-3} + n4_c_grad[0]{=-27766e-3})
8282
# a.grad[0] := fma(-1, ((2 * e[0]) * f.grad[0]), a.grad[0]);
83-
a_grad[0]{=MAYBE UNINITIALIZED} = -13883e-3 = fmaf((float)(-1),(((float)(2) * e[0]{=-7000e-3}) * f_grad[0]{=495e-3}),a_grad[0]{=-20825e-3})
83+
a_grad[0]{=MAYBE UNINITIALIZED} = 138833e-3 = fmaf((float)(-1),(((float)(2) * e[0]{=-7000e-3}) * f_grad[0]{=495e-3}),a_grad[0]{=131892e-3})
8484
COMMENT: end
8585
COMMENT: end

0 commit comments

Comments
 (0)