Skip to content

Commit 387acd3

Browse files
committed
cuda backend: Fix: unsafe_cleanup was working with a destroyed context / finalized device
1 parent 58b4a60 commit 387acd3

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,7 @@ let devices = ref @@ Core.Weak.create 0
111111

112112
(* Unlike [devices] above, [initialized_devices] never forgets its entries. *)
113113
let initialized_devices = Hash_set.create (module Int)
114-
115-
let set_ctx ctx =
116-
let cur_ctx = Cu.Context.get_current () in
117-
if not @@ phys_equal ctx cur_ctx then Cu.Context.set_current ctx
114+
let set_ctx ctx = Cu.Context.set_current ctx
118115

119116
(* It's not actually used, but it's required by the [Backend] interface. *)
120117
let alloc_buffer ?old_buffer ~size_in_bytes device =
@@ -135,6 +132,17 @@ let opt_alloc_merge_buffer ~size_in_bytes phys_dev =
135132
phys_dev.copy_merge_buffer <- Cu.Deviceptr.mem_alloc ~size_in_bytes;
136133
phys_dev.copy_merge_buffer_capacity <- size_in_bytes)
137134

135+
let cleanup_physical device =
136+
Cu.Context.set_current device.primary_context;
137+
Cu.Context.synchronize ();
138+
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
139+
Cu.Deviceptr.mem_free device.copy_merge_buffer;
140+
Hashtbl.iter device.cross_device_candidates ~f:(fun ctx_array ->
141+
Cu.Deviceptr.mem_free ctx_array.ptr)
142+
143+
let finalize_physical device =
144+
if Atomic.compare_and_set device.released false true then cleanup_physical device
145+
138146
let get_device ~(ordinal : int) : physical_device =
139147
if num_physical_devices () <= ordinal then
140148
invalid_arg [%string "Exec_as_cuda.get_device %{ordinal#Int}: not enough devices"];
@@ -166,6 +174,7 @@ let get_device ~(ordinal : int) : physical_device =
166174
owner_device_subordinal = Hashtbl.create (module Tn);
167175
}
168176
in
177+
Stdlib.Gc.finalise finalize_physical result;
169178
Core.Weak.set !devices ordinal (Some result);
170179
result)
171180

@@ -243,18 +252,14 @@ let init device =
243252
let unsafe_cleanup () =
244253
let len = Core.Weak.length !devices in
245254
(* NOTE: releasing the context should free its resources, there's no need to finalize the
246-
remaining contexts, and [finalize] will not do anything for a [released] physical device. *)
255+
remaining contexts, and [finalize], [finalize_physical] will not do anything for a [released]
256+
physical device. *)
247257
for i = 0 to len - 1 do
248258
Option.iter (Core.Weak.get !devices i) ~f:(fun device ->
249-
if Atomic.compare_and_set device.released false true then (
250-
Cu.Context.set_current device.primary_context;
251-
Cu.Context.synchronize ();
252-
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
253-
Hashtbl.iter device.cross_device_candidates ~f:(fun ctx_array ->
254-
Cu.Deviceptr.mem_free ctx_array.ptr);
255-
Cu.Device.primary_ctx_release device.dev))
259+
if Atomic.compare_and_set device.released false true then cleanup_physical device)
256260
done;
257-
Core.Weak.fill !devices 0 len None
261+
Core.Weak.fill !devices 0 len None;
262+
Stdlib.Gc.compact ()
258263

259264
let%diagn_l_sexp from_host (ctx : context) tn =
260265
match (tn, Map.find ctx.ctx_arrays tn) with

0 commit comments

Comments
 (0)