@@ -42,7 +42,7 @@ module Device_config = struct
4242 let name = " cuda"
4343end
4444
45- module Device_stream = Backend_impl. Device_types (Device_config )
45+ module Device_stream = Backend_impl. Device_types_ll (Device_config )
4646open Device_config
4747
4848let set_ctx ctx = Cu.Context. set_current ctx
@@ -321,7 +321,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
321321 (string " hexp2(hlog2(" ^^ v1 ^^ string " ),"
322322 ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
323323 ^^ string " )" )
324- | ToPowOf , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Fp8_prec _ ) ->
324+ | ToPowOf , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Fp8_prec _ | Uint4x32_prec _ ) ->
325325 invalid_arg " Cuda_backend.binop_syntax: ToPowOf not supported for integer precisions"
326326 | ToPowOf , Bfloat16_prec _ ->
327327 fun v1 v2 ->
@@ -350,6 +350,50 @@ end) : Ir.Backend_impl.Lowered_backend = struct
350350 (nest 2
351351 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
352352 ^^ string " __float2bfloat16(0.0f)" ))))
353+ | Relu_gate , Half_prec _ ->
354+ fun v1 v2 ->
355+ group
356+ (parens
357+ (group (parens (v1 ^^ string " > 0.0h" ))
358+ ^^ ifflat
359+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
360+ ^^ string " 0.0h" )
361+ (nest 2
362+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
363+ ^^ string " 0.0h" ))))
364+ | Relu_gate , Single_prec _ ->
365+ fun v1 v2 ->
366+ group
367+ (parens
368+ (group (parens (v1 ^^ string " > 0.0f" ))
369+ ^^ ifflat
370+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
371+ ^^ string " 0.0f" )
372+ (nest 2
373+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
374+ ^^ string " 0.0f" ))))
375+ | Relu_gate , Double_prec _ ->
376+ fun v1 v2 ->
377+ group
378+ (parens
379+ (group (parens (v1 ^^ string " > 0.0" ))
380+ ^^ ifflat
381+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
382+ ^^ string " 0.0" )
383+ (nest 2
384+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
385+ ^^ string " 0.0" ))))
386+ | Relu_gate , Uint4x32_prec _ ->
387+ fun v1 v2 ->
388+ group
389+ (parens
390+ (group (parens (v1 ^^ string " > 0" ))
391+ ^^ ifflat
392+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
393+ ^^ string " 0" )
394+ (nest 2
395+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
396+ ^^ string " 0" ))))
353397 | Satur01_gate , Byte_prec _ ->
354398 fun v1 v2 ->
355399 group
@@ -402,14 +446,97 @@ end) : Ir.Backend_impl.Lowered_backend = struct
402446 (nest 2
403447 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
404448 ^^ string " 0.0" ))))
449+ | Satur01_gate , Uint16_prec _ ->
450+ fun v1 v2 ->
451+ group
452+ (parens
453+ (group
454+ (parens
455+ (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
456+ ^^ string " < 1.0f" ))
457+ ^^ ifflat
458+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
459+ ^^ string " (unsigned short)0" )
460+ (nest 2
461+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
462+ ^^ string " (unsigned short)0" ))))
463+ | Satur01_gate , Int32_prec _ ->
464+ fun v1 v2 ->
465+ group
466+ (parens
467+ (group
468+ (parens
469+ (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
470+ ^^ string " < 1.0f" ))
471+ ^^ ifflat
472+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
473+ ^^ string " 0" )
474+ (nest 2
475+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
476+ ^^ string " 0" ))))
477+ | Satur01_gate , Uint4x32_prec _ ->
478+ fun v1 v2 ->
479+ group
480+ (parens
481+ (group
482+ (parens
483+ (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
484+ ^^ string " < 1.0f" ))
485+ ^^ ifflat
486+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
487+ ^^ string " 0u" )
488+ (nest 2
489+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
490+ ^^ string " 0u" ))))
491+ | Satur01_gate , Bfloat16_prec _ ->
492+ fun v1 v2 ->
493+ group
494+ (parens
495+ (group
496+ (parens
497+ (string " __bfloat162float(" ^^ v1
498+ ^^ string " ) > 0.0f && __bfloat162float("
499+ ^^ v1 ^^ string " ) < 1.0f" ))
500+ ^^ ifflat
501+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
502+ ^^ string " __float2bfloat16(0.0f)" )
503+ (nest 2
504+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
505+ ^^ string " __float2bfloat16(0.0f)" ))))
506+ | Satur01_gate , Fp8_prec _ ->
507+ fun v1 v2 ->
508+ group
509+ (parens
510+ (group
511+ (parens
512+ (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
513+ ^^ string " < 1.0f" ))
514+ ^^ ifflat
515+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
516+ ^^ string " (unsigned char)0" )
517+ (nest 2
518+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
519+ ^^ string " (unsigned char)0" ))))
405520 | Max , Byte_prec _ -> func " max"
406521 | Max , Half_prec _ -> func " __hmax"
407522 | Max , Double_prec _ -> func " fmax"
408523 | Max , Single_prec _ -> func " fmaxf"
524+ | Max , Uint16_prec _ -> func " max"
525+ | Max , Int32_prec _ -> func " max"
526+ | Max , Uint4x32_prec _ -> func " max"
527+ | Max , Bfloat16_prec _ ->
528+ (* FIXME: This might be wrong, definitely verify and maybe fix, here and elsewhere *)
529+ func " __hmax"
530+ | Max , Fp8_prec _ -> func " max"
409531 | Min , Byte_prec _ -> func " min"
410532 | Min , Half_prec _ -> func " __hmin"
411533 | Min , Double_prec _ -> func " fmin"
412534 | Min , Single_prec _ -> func " fminf"
535+ | Min , Uint16_prec _ -> func " min"
536+ | Min , Int32_prec _ -> func " min"
537+ | Min , Uint4x32_prec _ -> func " min"
538+ | Min , Bfloat16_prec _ -> func " __hmin"
539+ | Min , Fp8_prec _ -> func " min"
413540 | Mod , Byte_prec _ -> f " %"
414541 | Mod , _ -> func " fmod"
415542 | Cmplt , _ -> f " <"
@@ -480,17 +607,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
480607 | Tanh_approx , _ -> func " tanh"
481608 | Not , _ -> f " (" " == 0.0 ? 1.0 : 0.0)"
482609 | Uint4x32_to_prec_uniform , _ ->
483- let conv_func = match prec with
484- | Ops. Single_prec _ -> " uint4x32_to_single_uniform"
485- | Double_prec _ -> " uint4x32_to_double_uniform"
486- | Half_prec _ -> " uint4x32_to_fp16_uniform"
487- | Bfloat16_prec _ -> " uint4x32_to_bf16_uniform"
488- | Byte_prec _ -> " uint4x32_to_u8_uniform"
489- | Uint16_prec _ -> " uint4x32_to_u32_uniform" (* Should probably be u16, but using u32 for 16-bit unsigned *)
490- | Int32_prec _ -> " uint4x32_to_i32_uniform"
491- | _ -> " /* unsupported conversion from uint4x32 */ 0"
492- in
493- func conv_func
610+ func (" uint4x32_to_" ^ Ops. prec_string prec ^ " _uniform" )
494611
495612 let ternop_syntax prec v =
496613 let open PPrint in
@@ -564,12 +681,15 @@ end) : Ir.Backend_impl.Lowered_backend = struct
564681 let idx_params = Indexing. bound_symbols bindings in
565682 let b = Buffer. create 4096 in
566683 (* Read and prepend the CUDA builtins file *)
567- let builtins_path = Stdlib.Filename. concat (Stdlib.Filename. dirname __FILE__) " arrayjit_builtins.cu" in
684+ let builtins_path =
685+ Stdlib.Filename. concat (Stdlib.Filename. dirname Stdlib. __FILE__) " arrayjit_builtins.cu"
686+ in
568687 (try
569688 let builtins_content = Stdio.In_channel. read_all builtins_path in
570689 Buffer. add_string b builtins_content;
571690 Buffer. add_string b " \n\n "
572- with _ -> () ); (* Silently skip if file not found *)
691+ with _ -> () );
692+ (* Silently skip if file not found *)
573693 let declarations_doc = Syntax. print_declarations () in
574694 let params_and_docs =
575695 Array. map2_exn names lowereds
0 commit comments