@@ -192,39 +192,30 @@ let unsafe_cleanup () =
192192 done ;
193193 Core.Weak. fill ! devices 0 len None
194194
195- let % diagn_sexp from_host ?(rt : (module Minidebug_runtime.Debug_runtime ) option ) (ctx : context ) tn
196- =
195+ let % diagn_l_sexp from_host (ctx : context ) tn =
197196 match (tn, Map. find ctx.global_arrays tn) with
198197 | { Tn. array = (lazy (Some hosted )); _ } , Some dst ->
199198 set_ctx ctx.ctx;
200- (if Utils. settings.with_debug_level > 0 then
201- let module Debug_runtime =
202- (val Option. value_or_thunk rt ~default: (fun () -> (module Debug_runtime )))
203- in
204- [% log " copying" , Tn. debug_name tn, " to" , (dst : ctx_array ), " from host" ]);
199+ if Utils. settings.with_debug_level > 0 then
200+ [% log " copying" , Tn. debug_name tn, " to" , (dst : ctx_array ), " from host" ];
205201 let f src = Cudajit. memcpy_H_to_D_async ~dst ~src ctx.device.stream in
206202 Ndarray. map { f } hosted;
207203 true
208204 | _ -> false
209205
210- let % track_sexp to_host ?(rt : (module Minidebug_runtime.Debug_runtime ) option ) (ctx : context )
211- (tn : Tn.t ) =
206+ let % track_l_sexp to_host (ctx : context ) (tn : Tn.t ) =
212207 match (tn, Map. find ctx.global_arrays tn) with
213208 | { Tn. array = (lazy (Some hosted )); _ } , Some src ->
214209 set_ctx ctx.ctx;
215- (if Utils. settings.with_debug_level > 0 then
216- let module Debug_runtime =
217- (val Option. value_or_thunk rt ~default: (fun () ->
218- (module Debug_runtime : Minidebug_runtime.Debug_runtime )))
219- in
220- [% log " copying" , Tn. debug_name tn, " at" , (src : ctx_array ), " to host" ]);
210+ if Utils. settings.with_debug_level > 0 then
211+ [% log " copying" , Tn. debug_name tn, " at" , (src : ctx_array ), " to host" ];
221212 let f dst = Cudajit. memcpy_D_to_H_async ~dst ~src ctx.device.stream in
222213 Ndarray. map { f } hosted;
223214 true
224215 | _ -> false
225216
226- let % track_sexp rec device_to_device ?(rt : (module Minidebug_runtime.Debug_runtime ) option )
227- (tn : Tn.t ) ~into_merge_buffer ~(dst : context ) ~(src : context ) =
217+ let % track_l_sexp rec device_to_device (tn : Tn.t ) ~into_merge_buffer ~(dst : context )
218+ ~(src : context ) =
228219 let memcpy ~d_arr ~s_arr =
229220 if phys_equal dst.device.physical src.device.physical then
230221 Cudajit. memcpy_D_to_D_async ~size_in_bytes: (Tn. size_in_bytes tn) ~dst: d_arr ~src: s_arr
@@ -243,46 +234,34 @@ let%track_sexp rec device_to_device ?(rt : (module Minidebug_runtime.Debug_runti
243234 | Some d_arr ->
244235 set_ctx dst.ctx;
245236 memcpy ~d_arr ~s_arr ;
246- (if Utils. settings.with_debug_level > 0 then
247- let module Debug_runtime =
248- (val Option. value_or_thunk rt ~default: (fun () ->
249- (module Debug_runtime : Minidebug_runtime.Debug_runtime )))
250- in
251- [% log
252- " copied" ,
253- Tn. debug_name tn,
254- " from" ,
255- src.label,
256- " at" ,
257- (s_arr : ctx_array ),
258- " to" ,
259- (d_arr : ctx_array )]);
237+ if Utils. settings.with_debug_level > 0 then
238+ [% log
239+ " copied" ,
240+ Tn. debug_name tn,
241+ " from" ,
242+ src.label,
243+ " at" ,
244+ (s_arr : ctx_array ),
245+ " to" ,
246+ (d_arr : ctx_array )];
260247 true )
261248 | Streaming ->
262249 if phys_equal dst.device.physical src.device.physical then (
263250 dst.device.merge_buffer < - Some (s_arr, tn);
264- (if Utils. settings.with_debug_level > 0 then
265- let module Debug_runtime =
266- (val Option. value_or_thunk rt ~default: (fun () ->
267- (module Debug_runtime : Minidebug_runtime.Debug_runtime )))
268- in
269- [% log " using merge buffer for" , Tn. debug_name tn, " from" , src.label]);
251+ if Utils. settings.with_debug_level > 0 then
252+ [% log " using merge buffer for" , Tn. debug_name tn, " from" , src.label];
270253 true )
271254 else
272255 (* TODO: support proper streaming, but it might be difficult. *)
273- device_to_device ?rt tn ~into_merge_buffer: Copy ~dst ~src
256+ device_to_device tn ~into_merge_buffer: Copy ~dst ~src
274257 | Copy ->
275258 set_ctx dst.ctx;
276259 let size_in_bytes = Tn. size_in_bytes tn in
277260 opt_alloc_merge_buffer ~size_in_bytes dst.device.physical;
278261 memcpy ~d_arr: dst.device.physical.copy_merge_buffer ~s_arr ;
279262 dst.device.merge_buffer < - Some (dst.device.physical.copy_merge_buffer, tn);
280- (if Utils. settings.with_debug_level > 0 then
281- let module Debug_runtime =
282- (val Option. value_or_thunk rt ~default: (fun () ->
283- (module Debug_runtime : Minidebug_runtime.Debug_runtime )))
284- in
285- [% log " copied into merge buffer" , Tn. debug_name tn, " from" , src.label]);
263+ if Utils. settings.with_debug_level > 0 then
264+ [% log " copied into merge buffer" , Tn. debug_name tn, " from" , src.label];
286265 true )
287266
288267type code = {
@@ -522,7 +501,7 @@ let%track_sexp link_batch prior_context (code_batch : code_batch) =
522501 in
523502 (context, lowered_bindings, procs)
524503
525- let to_buffer ? rt : _ _tn ~dst:_ ~src:_ = failwith " CUDA low-level: NOT IMPLEMENTED YET"
526- let host_to_buffer ? rt : _ _tn ~dst:_ = failwith " CUDA low-level: NOT IMPLEMENTED YET"
527- let buffer_to_host ? rt : _ _tn ~src:_ = failwith " CUDA low-level: NOT IMPLEMENTED YET"
504+ let to_buffer _tn ~dst :_ ~src :_ = failwith " CUDA low-level: NOT IMPLEMENTED YET"
505+ let host_to_buffer _tn ~dst :_ = failwith " CUDA low-level: NOT IMPLEMENTED YET"
506+ let buffer_to_host _tn ~src :_ = failwith " CUDA low-level: NOT IMPLEMENTED YET"
528507let get_buffer _tn _context = failwith " CUDA low-level: NOT IMPLEMENTED YET"
0 commit comments