@@ -111,10 +111,7 @@ let devices = ref @@ Core.Weak.create 0
111111
112112(* Unlike [devices] above, [initialized_devices] never forgets its entries. *)
113113let initialized_devices = Hash_set. create (module Int )
114-
115- let set_ctx ctx =
116- let cur_ctx = Cu.Context. get_current () in
117- if not @@ phys_equal ctx cur_ctx then Cu.Context. set_current ctx
114+ let set_ctx ctx = Cu.Context. set_current ctx
118115
119116(* It's not actually used, but it's required by the [Backend] interface. *)
120117let alloc_buffer ?old_buffer ~size_in_bytes device =
@@ -135,6 +132,17 @@ let opt_alloc_merge_buffer ~size_in_bytes phys_dev =
135132 phys_dev.copy_merge_buffer < - Cu.Deviceptr. mem_alloc ~size_in_bytes ;
136133 phys_dev.copy_merge_buffer_capacity < - size_in_bytes)
137134
135+ let cleanup_physical device =
136+ Cu.Context. set_current device.primary_context;
137+ Cu.Context. synchronize () ;
138+ Option. iter ! Utils. advance_captured_logs ~f: (fun callback -> callback () );
139+ Cu.Deviceptr. mem_free device.copy_merge_buffer;
140+ Hashtbl. iter device.cross_device_candidates ~f: (fun ctx_array ->
141+ Cu.Deviceptr. mem_free ctx_array.ptr)
142+
143+ let finalize_physical device =
144+ if Atomic. compare_and_set device.released false true then cleanup_physical device
145+
138146let get_device ~(ordinal : int ) : physical_device =
139147 if num_physical_devices () < = ordinal then
140148 invalid_arg [% string " Exec_as_cuda.get_device %{ordinal#Int}: not enough devices" ];
@@ -166,6 +174,7 @@ let get_device ~(ordinal : int) : physical_device =
166174 owner_device_subordinal = Hashtbl. create (module Tn );
167175 }
168176 in
177+ Stdlib.Gc. finalise finalize_physical result;
169178 Core.Weak. set ! devices ordinal (Some result);
170179 result)
171180
@@ -243,18 +252,14 @@ let init device =
243252let unsafe_cleanup () =
244253 let len = Core.Weak. length ! devices in
245254 (* NOTE: releasing the context should free its resources, there's no need to finalize the
246- remaining contexts, and [finalize] will not do anything for a [released] physical device. *)
255+ remaining contexts, and [finalize], [finalize_physical] will not do anything for a [released]
256+ physical device. *)
247257 for i = 0 to len - 1 do
248258 Option. iter (Core.Weak. get ! devices i) ~f: (fun device ->
249- if Atomic. compare_and_set device.released false true then (
250- Cu.Context. set_current device.primary_context;
251- Cu.Context. synchronize () ;
252- Option. iter ! Utils. advance_captured_logs ~f: (fun callback -> callback () );
253- Hashtbl. iter device.cross_device_candidates ~f: (fun ctx_array ->
254- Cu.Deviceptr. mem_free ctx_array.ptr);
255- Cu.Device. primary_ctx_release device.dev))
259+ if Atomic. compare_and_set device.released false true then cleanup_physical device)
256260 done ;
257- Core.Weak. fill ! devices 0 len None
261+ Core.Weak. fill ! devices 0 len None ;
262+ Stdlib.Gc. compact ()
258263
259264let % diagn_l_sexp from_host (ctx : context ) tn =
260265 match (tn, Map. find ctx.ctx_arrays tn) with
0 commit comments