@@ -80,10 +80,8 @@ module Alloc_buffer = struct
8080 track_allocation new_buffer_obj;
8181 { ptr = new_buffer_obj; size_in_bytes }
8282
83- let alloc_zero_init_array prec ~dims (stream : stream ) =
84- let size_in_bytes =
85- (if Array. length dims = 0 then 0 else Array. reduce_exn dims ~f: ( * )) * Ops. prec_in_bytes prec
86- in
83+ let % track7_sexp alloc_zero_init_array (prec : Ops.prec ) ~(dims : int array ) (stream : stream ) =
84+ let size_in_bytes = Array. fold dims ~init: 1 ~f: ( * ) * Ops. prec_in_bytes prec in
8785 let device = stream.device.dev in
8886 let buffer = Me.Buffer. on_device device ~length: size_in_bytes resource_options in
8987 track_allocation buffer;
@@ -448,13 +446,15 @@ end) : Ir.Backend_impl.Lowered_backend = struct
448446 | Ops. Bfloat16_prec _ -> " bfloat" (* Metal supports bfloat16 natively *)
449447 | Ops. Fp8_prec _ -> invalid_arg " Metal backend does not support FP8 precision"
450448 | Ops. Single_prec _ -> " float"
451- | Ops. Double_prec _ -> raise @@ Utils. User_error " Metal backend does not support double precision"
449+ | Ops. Double_prec _ ->
450+ raise @@ Utils. User_error " Metal backend does not support double precision"
452451 | Ops. Void_prec -> " void"
453452
454453 let vec_typ_of_prec ~length prec =
455454 match (prec, length) with
456455 | Ops. Single_prec _ , 4 -> " float4_t"
457- | Ops. Double_prec _ , 2 -> raise @@ Utils. User_error " Metal backend does not support double precision"
456+ | Ops. Double_prec _ , 2 ->
457+ raise @@ Utils. User_error " Metal backend does not support double precision"
458458 | Ops. Int32_prec _ , 4 -> " int32x4_t"
459459 | (Ops. Byte_prec _ | Ops. Fp8_prec _ ), 16 -> " int8x16_t"
460460 | (Ops. Uint16_prec _ | Ops. Bfloat16_prec _ ), 8 -> " uint16x8_t"
@@ -472,7 +472,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
472472 | Ops. Bfloat16_prec _ -> " bf" (* TODO: Verify actual Metal suffix for bfloat16 *)
473473 | Ops. Fp8_prec _ -> invalid_arg " Metal backend does not support FP8 precision"
474474 | Ops. Single_prec _ -> " f"
475- | Ops. Double_prec _ -> raise @@ Utils. User_error " Metal backend does not support double precision"
475+ | Ops. Double_prec _ ->
476+ raise @@ Utils. User_error " Metal backend does not support double precision"
476477 | Ops. Void_prec -> " "
477478
478479 let ternop_syntax _prec op =
@@ -532,13 +533,17 @@ end) : Ir.Backend_impl.Lowered_backend = struct
532533 ^^ space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
533534 ^^ string (" 0.0" ^ s)))
534535 | ToPowOf , _ -> func " pow"
535- | Threefry4x32 , _ ->
536+ | Threefry4x32 , _ -> (
536537 (* Threefry4x32 must output to uint4x32 precision *)
537- ( match prec with
538+ match prec with
538539 | Ops. Uint4x32_prec _ -> func " arrayjit_threefry4x32"
539- | _ -> raise @@ Utils. User_error
540- (Printf. sprintf " Metal backend: Threefry4x32 requires target precision to be uint4x32, but got %s"
541- (Ops. prec_string prec)))
540+ | _ ->
541+ raise
542+ @@ Utils. User_error
543+ (Printf. sprintf
544+ " Metal backend: Threefry4x32 requires target precision to be uint4x32, but \
545+ got %s"
546+ (Ops. prec_string prec)))
542547 | Arg1 , _ | Arg2 , _ -> invalid_arg " Metal C_syntax_config: Arg1/Arg2 not operators"
543548
544549 let unop_syntax prec op =
@@ -555,7 +560,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
555560 | Sqrt , _ -> func_doc " sqrt"
556561 | Relu , Ops. Half_prec _ -> fun v -> func_doc " max" (separate comma_sep [ string " 0.0h" ; v ])
557562 | Relu , Ops. Single_prec _ -> fun v -> func_doc " max" (separate comma_sep [ string " 0.0f" ; v ])
558- | Relu , Ops. Double_prec _ -> raise @@ Utils. User_error " Metal backend does not support double precision"
563+ | Relu , Ops. Double_prec _ ->
564+ raise @@ Utils. User_error " Metal backend does not support double precision"
559565 | Relu , _ (* Byte_prec, Void_prec *) ->
560566 fun v -> func_doc " max" (separate comma_sep [ string " 0" ; v ])
561567 | Satur01 , p ->
0 commit comments