Skip to content

Commit 9540fe7

Browse files
committed
Fix: cuda: keep track of virtual device ids (aka subordinals)
1 parent 56f3e7b commit 9540fe7

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type physical_device = {
2323
primary_context : (Cudajit.context[@sexp.opaque]);
2424
mutable copy_merge_buffer : buffer_ptr;
2525
mutable copy_merge_buffer_capacity : int;
26+
mutable latest_subordinal : int;
2627
released : Utils.atomic_bool;
2728
}
2829
[@@deriving sexp_of]
@@ -109,6 +110,7 @@ let get_device ~(ordinal : int) : physical_device =
109110
{
110111
dev;
111112
ordinal;
113+
latest_subordinal = 0;
112114
primary_context;
113115
copy_merge_buffer;
114116
copy_merge_buffer_capacity;
@@ -119,7 +121,8 @@ let get_device ~(ordinal : int) : physical_device =
119121
result)
120122

121123
let new_virtual_device physical =
122-
let subordinal = 0 in
124+
let subordinal = physical.latest_subordinal in
125+
physical.latest_subordinal <- physical.latest_subordinal + 1;
123126
(* Strange that we need ctx_set_current even with a single device! *)
124127
set_ctx physical.primary_context;
125128
let stream = Cudajit.stream_create ~non_blocking:true () in

0 commit comments

Comments
 (0)