Skip to content

Commit 13f3be5

Browse files
committed
Fix CUDA backend staleness regressions, obsoleted precision naming
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 91cb919 commit 13f3be5

File tree

4 files changed

+142
-33
lines changed

4 files changed

+142
-33
lines changed

arrayjit/lib/arrayjit_builtins.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ __device__ double uint4x32_to_double_uniform(uint4x32_t x) {
132132
}
133133

134134
/* Uint4x32 to int32 uniform */
135-
__device__ int32_t uint4x32_to_i32_uniform(uint4x32_t x) {
135+
__device__ int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
136136
return (int32_t)x.v[0];
137137
}
138138

@@ -162,13 +162,13 @@ __device__ uint8_t uint4x32_to_u8_uniform(uint4x32_t x) {
162162
}
163163

164164
/* Uint4x32 to bfloat16 uniform */
165-
__device__ uint16_t uint4x32_to_bf16_uniform(uint4x32_t x) {
165+
__device__ uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
166166
float f = uint32_to_single_uniform(x.v[0]);
167167
return (uint16_t)(__float_as_uint(f) >> 16);
168168
}
169169

170170
/* Uint4x32 to float16 uniform using CUDA half intrinsics */
171-
__device__ __half uint4x32_to_fp16_uniform(uint4x32_t x) {
171+
__device__ __half uint4x32_to_half_uniform(uint4x32_t x) {
172172
float f = uint32_to_single_uniform(x.v[0]);
173173
return __float2half(f);
174174
}

arrayjit/lib/arrayjit_builtins.msl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ float uint4x32_to_double_uniform(uint4x32_t x) {
150150
}
151151

152152
/* Uint4x32 to int32 uniform */
153-
int32_t uint4x32_to_i32_uniform(uint4x32_t x) {
153+
int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
154154
return int32_t(x.v.x);
155155
}
156156

@@ -180,13 +180,13 @@ uint8_t uint4x32_to_u8_uniform(uint4x32_t x) {
180180
}
181181

182182
/* Uint4x32 to bfloat16 uniform */
183-
uint16_t uint4x32_to_bf16_uniform(uint4x32_t x) {
183+
uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
184184
float f = uint32_to_single_uniform(x.v.x);
185185
return uint16_t(as_type<uint32_t>(f) >> 16);
186186
}
187187

188188
/* Uint4x32 to float16 uniform */
189-
half uint4x32_to_fp16_uniform(uint4x32_t x) {
189+
half uint4x32_to_half_uniform(uint4x32_t x) {
190190
float f = uint32_to_single_uniform(x.v.x);
191191
return half(f);
192192
}

arrayjit/lib/cuda_backend.ml

Lines changed: 135 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ module Device_config = struct
4242
let name = "cuda"
4343
end
4444

45-
module Device_stream = Backend_impl.Device_types (Device_config)
45+
module Device_stream = Backend_impl.Device_types_ll (Device_config)
4646
open Device_config
4747

4848
let set_ctx ctx = Cu.Context.set_current ctx
@@ -321,7 +321,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
321321
(string "hexp2(hlog2(" ^^ v1 ^^ string "),"
322322
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
323323
^^ string ")")
324-
| ToPowOf, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Fp8_prec _) ->
324+
| ToPowOf, (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Fp8_prec _ | Uint4x32_prec _) ->
325325
invalid_arg "Cuda_backend.binop_syntax: ToPowOf not supported for integer precisions"
326326
| ToPowOf, Bfloat16_prec _ ->
327327
fun v1 v2 ->
@@ -350,6 +350,50 @@ end) : Ir.Backend_impl.Lowered_backend = struct
350350
(nest 2
351351
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
352352
^^ string "__float2bfloat16(0.0f)"))))
353+
| Relu_gate, Half_prec _ ->
354+
fun v1 v2 ->
355+
group
356+
(parens
357+
(group (parens (v1 ^^ string " > 0.0h"))
358+
^^ ifflat
359+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
360+
^^ string "0.0h")
361+
(nest 2
362+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
363+
^^ string "0.0h"))))
364+
| Relu_gate, Single_prec _ ->
365+
fun v1 v2 ->
366+
group
367+
(parens
368+
(group (parens (v1 ^^ string " > 0.0f"))
369+
^^ ifflat
370+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
371+
^^ string "0.0f")
372+
(nest 2
373+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
374+
^^ string "0.0f"))))
375+
| Relu_gate, Double_prec _ ->
376+
fun v1 v2 ->
377+
group
378+
(parens
379+
(group (parens (v1 ^^ string " > 0.0"))
380+
^^ ifflat
381+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
382+
^^ string "0.0")
383+
(nest 2
384+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
385+
^^ string "0.0"))))
386+
| Relu_gate, Uint4x32_prec _ ->
387+
fun v1 v2 ->
388+
group
389+
(parens
390+
(group (parens (v1 ^^ string " > 0"))
391+
^^ ifflat
392+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
393+
^^ string "0")
394+
(nest 2
395+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
396+
^^ string "0"))))
353397
| Satur01_gate, Byte_prec _ ->
354398
fun v1 v2 ->
355399
group
@@ -402,14 +446,97 @@ end) : Ir.Backend_impl.Lowered_backend = struct
402446
(nest 2
403447
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
404448
^^ string "0.0"))))
449+
| Satur01_gate, Uint16_prec _ ->
450+
fun v1 v2 ->
451+
group
452+
(parens
453+
(group
454+
(parens
455+
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
456+
^^ string " < 1.0f"))
457+
^^ ifflat
458+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
459+
^^ string "(unsigned short)0")
460+
(nest 2
461+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
462+
^^ string "(unsigned short)0"))))
463+
| Satur01_gate, Int32_prec _ ->
464+
fun v1 v2 ->
465+
group
466+
(parens
467+
(group
468+
(parens
469+
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
470+
^^ string " < 1.0f"))
471+
^^ ifflat
472+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
473+
^^ string "0")
474+
(nest 2
475+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
476+
^^ string "0"))))
477+
| Satur01_gate, Uint4x32_prec _ ->
478+
fun v1 v2 ->
479+
group
480+
(parens
481+
(group
482+
(parens
483+
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
484+
^^ string " < 1.0f"))
485+
^^ ifflat
486+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
487+
^^ string "0u")
488+
(nest 2
489+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
490+
^^ string "0u"))))
491+
| Satur01_gate, Bfloat16_prec _ ->
492+
fun v1 v2 ->
493+
group
494+
(parens
495+
(group
496+
(parens
497+
(string "__bfloat162float(" ^^ v1
498+
^^ string ") > 0.0f && __bfloat162float("
499+
^^ v1 ^^ string ") < 1.0f"))
500+
^^ ifflat
501+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
502+
^^ string "__float2bfloat16(0.0f)")
503+
(nest 2
504+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
505+
^^ string "__float2bfloat16(0.0f)"))))
506+
| Satur01_gate, Fp8_prec _ ->
507+
fun v1 v2 ->
508+
group
509+
(parens
510+
(group
511+
(parens
512+
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
513+
^^ string " < 1.0f"))
514+
^^ ifflat
515+
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
516+
^^ string "(unsigned char)0")
517+
(nest 2
518+
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
519+
^^ string "(unsigned char)0"))))
405520
| Max, Byte_prec _ -> func "max"
406521
| Max, Half_prec _ -> func "__hmax"
407522
| Max, Double_prec _ -> func "fmax"
408523
| Max, Single_prec _ -> func "fmaxf"
524+
| Max, Uint16_prec _ -> func "max"
525+
| Max, Int32_prec _ -> func "max"
526+
| Max, Uint4x32_prec _ -> func "max"
527+
| Max, Bfloat16_prec _ ->
528+
(* FIXME: This might be wrong, definitely verify and maybe fix, here and elsewhere *)
529+
func "__hmax"
530+
| Max, Fp8_prec _ -> func "max"
409531
| Min, Byte_prec _ -> func "min"
410532
| Min, Half_prec _ -> func "__hmin"
411533
| Min, Double_prec _ -> func "fmin"
412534
| Min, Single_prec _ -> func "fminf"
535+
| Min, Uint16_prec _ -> func "min"
536+
| Min, Int32_prec _ -> func "min"
537+
| Min, Uint4x32_prec _ -> func "min"
538+
| Min, Bfloat16_prec _ -> func "__hmin"
539+
| Min, Fp8_prec _ -> func "min"
413540
| Mod, Byte_prec _ -> f "%"
414541
| Mod, _ -> func "fmod"
415542
| Cmplt, _ -> f "<"
@@ -480,17 +607,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
480607
| Tanh_approx, _ -> func "tanh"
481608
| Not, _ -> f "(" " == 0.0 ? 1.0 : 0.0)"
482609
| Uint4x32_to_prec_uniform, _ ->
483-
let conv_func = match prec with
484-
| Ops.Single_prec _ -> "uint4x32_to_single_uniform"
485-
| Double_prec _ -> "uint4x32_to_double_uniform"
486-
| Half_prec _ -> "uint4x32_to_fp16_uniform"
487-
| Bfloat16_prec _ -> "uint4x32_to_bf16_uniform"
488-
| Byte_prec _ -> "uint4x32_to_u8_uniform"
489-
| Uint16_prec _ -> "uint4x32_to_u32_uniform" (* Should probably be u16, but using u32 for 16-bit unsigned *)
490-
| Int32_prec _ -> "uint4x32_to_i32_uniform"
491-
| _ -> "/* unsupported conversion from uint4x32 */ 0"
492-
in
493-
func conv_func
610+
func ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform")
494611

495612
let ternop_syntax prec v =
496613
let open PPrint in
@@ -564,12 +681,15 @@ end) : Ir.Backend_impl.Lowered_backend = struct
564681
let idx_params = Indexing.bound_symbols bindings in
565682
let b = Buffer.create 4096 in
566683
(* Read and prepend the CUDA builtins file *)
567-
let builtins_path = Stdlib.Filename.concat (Stdlib.Filename.dirname __FILE__) "arrayjit_builtins.cu" in
684+
let builtins_path =
685+
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "arrayjit_builtins.cu"
686+
in
568687
(try
569688
let builtins_content = Stdio.In_channel.read_all builtins_path in
570689
Buffer.add_string b builtins_content;
571690
Buffer.add_string b "\n\n"
572-
with _ -> ()); (* Silently skip if file not found *)
691+
with _ -> ());
692+
(* Silently skip if file not found *)
573693
let declarations_doc = Syntax.print_declarations () in
574694
let params_and_docs =
575695
Array.map2_exn names lowereds

arrayjit/lib/metal_backend.ml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -555,18 +555,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
555555
| Recip_sqrt, _ -> func_doc "rsqrt"
556556
| Tanh_approx, _ -> func_doc "tanh"
557557
| Not, _ -> fun v -> string "!" ^^ v
558-
| Uint4x32_to_prec_uniform, _ ->
559-
let conv_func = match prec with
560-
| Ops.Single_prec _ -> "uint4x32_to_single_uniform"
561-
| Double_prec _ -> "uint4x32_to_double_uniform" (* Metal doesn't support double, but function exists *)
562-
| Half_prec _ -> "uint4x32_to_fp16_uniform"
563-
| Bfloat16_prec _ -> "uint4x32_to_bf16_uniform"
564-
| Byte_prec _ -> "uint4x32_to_u8_uniform"
565-
| Uint16_prec _ -> "uint4x32_to_u32_uniform" (* Should probably be u16 *)
566-
| Int32_prec _ -> "uint4x32_to_i32_uniform"
567-
| _ -> "/* unsupported conversion from uint4x32 */ 0"
568-
in
569-
func_doc conv_func
558+
| Uint4x32_to_prec_uniform, _ -> func_doc ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform")
570559
(* Logical not *)
571560

572561
let convert_precision ~from ~to_ =

0 commit comments

Comments
 (0)