Skip to content

Commit 655d5bb

Browse files
committed
Better names for builtin files, for CUDA handle large builtins via function pointers
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 13f3be5 commit 655d5bb

File tree

7 files changed

+117
-84
lines changed

7 files changed

+117
-84
lines changed
File renamed without changes.
File renamed without changes.

arrayjit/lib/arrayjit_builtins.cu renamed to arrayjit/lib/builtins_large.cu

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ __device__ __forceinline__ void threefry_round(uint4 &x, unsigned int r0, unsign
4343
}
4444

4545
/* Threefry4x32 implementation - 20 rounds */
46-
__device__ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter) {
46+
__device__ uint4x32_t arrayjit_threefry4x32_impl(uint4x32_t key, uint4x32_t counter) {
4747
uint4 x = make_uint4(counter.v[0], counter.v[1], counter.v[2], counter.v[3]);
4848
uint4 k = make_uint4(key.v[0], key.v[1], key.v[2], key.v[3]);
4949

@@ -107,68 +107,4 @@ __device__ uint4x32_t arrayjit_threefry4x32(uint4x32_t key, uint4x32_t counter)
107107
return result;
108108
}
109109

110-
/* Conversion functions from uint4x32 to various precisions uniformly */
111-
112-
/* Convert to float in [0, 1) using CUDA intrinsics */
113-
__device__ __forceinline__ float uint32_to_single_uniform(uint32_t x) {
114-
/* Use __uint2float_rn for correct rounding */
115-
return __uint2float_rn(x >> 8) * (1.0f / 16777216.0f);
116-
}
117-
118-
/* Convert to double in [0, 1) */
119-
__device__ __forceinline__ double uint32_to_double_uniform(uint32_t x) {
120-
return __uint2double_rn(x) * (1.0 / 4294967296.0);
121-
}
122-
123-
/* Uint4x32 to float32 uniform */
124-
__device__ float uint4x32_to_single_uniform(uint4x32_t x) {
125-
return uint32_to_single_uniform(x.v[0]);
126-
}
127-
128-
/* Uint4x32 to float64 uniform */
129-
__device__ double uint4x32_to_double_uniform(uint4x32_t x) {
130-
uint64_t combined = __double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
131-
return __longlong_as_double(combined) * (1.0 / 18446744073709551616.0);
132-
}
133-
134-
/* Uint4x32 to int32 uniform */
135-
__device__ int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
136-
return (int32_t)x.v[0];
137-
}
138-
139-
/* Uint4x32 to int64 uniform */
140-
__device__ int64_t uint4x32_to_i64_uniform(uint4x32_t x) {
141-
return __double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
142-
}
143-
144-
/* Uint4x32 to uint32 uniform */
145-
__device__ uint32_t uint4x32_to_u32_uniform(uint4x32_t x) {
146-
return x.v[0];
147-
}
148-
149-
/* Uint4x32 to uint64 uniform */
150-
__device__ uint64_t uint4x32_to_u64_uniform(uint4x32_t x) {
151-
return (uint64_t)__double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
152-
}
153-
154-
/* Uint4x32 to int8 uniform */
155-
__device__ int8_t uint4x32_to_i8_uniform(uint4x32_t x) {
156-
return (int8_t)(x.v[0] & 0xFF);
157-
}
158-
159-
/* Uint4x32 to uint8 uniform */
160-
__device__ uint8_t uint4x32_to_u8_uniform(uint4x32_t x) {
161-
return (uint8_t)(x.v[0] & 0xFF);
162-
}
163-
164-
/* Uint4x32 to bfloat16 uniform */
165-
__device__ uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
166-
float f = uint32_to_single_uniform(x.v[0]);
167-
return (uint16_t)(__float_as_uint(f) >> 16);
168-
}
169-
170-
/* Uint4x32 to float16 uniform using CUDA half intrinsics */
171-
__device__ __half uint4x32_to_half_uniform(uint4x32_t x) {
172-
float f = uint32_to_single_uniform(x.v[0]);
173-
return __float2half(f);
174-
}
110+
__device__ uint4x32_t ( *arrayjit_threefry4x32)(uint4x32_t key, uint4x32_t counter) = arrayjit_threefry4x32_impl;

arrayjit/lib/builtins_small.cu

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
2+
typedef struct {
3+
uint32_t v[4];
4+
} uint4x32_t;
5+
6+
/* Conversion functions from uint4x32 to various precisions uniformly */
7+
8+
/* Convert to float in [0, 1) using CUDA intrinsics */
9+
__device__ __forceinline__ float uint32_to_single_uniform(uint32_t x) {
10+
/* Use __uint2float_rn for correct rounding */
11+
return __uint2float_rn(x >> 8) * (1.0f / 16777216.0f);
12+
}
13+
14+
/* Convert to double in [0, 1) */
15+
__device__ __forceinline__ double uint32_to_double_uniform(uint32_t x) {
16+
return __uint2double_rn(x) * (1.0 / 4294967296.0);
17+
}
18+
19+
/* Uint4x32 to float32 uniform */
20+
__device__ float uint4x32_to_single_uniform(uint4x32_t x) {
21+
return uint32_to_single_uniform(x.v[0]);
22+
}
23+
24+
/* Uint4x32 to float64 uniform */
25+
__device__ double uint4x32_to_double_uniform(uint4x32_t x) {
26+
uint64_t combined = __double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
27+
return __longlong_as_double(combined) * (1.0 / 18446744073709551616.0);
28+
}
29+
30+
/* Uint4x32 to int32 uniform */
31+
__device__ int32_t uint4x32_to_int32_uniform(uint4x32_t x) {
32+
return (int32_t)x.v[0];
33+
}
34+
35+
/* Uint4x32 to int64 uniform */
36+
__device__ int64_t uint4x32_to_i64_uniform(uint4x32_t x) {
37+
return __double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
38+
}
39+
40+
/* Uint4x32 to uint32 uniform */
41+
__device__ uint32_t uint4x32_to_u32_uniform(uint4x32_t x) {
42+
return x.v[0];
43+
}
44+
45+
/* Uint4x32 to uint64 uniform */
46+
__device__ uint64_t uint4x32_to_u64_uniform(uint4x32_t x) {
47+
return (uint64_t)__double_as_longlong(__hiloint2double(x.v[1], x.v[0]));
48+
}
49+
50+
/* Uint4x32 to int8 uniform */
51+
__device__ int8_t uint4x32_to_i8_uniform(uint4x32_t x) {
52+
return (int8_t)(x.v[0] & 0xFF);
53+
}
54+
55+
/* Uint4x32 to uint8 uniform */
56+
__device__ uint8_t uint4x32_to_u8_uniform(uint4x32_t x) {
57+
return (uint8_t)(x.v[0] & 0xFF);
58+
}
59+
60+
/* Uint4x32 to bfloat16 uniform */
61+
__device__ uint16_t uint4x32_to_bfloat16_uniform(uint4x32_t x) {
62+
float f = uint32_to_single_uniform(x.v[0]);
63+
return (uint16_t)(__float_as_uint(f) >> 16);
64+
}
65+
66+
/* Uint4x32 to float16 uniform using CUDA half intrinsics */
67+
__device__ __half uint4x32_to_half_uniform(uint4x32_t x) {
68+
float f = uint32_to_single_uniform(x.v[0]);
69+
return __float2half(f);
70+
}

arrayjit/lib/cc_backend.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ open Backend_intf
1313
let name = "cc"
1414

1515
(* Header declarations for arrayjit builtins *)
16-
let arrayjit_builtins_header = {|
16+
let builtins_header = {|
1717
/* ArrayJIT builtins declarations */
1818
#include <stdint.h>
1919

@@ -184,7 +184,7 @@ let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized)
184184
let build_file = Utils.open_build_file ~base_name:name ~extension:".c" in
185185
let declarations_doc = Syntax.print_declarations () in
186186
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
187-
let header_doc = PPrint.string arrayjit_builtins_header in
187+
let header_doc = PPrint.string builtins_header in
188188
let final_doc = PPrint.(header_doc ^^ declarations_doc ^^ proc_doc) in
189189
(* Use ribbon = 1.0 for usual code formatting, width 110 *)
190190
PPrint.ToChannel.pretty 1.0 110 build_file.oc final_doc;
@@ -214,7 +214,7 @@ let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized opt
214214
Syntax.compile_proc ~name idx_params lowered))
215215
in
216216
let all_proc_docs = List.filter_map (Array.to_list params_and_docs) ~f:(Option.map ~f:snd) in
217-
let header_doc = PPrint.string arrayjit_builtins_header in
217+
let header_doc = PPrint.string builtins_header in
218218
let final_doc = PPrint.(header_doc ^^ declarations_doc ^^ separate hardline all_proc_docs) in
219219
PPrint.ToChannel.pretty 1.0 110 build_file.oc final_doc;
220220
build_file.finalize ();

arrayjit/lib/cuda_backend.ml

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
606606
| Tanh_approx, Single_prec _ -> func "__tanhf"
607607
| Tanh_approx, _ -> func "tanh"
608608
| Not, _ -> f "(" " == 0.0 ? 1.0 : 0.0)"
609-
| Uint4x32_to_prec_uniform, _ ->
610-
func ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform")
609+
| Uint4x32_to_prec_uniform, _ -> func ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform")
611610

612611
let ternop_syntax prec v =
613612
let open PPrint in
@@ -657,6 +656,24 @@ end) : Ir.Backend_impl.Lowered_backend = struct
657656
^^ rparen ^^ semi
658657
end
659658

659+
let builtins_large_header =
660+
{|
661+
__device__ uint4x32_t ( *arrayjit_threefry4x32)(uint4x32_t key, uint4x32_t counter) = nullptr;
662+
|}
663+
664+
let prepend_builtins b =
665+
if Utils.debug_log_from_routines () then
666+
Buffer.add_string b "__device__ int printf (const char * format, ... );\n";
667+
Buffer.add_string b "\n\n";
668+
let builtins_path =
669+
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "builtins_small.cu"
670+
in
671+
let builtins_content = Stdio.In_channel.read_all builtins_path in
672+
Buffer.add_string b builtins_content;
673+
(* Needs to be after the small builtins, because uses uint4x32_t. *)
674+
Buffer.add_string b builtins_large_header;
675+
Buffer.add_string b "\n\n"
676+
660677
let%diagn2_sexp compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
661678
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
662679
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
@@ -665,8 +682,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
665682
end)) in
666683
let idx_params = Indexing.bound_symbols bindings in
667684
let b = Buffer.create 4096 in
668-
if Utils.debug_log_from_routines () then
669-
Buffer.add_string b "__device__ int printf (const char * format, ... );\n";
685+
prepend_builtins b;
670686
let declarations_doc = Syntax.print_declarations () in
671687
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
672688
let final_doc = PPrint.(declarations_doc ^^ proc_doc) in
@@ -680,16 +696,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
680696
end)) in
681697
let idx_params = Indexing.bound_symbols bindings in
682698
let b = Buffer.create 4096 in
683-
(* Read and prepend the CUDA builtins file *)
684-
let builtins_path =
685-
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "arrayjit_builtins.cu"
686-
in
687-
(try
688-
let builtins_content = Stdio.In_channel.read_all builtins_path in
689-
Buffer.add_string b builtins_content;
690-
Buffer.add_string b "\n\n"
691-
with _ -> ());
692-
(* Silently skip if file not found *)
699+
prepend_builtins b;
693700
let declarations_doc = Syntax.print_declarations () in
694701
let params_and_docs =
695702
Array.map2_exn names lowereds
@@ -787,10 +794,29 @@ end) : Ir.Backend_impl.Lowered_backend = struct
787794
Cu.Module.[ GENERATE_DEBUG_INFO true; GENERATE_LINE_INFO true ]
788795
else []
789796

797+
let set_ptr_in_kernel kernel_module src name =
798+
let dst, _ = Cuda.Module.get_global kernel_module ~name in
799+
(* Copy the helper function address to the kernel's function pointer variable *)
800+
Cuda.Deviceptr.memcpy_D_to_D ~dst ~src ~size_in_bytes:8 (* pointer size *) ()
801+
802+
let set_builtins_in_kernel =
803+
assert !initialized;
804+
let builtins_path =
805+
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "builtins_large.cu"
806+
in
807+
let cu_src = Stdio.In_channel.read_all builtins_path in
808+
let code = cuda_to_ptx ~name:"builtins_large" cu_src in
809+
(* set_ctx ctx; *)
810+
let run_module = Cu.Module.load_data_ex code (run_options ()) in
811+
let threefry4x32_ptr, _ = Cu.Module.get_global run_module ~name:"arrayjit_threefry4x32" in
812+
fun kernel_module ->
813+
set_ptr_in_kernel kernel_module threefry4x32_ptr "arrayjit_threefry4x32"
814+
790815
let%track3_sexp link prior_context (code : code) ctx_arrays =
791816
let ctx = ctx_of prior_context in
792817
set_ctx ctx;
793818
let run_module = Cu.Module.load_data_ex code.ptx (run_options ()) in
819+
set_builtins_in_kernel run_module;
794820
let idx_params = Indexing.bound_symbols code.bindings in
795821
let lowered_bindings : Indexing.lowered_bindings =
796822
List.map idx_params ~f:(fun s -> (s, ref 0))
@@ -809,6 +835,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
809835
let ctx = ctx_of prior_context in
810836
set_ctx ctx;
811837
let run_module = Cu.Module.load_data_ex code_batch.ptx (run_options ()) in
838+
set_builtins_in_kernel run_module;
812839
let procs =
813840
Array.mapi code_batch.params_and_names ~f:(fun i pns ->
814841
Option.value ~default:None

arrayjit/lib/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
ppx_minidebug.runtime)
3636
(foreign_stubs
3737
(language c)
38-
(names arrayjit_builtins))
38+
(names builtins))
3939
(preprocess
4040
(pps
4141
ppx_compare

0 commit comments

Comments
 (0)