@@ -67,7 +67,11 @@ module Alloc_buffer = struct
6767 (if Array. length dims = 0 then 0 else Array. reduce_exn dims ~f: ( * )) * Ops. prec_in_bytes prec
6868 in
6969 set_ctx stream.device.dev.primary_context;
70- Cu.Deviceptr. mem_alloc ~size_in_bytes
70+ let ptr = Cu.Deviceptr. mem_alloc ~size_in_bytes in
71+ (* TODO: consider using memset_d8 to zero-initialize the memory. *)
72+ (* if size_in_bytes > 0 then
73+ Cu.Stream.memset_d8 ptr Unsigned.UChar.zero ~length:size_in_bytes stream.runner; *)
74+ ptr
7175
7276 let free_buffer = Some (fun _stream ptr -> Cu.Deviceptr. mem_free ptr)
7377end
@@ -283,10 +287,10 @@ end) : Ir.Backend_impl.Lowered_backend = struct
283287 | Void_prec -> " void"
284288
285289 let binop_syntax prec v =
290+ (* TODO: consider using binop_syntax inherited from Pure_C_config and overriding only
291+ where different. *)
286292 let open PPrint in
287- let f op_str v1 v2 =
288- group (lparen ^^ v1 ^^ space ^^ string op_str ^^ space ^^ v2 ^^ rparen)
289- in
293+ let f op_str v1 v2 = group @@ parens (v1 ^^ space ^^ string op_str ^^ space ^^ v2) in
290294 let func fn v1 v2 = group (string fn ^^ parens (separate comma [ v1; v2 ])) in
291295 match (v, prec) with
292296 | Ops. Arg1 , _ -> invalid_arg " Cuda_backend.binop_syntax: Arg1 is not an operator"
@@ -307,44 +311,50 @@ end) : Ir.Backend_impl.Lowered_backend = struct
307311 | ToPowOf , Byte_prec _ ->
308312 invalid_arg " Cuda_backend.binop_syntax: ToPowOf not supported for byte/integer precisions"
309313 | Relu_gate , Byte_prec _ ->
310- fun v1 v2 -> group (parens (v1 ^^ string " > 0" ) ^^ string " ? " ^^ v2 ^^ string " : 0" )
314+ fun v1 v2 ->
315+ group @@ parens (parens (v1 ^^ string " > 0" ) ^^ string " ? " ^^ v2 ^^ string " : 0" )
311316 | Relu_gate , Half_prec _ ->
312317 fun v1 v2 ->
313318 group
314- (parens
315- (string " __hgt(" ^^ v1 ^^ comma
316- ^^ string " __ushort_as_half((unsigned short)0x0000U))" )
317- ^^ string " ? " ^^ v2
318- ^^ string " : __ushort_as_half((unsigned short)0x0000U)" )
319+ @@ parens
320+ (parens
321+ (string " __hgt(" ^^ v1 ^^ comma
322+ ^^ string " __ushort_as_half((unsigned short)0x0000U))" )
323+ ^^ string " ? " ^^ v2
324+ ^^ string " : __ushort_as_half((unsigned short)0x0000U)" )
319325 | Relu_gate , _ ->
320326 fun v1 v2 ->
321- group (parens (v1 ^^ string " > 0.0" ) ^^ string " ? " ^^ v2 ^^ string " : 0.0" )
327+ group @@ parens (parens (v1 ^^ string " > 0.0" ) ^^ string " ? " ^^ v2 ^^ string " : 0.0" )
322328 | Satur01_gate , Byte_prec _ ->
323329 fun v1 v2 ->
324- parens
325- (parens
326- (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f" )
327- ^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0" )
330+ group
331+ @@ parens
332+ (parens
333+ (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f" )
334+ ^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0" )
328335 | Satur01_gate , Half_prec _ ->
329336 fun v1 v2 ->
330- parens
331- (parens
332- (string " __hgt(" ^^ v1 ^^ comma
333- ^^ string " __ushort_as_half((unsigned short)0x0000U)) && __hlt("
334- ^^ v1 ^^ comma
335- ^^ string " __ushort_as_half((unsigned short)0x3C00U)))" )
336- ^^ string " ? " ^^ v2
337- ^^ string " : __ushort_as_half((unsigned short)0x0000U)" )
337+ group
338+ @@ parens
339+ (parens
340+ (string " __hgt(" ^^ v1 ^^ comma
341+ ^^ string " __ushort_as_half((unsigned short)0x0000U)) && __hlt("
342+ ^^ v1 ^^ comma
343+ ^^ string " __ushort_as_half((unsigned short)0x3C00U)))" )
344+ ^^ string " ? " ^^ v2
345+ ^^ string " : __ushort_as_half((unsigned short)0x0000U)" )
338346 | Satur01_gate , Single_prec _ ->
339347 fun v1 v2 ->
340- parens
341- (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f" )
342- ^^ string " ? " ^^ v2 ^^ string " : 0.0f" )
348+ group
349+ @@ parens
350+ (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f" )
351+ ^^ string " ? " ^^ v2 ^^ string " : 0.0f" )
343352 | Satur01_gate , Double_prec _ ->
344353 fun v1 v2 ->
345- parens
346- (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0" )
347- ^^ string " ? " ^^ v2 ^^ string " : 0.0" )
354+ group
355+ @@ parens
356+ (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0" )
357+ ^^ string " ? " ^^ v2 ^^ string " : 0.0" )
348358 | Max , Byte_prec _ -> func " max"
349359 | Max , Half_prec _ -> func " __hmax"
350360 | Max , Double_prec _ -> func " fmax"
@@ -403,13 +413,16 @@ end) : Ir.Backend_impl.Lowered_backend = struct
403413 | Recip , Byte_prec _ ->
404414 invalid_arg " Cuda_backend.unop_syntax: Recip not supported for byte/integer precisions"
405415 | Recip , Half_prec _ -> func " hrcp"
406- | Recip , _ -> f " (1.0 / (" " ))"
416+ | Recip , Single_prec _ -> f " (1.0f / (" " ))"
417+ | Recip , Double_prec _ -> f " (1.0 / (" " ))"
418+ | Recip , _ -> f " (1 / (" " ))"
407419 | Recip_sqrt , Byte_prec _ ->
408420 invalid_arg
409421 " Cuda_backend.unop_syntax: Recip_sqrt not supported for byte/integer precisions"
410422 | Recip_sqrt , Half_prec _ -> func " hrsqrt"
411423 | Recip_sqrt , Double_prec _ -> f " (1.0 / sqrt(" " ))"
412- | Recip_sqrt , _ -> f " (1.0 / sqrtf(" " ))"
424+ | Recip_sqrt , Single_prec _ -> f " (1.0f / sqrtf(" " ))"
425+ | Recip_sqrt , _ -> f " (1 / sqrtf(" " ))"
413426 | Neg , _ -> f " (-(" " ))"
414427 | Tanh_approx , Byte_prec _ ->
415428 invalid_arg
0 commit comments