Skip to content

Commit 56f3e7b

Browse files
committed
Restore the functionality of debug logging from the cuda backend
1 parent 5bc72f6 commit 56f3e7b

File tree

5 files changed

+83
-25
lines changed

5 files changed

+83
-25
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
- Reduced busy-waiting inside `c_compile_and_load`, propagating compilation errors now instead of infinite loop on error.
2525
- Fixed loss of significant digits for small numbers when outputting files.
2626
- Added missing mixed-precision conversions in the `C_syntax` backend builder.
27+
- Restored the functionality of debug logging from the cuda backend.
2728

2829
## [0.4.0] -- 2024-09-04
2930

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ A possible route to learning OCANNL:
5252
2. Backend-independent optimizations [arrayjit/lib/lowering_and_inlining.md](arrayjit/lib/lowering_and_inlining.md) -- _lowering_ means translating (compiling) from the high-level representation (as assignments) to the low-level representation.
5353
3. More documentation to come.
5454

55+
### Using the tracing debugger with CUDA computations
56+
57+
To use debugging as provided by configuring `Utils.settings.debug_log_from_routines <- true` with the `cuda` backend, you need to wrap the code scheduling tasks and synchronizing `cuda` devices with `Utils.capture_stdout_logs`. The reason is that CUDA kernels are allowed to use `printf`, but not `fprintf` -- the driver dumps the printing buffer of a device to `stdout` at certain times (e.g. when synchronizing the device). For an example, see the implementation of `Train.example_train_loop`. Specifically, it wraps two sections: the call to `Train.parallel_update`, and the body of the returned `infer_callback`.
58+
59+
IMPORTANT: due to potential bugs, debug logging from CUDA in complex settings currently only works as intended for _very_ small computation sizes.
60+
5561
## Upcoming milestones
5662

5763
This is very tentative.

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ let is_initialized, initialize =
6363
let num_physical_devices = Cudajit.device_get_count
6464
let devices = ref @@ Core.Weak.create 0
6565

66+
(* Unlike [devices] above, [initialized_devices] never forgets its entries. *)
67+
let initialized_devices = Hash_set.create (module Int)
68+
6669
let set_ctx ctx =
6770
let cur_ctx = Cudajit.ctx_get_current () in
6871
if not @@ phys_equal ctx cur_ctx then Cudajit.ctx_set_current ctx
@@ -98,6 +101,9 @@ let get_device ~(ordinal : int) : physical_device =
98101
let primary_context = Cudajit.device_primary_ctx_retain dev in
99102
let copy_merge_buffer_capacity = 8 in
100103
set_ctx primary_context;
104+
if Utils.debug_log_from_routines () && not (Hash_set.mem initialized_devices ordinal) then
105+
Option.iter Utils.settings.cuda_printf_fifo_size ~f:Cudajit.(ctx_set_limit PRINTF_FIFO_SIZE);
106+
Hash_set.add initialized_devices ordinal;
101107
let copy_merge_buffer = Cudajit.mem_alloc ~size_in_bytes:copy_merge_buffer_capacity in
102108
let result =
103109
{
@@ -147,7 +153,8 @@ let get_name device =
147153

148154
let await device : unit =
149155
set_ctx device.physical.primary_context;
150-
Cudajit.stream_synchronize device.stream
156+
Cudajit.stream_synchronize device.stream;
157+
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ())
151158

152159
let is_idle device = Cudajit.stream_is_ready device.stream
153160

@@ -188,6 +195,7 @@ let unsafe_cleanup () =
188195
if Atomic.compare_and_set device.released false true then (
189196
Cudajit.ctx_set_current device.primary_context;
190197
Cudajit.ctx_synchronize ();
198+
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
191199
Cudajit.device_primary_ctx_release device.dev))
192200
done;
193201
Core.Weak.fill !devices 0 len None
@@ -477,12 +485,13 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~glo
477485
(* Map.iteri global_arrays ~f:(fun ~key ~data:ptr -> if key.Low_level.zero_initialized then
478486
Cu.memset_d8_async ptr Unsigned.UChar.zero ~length:(Tn.size_in_bytes key.Low_level.tn)); *)
479487
[%log "launching the kernel"];
488+
(* TODO: This doesn't help. *)
489+
(* Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ()); *)
480490
(if Utils.debug_log_from_routines () then
481491
Utils.add_log_processor ~prefix:log_id_prefix @@ fun _output ->
482492
[%log_block
483493
context.label;
484494
Utils.log_trace_tree _output]);
485-
(* if Utils.debug_log_from_routines () then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE 4096; *)
486495
Cu.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 context.device.stream
487496
args;
488497
[%log "kernel launched"]

arrayjit/lib/utils.ml

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ type settings = {
4343
mutable check_half_prec_constants_cutoff : float option;
4444
(** If given, generic code optimization should fail if a half precision FP16 constant exceeds
4545
the cutoff. *)
46+
mutable cuda_printf_fifo_size : int option;
47+
(** If not [None], the setting will be used for the size of the CUDA devices buffer for
48+
storing logs, see [debug_log_from_routines] above. If [None], the default buffer size on
49+
the devices is not altered. *)
4650
}
4751
[@@deriving sexp]
4852

@@ -55,6 +59,7 @@ let settings =
5559
fixed_state_for_init = None;
5660
print_decimals_precision = 2;
5761
check_half_prec_constants_cutoff = Some (2. **. 14.);
62+
cuda_printf_fifo_size = None;
5863
}
5964

6065
let accessed_global_args = Hash_set.create (module String)
@@ -321,7 +326,9 @@ let restore_settings () =
321326
Int.of_string @@ get_global_arg ~arg_name:"print_decimals_precision" ~default:"2";
322327
settings.check_half_prec_constants_cutoff <-
323328
Float.of_string_opt
324-
@@ get_global_arg ~arg_name:"check_half_prec_constants_cutoff" ~default:"16384.0"
329+
@@ get_global_arg ~arg_name:"check_half_prec_constants_cutoff" ~default:"16384.0";
330+
settings.cuda_printf_fifo_size <-
331+
Int.of_string_opt @@ get_global_arg ~arg_name:"cuda_printf_fifo_size" ~default:""
325332

326333
let () = restore_settings ()
327334
let with_runtime_debug () = settings.output_debug_files_in_build_directory && settings.log_level > 1
@@ -507,6 +514,10 @@ let pp_file ~base_name ~extension =
507514

508515
let captured_log_prefix = ref "!@#"
509516

517+
(** To avoid the complication of a concurrent thread, we expose a callback for collaborative log
518+
processing. *)
519+
let advance_captured_logs = ref None
520+
510521
type captured_log_processor = { log_processor_prefix : string; process_logs : string list -> unit }
511522

512523
let captured_log_processors : captured_log_processor list ref = ref []
@@ -515,39 +526,69 @@ let add_log_processor ~prefix process_logs =
515526
captured_log_processors :=
516527
{ log_processor_prefix = prefix; process_logs } :: !captured_log_processors
517528

529+
external input_scan_line : Stdlib.in_channel -> int = "caml_ml_input_scan_line"
530+
531+
let input_line chan =
532+
let n = input_scan_line chan in
533+
if n = 0 then raise End_of_file;
534+
let line = Stdlib.really_input_string chan (abs n) in
535+
( n > 0,
536+
String.chop_suffix_if_exists ~suffix:"\n" @@ String.chop_suffix_if_exists line ~suffix:"\r\n" )
537+
518538
let capture_stdout_logs ?(never_skip = false) arg =
519539
if (not never_skip) && not (debug_log_from_routines ()) then arg ()
520540
else (
521541
Stdlib.flush Stdlib.stdout;
522-
let exitp, entrancep = Unix.pipe () and backup = Unix.dup Unix.stdout in
523-
Unix.dup2 entrancep Unix.stdout;
524-
Unix.set_nonblock entrancep;
525-
(* FIXME: process logs in a parallel thread, and double check they are not getting cut off. *)
542+
let ls = ref [] in
543+
let lastl = ref "" in
544+
let backup = ref (Unix.dup Unix.stdout) in
545+
let exit_entrance = ref (Unix.pipe ()) in
546+
let pre_advance () =
547+
Unix.dup2 (snd !exit_entrance) Unix.stdout;
548+
Unix.set_nonblock (snd !exit_entrance)
549+
in
550+
let advance is_last () =
551+
Stdlib.flush Stdlib.stdout;
552+
Unix.close (snd !exit_entrance);
553+
Unix.dup2 !backup Unix.stdout;
554+
let channel = Unix.in_channel_of_descr (fst !exit_entrance) in
555+
(try
556+
while true do
557+
let is_endlined, line = input_line channel in
558+
let line = !lastl ^ line in
559+
if is_endlined then (
560+
(match String.chop_prefix ~prefix:!captured_log_prefix line with
561+
| None -> Stdlib.print_endline line
562+
(* ls := line :: !ls *)
563+
| Some logline -> ls := logline :: !ls);
564+
lastl := "")
565+
else lastl := line
566+
done
567+
with End_of_file -> ());
568+
if not is_last then (
569+
backup := Unix.dup Unix.stdout;
570+
exit_entrance := Unix.pipe ();
571+
pre_advance ())
572+
in
573+
advance_captured_logs := Some (advance false);
574+
pre_advance ();
526575
let result =
527576
try arg ()
528577
with Sys_blocked_io ->
578+
advance_captured_logs := None;
529579
invalid_arg
530580
"capture_stdout_logs: unfortunately, flushing stdout inside captured code is prohibited"
531581
in
532-
Stdlib.flush Stdlib.stdout;
533-
Unix.close entrancep;
534-
Unix.dup2 backup Unix.stdout;
535-
let ls = ref [] and channel = Unix.in_channel_of_descr exitp in
536-
let output =
537-
try
538-
while true do
539-
let line = Stdlib.input_line channel in
540-
match String.chop_prefix ~prefix:!captured_log_prefix line with
541-
| None -> Stdlib.print_endline line
542-
| Some logline -> ls := logline :: !ls
543-
done;
544-
[]
545-
with End_of_file -> List.rev !ls
546-
in
582+
advance true ();
583+
let output = List.rev !ls in
547584
Exn.protect
548585
~f:(fun () ->
549-
List.iter !captured_log_processors ~f:(fun { log_processor_prefix; process_logs } ->
586+
(* Preserve the order in which kernels were launched. *)
587+
List.iter (List.rev !captured_log_processors)
588+
~f:(fun { log_processor_prefix; process_logs } ->
550589
process_logs
551590
@@ List.filter_map output ~f:(String.chop_prefix ~prefix:log_processor_prefix)))
552-
~finally:(fun () -> captured_log_processors := []);
591+
~finally:(fun () ->
592+
advance_captured_logs := None;
593+
captured_log_processors := []);
553594
result)

lib/train.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
484484
Tn.log_accessible_headers ());
485485
for epoch = 0 to epochs - 1 do
486486
epoch_loss := 0.;
487-
update ();
487+
Utils.capture_stdout_logs update;
488488
learning_rates := learning_rate.@[0] :: !learning_rates;
489489
epoch_losses := !epoch_loss :: !epoch_losses;
490490
Option.iter per_epoch_callback ~f:(fun f ->
@@ -509,6 +509,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
509509
Tensor.set_values infer values;
510510
(* For the gccjit backend, infer is only on host, not on device. For cuda, this will be
511511
needed. *)
512+
Utils.capture_stdout_logs @@ fun () ->
512513
assert (Backend.from_host routine.context infer.value);
513514
run routine;
514515
assert (Backend.to_host routine.context model_result.value);

0 commit comments

Comments
 (0)