Skip to content

Commit 330f855

Browse files
committed
Update to cudajit.0.7.0
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 463c570 commit 330f855

File tree

4 files changed

+14
-17
lines changed

4 files changed

+14
-17
lines changed

arrayjit.opam

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ depends: [
3636
"odoc" {with-doc}
3737
]
3838
depopts: [
39-
"cudajit" {>= "0.6.2"}
39+
"cudajit" {>= "0.7.0"}
4040
"gccjit" {>= "0.3.2"}
4141
]
4242
conflicts: [
43-
"cudajit" {< "0.6.2"}
43+
"cudajit" {< "0.7.0"}
4444
"gccjit" {< "0.3.2"}
4545
]
4646
build: [

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
open Base
22
module Tn = Tnode
33
module Lazy = Utils.Lazy
4-
module Cu = Cudajit
4+
module Cu = Cuda
55
open Backend_intf
66

77
let _get_local_debug_runtime = Utils.get_local_debug_runtime
@@ -99,7 +99,7 @@ module Fresh () = struct
9999

100100
let get_used_memory (device : device) =
101101
set_ctx device.dev.primary_context;
102-
let free, total = Cudajit.Device.get_free_and_total_mem () in
102+
let free, total = Cu.Device.get_free_and_total_mem () in
103103
total - free
104104

105105
let opt_alloc_merge_buffer ~size_in_bytes dev stream : unit =
@@ -213,7 +213,7 @@ module Fresh () = struct
213213

214214
type code = {
215215
traced_store : Low_level.traced_store;
216-
ptx : Cu.Nvrtc.compile_to_ptx_result;
216+
ptx : Nvrtc.compile_to_ptx_result;
217217
params : (string * param_source) list;
218218
bindings : Indexing.unit_bindings;
219219
name : string;
@@ -222,7 +222,7 @@ module Fresh () = struct
222222

223223
type code_batch = {
224224
traced_stores : Low_level.traced_store option array;
225-
ptx : Cu.Nvrtc.compile_to_ptx_result;
225+
ptx : Nvrtc.compile_to_ptx_result;
226226
bindings : Indexing.unit_bindings;
227227
params_and_names : ((string * param_source) list * string) option array;
228228
}
@@ -236,7 +236,6 @@ module Fresh () = struct
236236
Stdio.Out_channel.flush oc;
237237
Stdio.Out_channel.close oc);
238238
[%log "compiling to PTX"];
239-
let module Cu = Cudajit in
240239
let with_debug =
241240
Utils.settings.output_debug_files_in_build_directory || Utils.settings.log_level > 0
242241
in
@@ -245,15 +244,15 @@ module Fresh () = struct
245244
in
246245
(* FIXME: every now and then the compilation crashes because the options are garbled. *)
247246
(* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
248-
let ptx = Cu.Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
247+
let ptx = Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
249248
if Utils.settings.output_debug_files_in_build_directory then (
250249
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".ptx" in
251-
Stdio.Out_channel.output_string oc @@ Cu.Nvrtc.string_from_ptx ptx;
250+
Stdio.Out_channel.output_string oc @@ Nvrtc.string_from_ptx ptx;
252251
Stdio.Out_channel.flush oc;
253252
Stdio.Out_channel.close oc;
254253
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".cu_log" in
255254
Stdio.Out_channel.output_string oc
256-
@@ Option.value_exn ~here:[%here] (Cu.Nvrtc.compilation_log ptx);
255+
@@ Option.value_exn ~here:[%here] (Nvrtc.compilation_log ptx);
257256
Stdio.Out_channel.flush oc;
258257
Stdio.Out_channel.close oc);
259258
ptx
@@ -452,7 +451,6 @@ module Fresh () = struct
452451

453452
let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx_arrays
454453
lowered_bindings run_module =
455-
let module Cu = Cudajit in
456454
let func = Cu.Module.get_function run_module ~name in
457455
let stream = prior_context.stream in
458456
let runner_label = get_name stream in
@@ -540,7 +538,6 @@ module Fresh () = struct
540538
let lowered_bindings : Indexing.lowered_bindings =
541539
List.map idx_params ~f:(fun s -> (s, ref 0))
542540
in
543-
let module Cu = Cudajit in
544541
let ctx = ctx_of prior_context in
545542
set_ctx ctx;
546543
let run_module = Cu.Module.load_data_ex code_batch.ptx (run_options ()) in
@@ -557,10 +554,10 @@ module Fresh () = struct
557554

558555
let get_global_debug_info () =
559556
Sexp.message "cuda_global_debug"
560-
[ ("live_streams", [%sexp_of: int] @@ Cudajit.Stream.get_total_live_streams ()) ]
557+
[ ("live_streams", [%sexp_of: int] @@ Cu.Stream.get_total_live_streams ()) ]
561558

562559
let get_debug_info (stream : stream) =
563-
let tot, unr, unf = Cudajit.Stream.total_unreleased_unfinished_delimited_events stream.runner in
560+
let tot, unr, unf = Cu.Stream.total_unreleased_unfinished_delimited_events stream.runner in
564561
let i2s = [%sexp_of: int] in
565562
Sexp.message "cuda_stream_debug"
566563
[ ("total_events", i2s tot); ("unreleased_events", i2s unr); ("unfinished_events", i2s unf) ]

arrayjit/lib/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
(select
2222
cuda_backend.ml
2323
from
24-
(cudajit -> cuda_backend.cudajit.ml)
24+
(cudajit.cuda -> cuda_backend.cudajit.ml)
2525
(-> cuda_backend.missing.ml))
2626
ppx_minidebug.runtime)
2727
(preprocess

dune-project

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@
6262
(>= 2.2.0)))
6363
(depopts
6464
(cudajit
65-
(>= 0.6.2))
65+
(>= 0.7.0))
6666
(gccjit
6767
(>= 0.3.2)))
6868
(conflicts
6969
(cudajit
70-
(< 0.6.2))
70+
(< 0.7.0))
7171
(gccjit
7272
(< 0.3.2)))
7373
(tags

0 commit comments

Comments
 (0)