@@ -42,13 +42,7 @@ type device = {
4242}
4343[@@ deriving sexp_of ]
4444
45- type stream = {
46- device : device ;
47- cu_stream : Cu.Stream .t ;
48- unique_name : string ;
49- mutable merge_buffer : (buffer_ptr * Tn .t ) option ;
50- }
51- [@@ deriving sexp_of ]
45+ type nonrec stream = (buffer_ptr , event , device , unit , Cu.Stream .t ) stream [@@ deriving sexp_of ]
5246
5347type context = {
5448 label : string ;
@@ -68,10 +62,10 @@ type context = {
6862let ctx_arrays ctx = ctx.ctx_arrays
6963let global_config = ref For_parallel_copying
7064let is_done event = Cu.Delimited_event. query event
71- let will_wait_for context event = Cu.Delimited_event. wait context.stream.cu_stream event
65+ let will_wait_for context event = Cu.Delimited_event. wait context.stream.runner event
7266let sync event = Cu.Delimited_event. synchronize event
73- let all_work stream = Cu.Delimited_event. record stream.cu_stream
74- let scheduled_merge_node stream = Option. map ~f: snd stream.merge_buffer
67+ let all_work stream = Cu.Delimited_event. record stream.runner
68+ let scheduled_merge_node stream = Option. map ~f: snd ! ( stream.merge_buffer)
7569
7670let is_initialized, initialize =
7771 let initialized = ref false in
@@ -172,7 +166,7 @@ let%track3_sexp new_stream (device : device) : stream =
172166 (* Strange that we need ctx_set_current even with a single device! *)
173167 set_ctx device.primary_context;
174168 let cu_stream = Cu.Stream. create ~non_blocking: true () in
175- { device; cu_stream; unique_name; merge_buffer = None }
169+ make_stream ~ device ~state: () ~ unique_name ~runner: cu_stream
176170
177171let cuda_properties =
178172 let cache =
@@ -199,10 +193,10 @@ let get_name stream = stream.unique_name
199193
200194let await stream : unit =
201195 set_ctx stream.device.primary_context;
202- Cu.Stream. synchronize stream.cu_stream ;
196+ Cu.Stream. synchronize stream.runner ;
203197 Option. iter ! Utils. advance_captured_logs ~f: (fun callback -> callback () )
204198
205- let is_idle stream = Cu.Stream. is_ready stream.cu_stream
199+ let is_idle stream = Cu.Stream. is_ready stream.runner
206200
207201let % track3_sexp finalize (ctx : context ) : unit =
208202 if
@@ -235,23 +229,23 @@ let init stream =
235229
236230let from_host ~dst_ptr ~dst hosted =
237231 set_ctx dst.ctx;
238- let f src = Cu.Stream. memcpy_H_to_D ~dst: dst_ptr ~src dst.stream.cu_stream in
232+ let f src = Cu.Stream. memcpy_H_to_D ~dst: dst_ptr ~src dst.stream.runner in
239233 Ndarray. map { f } hosted
240234
241235let to_host ~src_ptr ~src hosted =
242236 set_ctx src.ctx;
243- let f dst = Cu.Stream. memcpy_D_to_H ~dst ~src: src_ptr src.stream.cu_stream in
237+ let f dst = Cu.Stream. memcpy_D_to_H ~dst ~src: src_ptr src.stream.runner in
244238 Ndarray. map { f } hosted
245239
246240let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
247241 let same_device = dst.stream.device.ordinal = src.stream.device.ordinal in
248242 let memcpy ~dst_ptr =
249243 if same_device then
250244 Cu.Stream. memcpy_D_to_D ~size_in_bytes: (Tn. size_in_bytes tn) ~dst: dst_ptr ~src: src_ptr
251- dst.stream.cu_stream
245+ dst.stream.runner
252246 else
253247 Cu.Stream. memcpy_peer ~size_in_bytes: (Tn. size_in_bytes tn) ~dst: dst_ptr ~dst_ctx: dst.ctx
254- ~src: src_ptr ~src_ctx: src.ctx dst.stream.cu_stream
248+ ~src: src_ptr ~src_ctx: src.ctx dst.stream.runner
255249 in
256250 match (into_merge_buffer, dst_ptr) with
257251 | No , None -> invalid_arg " Cuda_backend.device_to_device: missing dst_ptr"
@@ -260,13 +254,13 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
260254 memcpy ~dst_ptr
261255 | Streaming , _ ->
262256 assert same_device;
263- dst.stream.merge_buffer < - Some (src_ptr, tn)
257+ dst.stream.merge_buffer := Some (src_ptr, tn)
264258 | Copy , _ ->
265259 set_ctx dst.ctx;
266260 let size_in_bytes = Tn. size_in_bytes tn in
267261 opt_alloc_merge_buffer ~size_in_bytes dst.stream.device;
268262 memcpy ~dst_ptr: dst.stream.device.copy_merge_buffer;
269- dst.stream.merge_buffer < - Some (dst.stream.device.copy_merge_buffer, tn)
263+ dst.stream.merge_buffer := Some (dst.stream.device.copy_merge_buffer, tn)
270264
271265type code = {
272266 traced_store : Low_level .traced_store ;
@@ -463,7 +457,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
463457 S. Tensor arr
464458 | _name , Log_file_name -> S. Int log_id
465459 | _name , Merge_buffer ->
466- let ptr = fst @@ Option. value_exn ~here: [% here] context.stream.merge_buffer in
460+ let ptr = fst @@ Option. value_exn ~here: [% here] ! ( context.stream.merge_buffer) in
467461 S. Tensor ptr
468462 | _name , Static_idx s ->
469463 let i = Indexing. find_exn lowered_bindings s in
@@ -492,8 +486,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
492486 [% log_block
493487 context.label;
494488 Utils. log_trace_tree _output]);
495- S. launch_kernel func ~grid_dim_x: 1 ~block_dim_x: 1 ~shared_mem_bytes: 0 context.stream.cu_stream
496- args;
489+ S. launch_kernel func ~grid_dim_x: 1 ~block_dim_x: 1 ~shared_mem_bytes: 0 context.stream.runner args;
497490 [% log " kernel launched" ]
498491 in
499492 ( context,
0 commit comments