Skip to content

Commit b54b343

Browse files
committed
Untested: fix builtins modules across devices
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 655d5bb commit b54b343

File tree

1 file changed

+64
-55
lines changed

1 file changed

+64
-55
lines changed

arrayjit/lib/cuda_backend.ml

Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ end
3535
module Device_config = struct
3636
include Backend_buffer
3737

38-
type dev = { dev : Cu.Device.t; primary_context : Cu.Context.t } [@@deriving sexp_of]
38+
type dev = {
39+
dev : Cu.Device.t;
40+
primary_context : Cu.Context.t;
41+
set_builtins_in : Cu.Module.t -> unit;
42+
}
43+
[@@deriving sexp_of]
44+
3945
type runner = Cu.Stream.t [@@deriving sexp_of]
4046
type event = Cu.Delimited_event.t [@@deriving sexp_of]
4147

@@ -98,6 +104,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
98104
initialized := true)
99105

100106
let num_devices = Cu.Device.get_count
107+
108+
(* [devices] is mutable to support plugging in new devices. *)
101109
let devices = ref @@ Array.create ~len:(num_devices ()) None
102110

103111
let get_used_memory (device : device) =
@@ -121,6 +129,57 @@ end) : Ir.Backend_impl.Lowered_backend = struct
121129
Hashtbl.iter device.cross_stream_candidates ~f:(fun buffer_ptr ->
122130
Cu.Deviceptr.mem_free buffer_ptr)
123131

132+
let%diagn2_sexp cuda_to_ptx ~name cu_src =
133+
let name_cu = name ^ ".cu" in
134+
if Utils.settings.output_debug_files_in_build_directory then (
135+
let build_file = Utils.open_build_file ~base_name:name ~extension:".cu" in
136+
Stdio.Out_channel.output_string build_file.oc cu_src;
137+
build_file.finalize ());
138+
[%log "compiling to PTX"];
139+
let with_debug =
140+
Utils.settings.output_debug_files_in_build_directory || Utils.settings.log_level > 0
141+
in
142+
let options =
143+
"--use_fast_math" :: (if Utils.with_runtime_debug () then [ "--device-debug" ] else [])
144+
in
145+
(* FIXME: every now and then the compilation crashes because the options are garbled. *)
146+
(* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
147+
let ptx = Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
148+
if Utils.settings.output_debug_files_in_build_directory then (
149+
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".ptx" in
150+
Stdio.Out_channel.output_string oc @@ Nvrtc.string_from_ptx ptx;
151+
Stdio.Out_channel.flush oc;
152+
Stdio.Out_channel.close oc;
153+
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".cu_log" in
154+
Stdio.Out_channel.output_string oc
155+
@@ Option.value_exn ~here:[%here] (Nvrtc.compilation_log ptx);
156+
Stdio.Out_channel.flush oc;
157+
Stdio.Out_channel.close oc);
158+
ptx
159+
160+
let run_options () =
161+
if Utils.with_runtime_debug () then
162+
Cu.Module.[ GENERATE_DEBUG_INFO true; GENERATE_LINE_INFO true ]
163+
else []
164+
165+
let set_ptr_in_kernel kernel_module src name =
166+
let dst, _ = Cuda.Module.get_global kernel_module ~name in
167+
(* Copy the helper function address to the kernel's function pointer variable *)
168+
Cuda.Deviceptr.memcpy_D_to_D ~dst ~src ~size_in_bytes:8 (* pointer size *) ()
169+
170+
let set_builtins_for_device =
171+
assert !initialized;
172+
let builtins_path =
173+
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "builtins_large.cu"
174+
in
175+
let cu_src = Stdio.In_channel.read_all builtins_path in
176+
let code = cuda_to_ptx ~name:"builtins_large" cu_src in
177+
fun ~primary_context ->
178+
set_ctx primary_context;
179+
let run_module = Cu.Module.load_data_ex code (run_options ()) in
180+
let threefry4x32_ptr, _ = Cu.Module.get_global run_module ~name:"arrayjit_threefry4x32" in
181+
fun kernel_module -> set_ptr_in_kernel kernel_module threefry4x32_ptr "arrayjit_threefry4x32"
182+
124183
let%track3_sexp get_device ~(ordinal : int) : device =
125184
if num_devices () <= ordinal then
126185
invalid_arg [%string "Exec_as_cuda.get_device %{ordinal#Int}: not enough devices"];
@@ -130,7 +189,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
130189
let default () =
131190
let dev = Cu.Device.get ~ordinal in
132191
let primary_context : Cu.Context.t = Cu.Context.get_primary dev in
133-
let dev = { dev; primary_context } in
192+
let set_builtins_in = set_builtins_for_device ~primary_context in
193+
let dev = { dev; primary_context; set_builtins_in } in
134194
set_ctx primary_context;
135195
if Utils.debug_log_from_routines () && not (Hash_set.mem initialized_devices ordinal) then
136196
Int.of_string_opt @@ Utils.get_global_arg ~arg_name:"cuda_printf_fifo_size" ~default:""
@@ -229,34 +289,6 @@ end) : Ir.Backend_impl.Lowered_backend = struct
229289
}
230290
[@@deriving sexp_of]
231291

232-
let%diagn2_sexp cuda_to_ptx ~name cu_src =
233-
let name_cu = name ^ ".cu" in
234-
if Utils.settings.output_debug_files_in_build_directory then (
235-
let build_file = Utils.open_build_file ~base_name:name ~extension:".cu" in
236-
Stdio.Out_channel.output_string build_file.oc cu_src;
237-
build_file.finalize ());
238-
[%log "compiling to PTX"];
239-
let with_debug =
240-
Utils.settings.output_debug_files_in_build_directory || Utils.settings.log_level > 0
241-
in
242-
let options =
243-
"--use_fast_math" :: (if Utils.with_runtime_debug () then [ "--device-debug" ] else [])
244-
in
245-
(* FIXME: every now and then the compilation crashes because the options are garbled. *)
246-
(* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
247-
let ptx = Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
248-
if Utils.settings.output_debug_files_in_build_directory then (
249-
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".ptx" in
250-
Stdio.Out_channel.output_string oc @@ Nvrtc.string_from_ptx ptx;
251-
Stdio.Out_channel.flush oc;
252-
Stdio.Out_channel.close oc;
253-
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".cu_log" in
254-
Stdio.Out_channel.output_string oc
255-
@@ Option.value_exn ~here:[%here] (Nvrtc.compilation_log ptx);
256-
Stdio.Out_channel.flush oc;
257-
Stdio.Out_channel.close oc);
258-
ptx
259-
260292
module Cuda_syntax_config (Input : sig
261293
val procs : Low_level.optimized array
262294
end) =
@@ -789,34 +821,11 @@ end) : Ir.Backend_impl.Lowered_backend = struct
789821
work;
790822
}
791823

792-
let run_options () =
793-
if Utils.with_runtime_debug () then
794-
Cu.Module.[ GENERATE_DEBUG_INFO true; GENERATE_LINE_INFO true ]
795-
else []
796-
797-
let set_ptr_in_kernel kernel_module src name =
798-
let dst, _ = Cuda.Module.get_global kernel_module ~name in
799-
(* Copy the helper function address to the kernel's function pointer variable *)
800-
Cuda.Deviceptr.memcpy_D_to_D ~dst ~src ~size_in_bytes:8 (* pointer size *) ()
801-
802-
let set_builtins_in_kernel =
803-
assert !initialized;
804-
let builtins_path =
805-
Stdlib.Filename.concat (Stdlib.Filename.dirname Stdlib.__FILE__) "builtins_large.cu"
806-
in
807-
let cu_src = Stdio.In_channel.read_all builtins_path in
808-
let code = cuda_to_ptx ~name:"builtins_large" cu_src in
809-
(* set_ctx ctx; *)
810-
let run_module = Cu.Module.load_data_ex code (run_options ()) in
811-
let threefry4x32_ptr, _ = Cu.Module.get_global run_module ~name:"arrayjit_threefry4x32" in
812-
fun kernel_module ->
813-
set_ptr_in_kernel kernel_module threefry4x32_ptr "arrayjit_threefry4x32"
814-
815824
let%track3_sexp link prior_context (code : code) ctx_arrays =
816825
let ctx = ctx_of prior_context in
817826
set_ctx ctx;
818827
let run_module = Cu.Module.load_data_ex code.ptx (run_options ()) in
819-
set_builtins_in_kernel run_module;
828+
prior_context.stream.device.dev.set_builtins_in run_module;
820829
let idx_params = Indexing.bound_symbols code.bindings in
821830
let lowered_bindings : Indexing.lowered_bindings =
822831
List.map idx_params ~f:(fun s -> (s, ref 0))
@@ -835,7 +844,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
835844
let ctx = ctx_of prior_context in
836845
set_ctx ctx;
837846
let run_module = Cu.Module.load_data_ex code_batch.ptx (run_options ()) in
838-
set_builtins_in_kernel run_module;
847+
prior_context.stream.device.dev.set_builtins_in run_module;
839848
let procs =
840849
Array.mapi code_batch.params_and_names ~f:(fun i pns ->
841850
Option.value ~default:None

0 commit comments

Comments
 (0)