Skip to content

Commit 0b7ae75

Browse files
committed
The cuda backend is now a generative functor; Cu.init called at module initialization
1 parent 2af41be commit 0b7ae75

File tree

10 files changed

+422
-424
lines changed

10 files changed

+422
-424
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- Built per-tensor-node stream-to-stream synchronization into copying functions.
2424
- Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping.
2525
- Simplifications: no more explicit compilation postponing; no more hard-coded pointers (all non-local arrays are passed by parameter).
26+
- Fresh backends are now fresh modules to structurally prevent any potential cache leaking.
2627

2728
### Fixed
2829

arrayjit/lib/backend_impl.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ struct
107107
{
108108
dev;
109109
ordinal;
110-
released = Atomic.make false;
111110
cross_stream_candidates = Hashtbl.create (module Tnode);
112111
owner_stream = Hashtbl.create (module Tnode);
113112
shared_writer_streams = Hashtbl.create (module Tnode);

arrayjit/lib/backend_intf.ml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ end
8080
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
8181
dev : 'dev;
8282
ordinal : int;
83-
released : Utils.atomic_bool;
8483
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
8584
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
8685
shared_writer_streams :
@@ -112,7 +111,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
112111
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
113112
dev : 'dev;
114113
ordinal : int;
115-
released : Utils.atomic_bool;
116114
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
117115
(** Freshly created arrays that might be shared across streams. The map can both grow and
118116
shrink. *)

arrayjit/lib/backends.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,6 @@ let finalize (type buffer_ptr dev runner event)
462462
Option.iter Backend.free_buffer ~f:(fun mem_free ->
463463
if
464464
Atomic.compare_and_set ctx.finalized false true
465-
&& (not @@ Atomic.get ctx.stream.device.released)
466465
then (
467466
Backend.await ctx.stream;
468467
Map.iteri ctx.ctx_arrays ~f:(fun ~key ~data ->
@@ -475,6 +474,8 @@ let%track5_sexp fresh_backend ?backend_name () =
475474
Stdlib.Gc.full_major ();
476475
(* TODO: is running again needed to give time to weak arrays to become empty? *)
477476
Stdlib.Gc.full_major ();
477+
(* Note: we invoke functors from within fresh_backend to fully isolate backends from distinct
478+
calls to fresh_backend. *)
478479
match
479480
Option.value_or_thunk backend_name ~default:(fun () ->
480481
Utils.get_global_arg ~arg_name:"backend" ~default:"cc")
@@ -486,5 +487,5 @@ let%track5_sexp fresh_backend ?backend_name () =
486487
| "sync_cc" -> (module Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend) : Backend)
487488
| "sync_gccjit" ->
488489
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend) : Backend)
489-
| "cuda" -> (module Raise_backend ((Cuda_backend : Lowered_backend)) : Backend)
490+
| "cuda" -> (module Raise_backend ((Cuda_backend.Fresh () : Lowered_backend)) : Backend)
490491
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]

arrayjit/lib/backends.mli

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ val finalize :
1717
1818
Note: this type will get simpler with modular explicits. *)
1919

20-
val fresh_backend :
21-
?backend_name:string -> unit -> (module Backend_intf.Backend)
20+
val fresh_backend : ?backend_name:string -> unit -> (module Backend_intf.Backend)
2221
(** Creates a new backend corresponding to [backend_name], or if omitted, selected via the global
23-
[backend] setting. *)
22+
[backend] setting. It should be safe to call {!Tensor.unsafe_reinitialize} before
23+
[fresh_backend]. *)

0 commit comments

Comments
 (0)