Skip to content

Commit 2eb8cc1

Browse files
committed
Fix enabling of zero-dimension scalars in metal and cuda backends
1 parent 21e0243 commit 2eb8cc1

File tree

4 files changed

+38
-22
lines changed

4 files changed

+38
-22
lines changed

arrayjit/lib/cuda_backend.ml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ module Alloc_buffer = struct
6969
{ ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes; size_in_bytes }
7070

7171
let alloc_zero_init_array prec ~dims stream =
72-
let size_in_bytes =
73-
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
74-
in
72+
let size_in_bytes = Array.fold dims ~init:1 ~f:( * ) * Ops.prec_in_bytes prec in
7573
set_ctx stream.device.dev.primary_context;
7674
let ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes in
7775
(* TODO: consider using memset_d8 to zero-initialize the memory. *)
@@ -588,13 +586,17 @@ end) : Ir.Backend_impl.Lowered_backend = struct
588586
| Cmpeq, _ -> f "=="
589587
| Or, _ -> f "||"
590588
| And, _ -> f "&&"
591-
| Threefry4x32, _ ->
589+
| Threefry4x32, _ -> (
592590
(* Threefry4x32 must output to uint4x32 precision *)
593-
(match prec with
591+
match prec with
594592
| Ops.Uint4x32_prec _ -> func "arrayjit_threefry4x32"
595-
| _ -> raise @@ Utils.User_error
596-
(Printf.sprintf "CUDA backend: Threefry4x32 requires target precision to be uint4x32, but got %s"
597-
(Ops.prec_string prec)))
593+
| _ ->
594+
raise
595+
@@ Utils.User_error
596+
(Printf.sprintf
597+
"CUDA backend: Threefry4x32 requires target precision to be uint4x32, but \
598+
got %s"
599+
(Ops.prec_string prec)))
598600

599601
let unop_syntax prec v =
600602
let open PPrint in

arrayjit/lib/metal_backend.ml

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,8 @@ module Alloc_buffer = struct
8080
track_allocation new_buffer_obj;
8181
{ ptr = new_buffer_obj; size_in_bytes }
8282

83-
let alloc_zero_init_array prec ~dims (stream : stream) =
84-
let size_in_bytes =
85-
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
86-
in
83+
let%track7_sexp alloc_zero_init_array (prec : Ops.prec) ~(dims : int array) (stream : stream) =
84+
let size_in_bytes = Array.fold dims ~init:1 ~f:( * ) * Ops.prec_in_bytes prec in
8785
let device = stream.device.dev in
8886
let buffer = Me.Buffer.on_device device ~length:size_in_bytes resource_options in
8987
track_allocation buffer;
@@ -448,13 +446,15 @@ end) : Ir.Backend_impl.Lowered_backend = struct
448446
| Ops.Bfloat16_prec _ -> "bfloat" (* Metal supports bfloat16 natively *)
449447
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
450448
| Ops.Single_prec _ -> "float"
451-
| Ops.Double_prec _ -> raise @@ Utils.User_error "Metal backend does not support double precision"
449+
| Ops.Double_prec _ ->
450+
raise @@ Utils.User_error "Metal backend does not support double precision"
452451
| Ops.Void_prec -> "void"
453452

454453
let vec_typ_of_prec ~length prec =
455454
match (prec, length) with
456455
| Ops.Single_prec _, 4 -> "float4_t"
457-
| Ops.Double_prec _, 2 -> raise @@ Utils.User_error "Metal backend does not support double precision"
456+
| Ops.Double_prec _, 2 ->
457+
raise @@ Utils.User_error "Metal backend does not support double precision"
458458
| Ops.Int32_prec _, 4 -> "int32x4_t"
459459
| (Ops.Byte_prec _ | Ops.Fp8_prec _), 16 -> "int8x16_t"
460460
| (Ops.Uint16_prec _ | Ops.Bfloat16_prec _), 8 -> "uint16x8_t"
@@ -472,7 +472,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
472472
| Ops.Bfloat16_prec _ -> "bf" (* TODO: Verify actual Metal suffix for bfloat16 *)
473473
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
474474
| Ops.Single_prec _ -> "f"
475-
| Ops.Double_prec _ -> raise @@ Utils.User_error "Metal backend does not support double precision"
475+
| Ops.Double_prec _ ->
476+
raise @@ Utils.User_error "Metal backend does not support double precision"
476477
| Ops.Void_prec -> ""
477478

478479
let ternop_syntax _prec op =
@@ -532,13 +533,17 @@ end) : Ir.Backend_impl.Lowered_backend = struct
532533
^^ space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
533534
^^ string ("0.0" ^ s)))
534535
| ToPowOf, _ -> func "pow"
535-
| Threefry4x32, _ ->
536+
| Threefry4x32, _ -> (
536537
(* Threefry4x32 must output to uint4x32 precision *)
537-
(match prec with
538+
match prec with
538539
| Ops.Uint4x32_prec _ -> func "arrayjit_threefry4x32"
539-
| _ -> raise @@ Utils.User_error
540-
(Printf.sprintf "Metal backend: Threefry4x32 requires target precision to be uint4x32, but got %s"
541-
(Ops.prec_string prec)))
540+
| _ ->
541+
raise
542+
@@ Utils.User_error
543+
(Printf.sprintf
544+
"Metal backend: Threefry4x32 requires target precision to be uint4x32, but \
545+
got %s"
546+
(Ops.prec_string prec)))
542547
| Arg1, _ | Arg2, _ -> invalid_arg "Metal C_syntax_config: Arg1/Arg2 not operators"
543548

544549
let unop_syntax prec op =
@@ -555,7 +560,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
555560
| Sqrt, _ -> func_doc "sqrt"
556561
| Relu, Ops.Half_prec _ -> fun v -> func_doc "max" (separate comma_sep [ string "0.0h"; v ])
557562
| Relu, Ops.Single_prec _ -> fun v -> func_doc "max" (separate comma_sep [ string "0.0f"; v ])
558-
| Relu, Ops.Double_prec _ -> raise @@ Utils.User_error "Metal backend does not support double precision"
563+
| Relu, Ops.Double_prec _ ->
564+
raise @@ Utils.User_error "Metal backend does not support double precision"
559565
| Relu, _ (* Byte_prec, Void_prec *) ->
560566
fun v -> func_doc "max" (separate comma_sep [ string "0"; v ])
561567
| Satur01, p ->

test/einsum/moons_demo_variant.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ n93 sgd_delta_w3 as sgd_delta_w3: Virt/15; single prec 1x16; mem in bytes: <not-
9494
n94 sgd_momentum_w3 as sgd_momentum_w3: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
9595
n95 0.0001 as n95: Virt/40; single prec 1; mem in bytes: <not-in-yet>
9696
n96 *. as n96: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
97-
n97 point as point: Host-const/37; single prec 2; mem in bytes: 8
97+
n97 point as point: Host-const/37; single prec 2; mem in bytes: <not-in-yet>
9898
n98 * as n98: Local/1046; single prec 16; mem in bytes: <not-in-yet>
9999
n99 grad_* as n98.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
100100
n100 + as n100: Virt/15; single prec 16; mem in bytes: <not-in-yet>

test/operations/micrograd_demo_logging-metal-0-0.log.expected

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
float *a &[1] = 0xNNNN
2+
float *b &[1] = 0xNNNN
3+
COMMENT: init params for g
4+
# b[0] := 2;
5+
b[0]{=MAYBE UNINITIALIZED} = 2000e-3 = (float)(2)
6+
# a[0] := -4;
7+
a[0]{=MAYBE UNINITIALIZED} = -4000e-3 = (float)(-4)
8+
COMMENT: end
9+
float *a &[1] = 0xNNNN
210
float *a_grad &[1] = 0xNNNN
311
float *b &[1] = 0xNNNN
412
float *b_grad &[1] = 0xNNNN

0 commit comments

Comments
 (0)