3535module Device_config = struct
3636 include Backend_buffer
3737
38- type dev = { dev : Cu.Device .t ; primary_context : Cu.Context .t } [@@ deriving sexp_of ]
38+ type dev = {
39+ dev : Cu.Device .t ;
40+ primary_context : Cu.Context .t ;
41+ set_builtins_in : Cu.Module .t -> unit ;
42+ }
43+ [@@ deriving sexp_of ]
44+
3945 type runner = Cu.Stream .t [@@ deriving sexp_of ]
4046 type event = Cu.Delimited_event .t [@@ deriving sexp_of ]
4147
@@ -98,6 +104,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
98104 initialized := true )
99105
100106 let num_devices = Cu.Device. get_count
107+
108+ (* [devices] is mutable to support plugging in new devices. *)
101109 let devices = ref @@ Array. create ~len: (num_devices () ) None
102110
103111 let get_used_memory (device : device ) =
@@ -121,6 +129,57 @@ end) : Ir.Backend_impl.Lowered_backend = struct
121129 Hashtbl. iter device.cross_stream_candidates ~f: (fun buffer_ptr ->
122130 Cu.Deviceptr. mem_free buffer_ptr)
123131
132+ let % diagn2_sexp cuda_to_ptx ~name cu_src =
133+ let name_cu = name ^ " .cu" in
134+ if Utils. settings.output_debug_files_in_build_directory then (
135+ let build_file = Utils. open_build_file ~base_name: name ~extension: " .cu" in
136+ Stdio.Out_channel. output_string build_file.oc cu_src;
137+ build_file.finalize () );
138+ [% log " compiling to PTX" ];
139+ let with_debug =
140+ Utils. settings.output_debug_files_in_build_directory || Utils. settings.log_level > 0
141+ in
142+ let options =
143+ " --use_fast_math" :: (if Utils. with_runtime_debug () then [ " --device-debug" ] else [] )
144+ in
145+ (* FIXME: every now and then the compilation crashes because the options are garbled. *)
146+ (* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
147+ let ptx = Nvrtc. compile_to_ptx ~cu_src ~name: name_cu ~options ~with_debug in
148+ if Utils. settings.output_debug_files_in_build_directory then (
149+ let oc = Out_channel. open_text @@ Utils. build_file @@ name ^ " .ptx" in
150+ Stdio.Out_channel. output_string oc @@ Nvrtc. string_from_ptx ptx;
151+ Stdio.Out_channel. flush oc;
152+ Stdio.Out_channel. close oc;
153+ let oc = Out_channel. open_text @@ Utils. build_file @@ name ^ " .cu_log" in
154+ Stdio.Out_channel. output_string oc
155+ @@ Option. value_exn ~here: [% here] (Nvrtc. compilation_log ptx);
156+ Stdio.Out_channel. flush oc;
157+ Stdio.Out_channel. close oc);
158+ ptx
159+
160+ let run_options () =
161+ if Utils. with_runtime_debug () then
162+ Cu.Module. [ GENERATE_DEBUG_INFO true ; GENERATE_LINE_INFO true ]
163+ else []
164+
165+ let set_ptr_in_kernel kernel_module src name =
166+ let dst, _ = Cuda.Module. get_global kernel_module ~name in
167+ (* Copy the helper function address to the kernel's function pointer variable *)
168+ Cuda.Deviceptr. memcpy_D_to_D ~dst ~src ~size_in_bytes: 8 (* pointer size *) ()
169+
170+ let set_builtins_for_device =
171+ assert ! initialized;
172+ let builtins_path =
173+ Stdlib.Filename. concat (Stdlib.Filename. dirname Stdlib. __FILE__) " builtins_large.cu"
174+ in
175+ let cu_src = Stdio.In_channel. read_all builtins_path in
176+ let code = cuda_to_ptx ~name: " builtins_large" cu_src in
177+ fun ~primary_context ->
178+ set_ctx primary_context;
179+ let run_module = Cu.Module. load_data_ex code (run_options () ) in
180+ let threefry4x32_ptr, _ = Cu.Module. get_global run_module ~name: " arrayjit_threefry4x32" in
181+ fun kernel_module -> set_ptr_in_kernel kernel_module threefry4x32_ptr " arrayjit_threefry4x32"
182+
124183 let % track3_sexp get_device ~(ordinal : int ) : device =
125184 if num_devices () < = ordinal then
126185 invalid_arg [% string " Exec_as_cuda.get_device %{ordinal#Int}: not enough devices" ];
@@ -130,7 +189,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
130189 let default () =
131190 let dev = Cu.Device. get ~ordinal in
132191 let primary_context : Cu.Context.t = Cu.Context. get_primary dev in
133- let dev = { dev; primary_context } in
192+ let set_builtins_in = set_builtins_for_device ~primary_context in
193+ let dev = { dev; primary_context; set_builtins_in } in
134194 set_ctx primary_context;
135195 if Utils. debug_log_from_routines () && not (Hash_set. mem initialized_devices ordinal) then
136196 Int. of_string_opt @@ Utils. get_global_arg ~arg_name: " cuda_printf_fifo_size" ~default: " "
@@ -229,34 +289,6 @@ end) : Ir.Backend_impl.Lowered_backend = struct
229289 }
230290 [@@ deriving sexp_of ]
231291
232- let % diagn2_sexp cuda_to_ptx ~name cu_src =
233- let name_cu = name ^ " .cu" in
234- if Utils. settings.output_debug_files_in_build_directory then (
235- let build_file = Utils. open_build_file ~base_name: name ~extension: " .cu" in
236- Stdio.Out_channel. output_string build_file.oc cu_src;
237- build_file.finalize () );
238- [% log " compiling to PTX" ];
239- let with_debug =
240- Utils. settings.output_debug_files_in_build_directory || Utils. settings.log_level > 0
241- in
242- let options =
243- " --use_fast_math" :: (if Utils. with_runtime_debug () then [ " --device-debug" ] else [] )
244- in
245- (* FIXME: every now and then the compilation crashes because the options are garbled. *)
246- (* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
247- let ptx = Nvrtc. compile_to_ptx ~cu_src ~name: name_cu ~options ~with_debug in
248- if Utils. settings.output_debug_files_in_build_directory then (
249- let oc = Out_channel. open_text @@ Utils. build_file @@ name ^ " .ptx" in
250- Stdio.Out_channel. output_string oc @@ Nvrtc. string_from_ptx ptx;
251- Stdio.Out_channel. flush oc;
252- Stdio.Out_channel. close oc;
253- let oc = Out_channel. open_text @@ Utils. build_file @@ name ^ " .cu_log" in
254- Stdio.Out_channel. output_string oc
255- @@ Option. value_exn ~here: [% here] (Nvrtc. compilation_log ptx);
256- Stdio.Out_channel. flush oc;
257- Stdio.Out_channel. close oc);
258- ptx
259-
260292 module Cuda_syntax_config (Input : sig
261293 val procs : Low_level .optimized array
262294 end ) =
@@ -789,34 +821,11 @@ end) : Ir.Backend_impl.Lowered_backend = struct
789821 work;
790822 }
791823
792- let run_options () =
793- if Utils. with_runtime_debug () then
794- Cu.Module. [ GENERATE_DEBUG_INFO true ; GENERATE_LINE_INFO true ]
795- else []
796-
797- let set_ptr_in_kernel kernel_module src name =
798- let dst, _ = Cuda.Module. get_global kernel_module ~name in
799- (* Copy the helper function address to the kernel's function pointer variable *)
800- Cuda.Deviceptr. memcpy_D_to_D ~dst ~src ~size_in_bytes: 8 (* pointer size *) ()
801-
802- let set_builtins_in_kernel =
803- assert ! initialized;
804- let builtins_path =
805- Stdlib.Filename. concat (Stdlib.Filename. dirname Stdlib. __FILE__) " builtins_large.cu"
806- in
807- let cu_src = Stdio.In_channel. read_all builtins_path in
808- let code = cuda_to_ptx ~name: " builtins_large" cu_src in
809- (* set_ctx ctx; *)
810- let run_module = Cu.Module. load_data_ex code (run_options () ) in
811- let threefry4x32_ptr, _ = Cu.Module. get_global run_module ~name: " arrayjit_threefry4x32" in
812- fun kernel_module ->
813- set_ptr_in_kernel kernel_module threefry4x32_ptr " arrayjit_threefry4x32"
814-
815824 let % track3_sexp link prior_context (code : code ) ctx_arrays =
816825 let ctx = ctx_of prior_context in
817826 set_ctx ctx;
818827 let run_module = Cu.Module. load_data_ex code.ptx (run_options () ) in
819- set_builtins_in_kernel run_module;
828+ prior_context.stream.device.dev.set_builtins_in run_module;
820829 let idx_params = Indexing. bound_symbols code.bindings in
821830 let lowered_bindings : Indexing.lowered_bindings =
822831 List. map idx_params ~f: (fun s -> (s, ref 0 ))
@@ -835,7 +844,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
835844 let ctx = ctx_of prior_context in
836845 set_ctx ctx;
837846 let run_module = Cu.Module. load_data_ex code_batch.ptx (run_options () ) in
838- set_builtins_in_kernel run_module;
847+ prior_context.stream.device.dev.set_builtins_in run_module;
839848 let procs =
840849 Array. mapi code_batch.params_and_names ~f: (fun i pns ->
841850 Option. value ~default: None
0 commit comments