11open Base
22module Tn = Tnode
33module Lazy = Utils. Lazy
4- module Cu = Cudajit
4+ module Cu = Cuda
55open Backend_intf
66
77let _get_local_debug_runtime = Utils. get_local_debug_runtime
@@ -99,7 +99,7 @@ module Fresh () = struct
9999
100100 let get_used_memory (device : device ) =
101101 set_ctx device.dev.primary_context;
102- let free, total = Cudajit .Device. get_free_and_total_mem () in
102+ let free, total = Cu .Device. get_free_and_total_mem () in
103103 total - free
104104
105105 let opt_alloc_merge_buffer ~size_in_bytes dev stream : unit =
@@ -213,7 +213,7 @@ module Fresh () = struct
213213
214214 type code = {
215215 traced_store : Low_level .traced_store ;
216- ptx : Cu. Nvrtc.compile_to_ptx_result ;
216+ ptx : Nvrtc .compile_to_ptx_result ;
217217 params : (string * param_source ) list ;
218218 bindings : Indexing .unit_bindings ;
219219 name : string ;
@@ -222,7 +222,7 @@ module Fresh () = struct
222222
223223 type code_batch = {
224224 traced_stores : Low_level .traced_store option array ;
225- ptx : Cu. Nvrtc.compile_to_ptx_result ;
225+ ptx : Nvrtc .compile_to_ptx_result ;
226226 bindings : Indexing .unit_bindings ;
227227 params_and_names : ((string * param_source ) list * string ) option array ;
228228 }
@@ -236,7 +236,6 @@ module Fresh () = struct
236236 Stdio.Out_channel. flush oc;
237237 Stdio.Out_channel. close oc);
238238 [% log " compiling to PTX" ];
239- let module Cu = Cudajit in
240239 let with_debug =
241240 Utils. settings.output_debug_files_in_build_directory || Utils. settings.log_level > 0
242241 in
@@ -245,15 +244,15 @@ module Fresh () = struct
245244 in
246245 (* FIXME: every now and then the compilation crashes because the options are garbled. *)
247246 (* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
248- let ptx = Cu. Nvrtc. compile_to_ptx ~cu_src ~name: name_cu ~options ~with_debug in
247+ let ptx = Nvrtc. compile_to_ptx ~cu_src ~name: name_cu ~options ~with_debug in
249248 if Utils. settings.output_debug_files_in_build_directory then (
250249 let oc = Out_channel. open_text @@ Utils. build_file @@ name ^ " .ptx" in
251- Stdio.Out_channel. output_string oc @@ Cu. Nvrtc. string_from_ptx ptx;
250+ Stdio.Out_channel. output_string oc @@ Nvrtc. string_from_ptx ptx;
252251 Stdio.Out_channel. flush oc;
253252 Stdio.Out_channel. close oc;
254253 let oc = Out_channel. open_text @@ Utils. build_file @@ name ^ " .cu_log" in
255254 Stdio.Out_channel. output_string oc
256- @@ Option. value_exn ~here: [% here] (Cu. Nvrtc. compilation_log ptx);
255+ @@ Option. value_exn ~here: [% here] (Nvrtc. compilation_log ptx);
257256 Stdio.Out_channel. flush oc;
258257 Stdio.Out_channel. close oc);
259258 ptx
@@ -452,7 +451,6 @@ module Fresh () = struct
452451
453452 let link_proc ~prior_context ~name ~(params : (string * param_source) list ) ~ctx_arrays
454453 lowered_bindings run_module =
455- let module Cu = Cudajit in
456454 let func = Cu.Module. get_function run_module ~name in
457455 let stream = prior_context.stream in
458456 let runner_label = get_name stream in
@@ -540,7 +538,6 @@ module Fresh () = struct
540538 let lowered_bindings : Indexing.lowered_bindings =
541539 List. map idx_params ~f: (fun s -> (s, ref 0 ))
542540 in
543- let module Cu = Cudajit in
544541 let ctx = ctx_of prior_context in
545542 set_ctx ctx;
546543 let run_module = Cu.Module. load_data_ex code_batch.ptx (run_options () ) in
@@ -557,10 +554,10 @@ module Fresh () = struct
557554
558555 let get_global_debug_info () =
559556 Sexp. message " cuda_global_debug"
560- [ (" live_streams" , [% sexp_of: int ] @@ Cudajit .Stream. get_total_live_streams () ) ]
557+ [ (" live_streams" , [% sexp_of: int ] @@ Cu .Stream. get_total_live_streams () ) ]
561558
562559 let get_debug_info (stream : stream ) =
563- let tot, unr, unf = Cudajit .Stream. total_unreleased_unfinished_delimited_events stream.runner in
560+ let tot, unr, unf = Cu .Stream. total_unreleased_unfinished_delimited_events stream.runner in
564561 let i2s = [% sexp_of: int ] in
565562 Sexp. message " cuda_stream_debug"
566563 [ (" total_events" , i2s tot); (" unreleased_events" , i2s unr); (" unfinished_events" , i2s unf) ]
0 commit comments