Skip to content

Commit 0378b25

Browse files
committed
When filtering stdout for log messages, forward non-log lines to the original stdout as soon as available.
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 52f80ac commit 0378b25

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ A possible route to learning OCANNL:
6060

6161
To use debugging as provided by configuring `Utils.settings.debug_log_from_routines <- true` with the `cuda` backend, you need to wrap the code scheduling tasks and synchronizing `cuda` devices with `Utils.capture_stdout_logs`. The reason is that CUDA kernels are allowed to use `printf`, but not `fprintf` -- the driver dumps the printing buffer of a device to `stdout` at certain times (e.g. when synchronizing the device). For an example, see the implementation of `Train.example_train_loop`. Specifically, it wraps two sections: the call to `Train.parallel_update`, and the body of the returned `infer_callback`.
6262

63-
IMPORTANT: debug logging from CUDA in complex settings currently only works as intended for _very_ small computation sizes. If facing issues, try the setting `never_capture_stdout=true` (see [ocannl_config.example](ocannl_config.example)).
63+
NOTE: debug logging from CUDA in complex settings is a bit tricky, it involves another thread (domain) intercepting and filtering `stdout`. If facing issues, try the setting `never_capture_stdout=true` (see [ocannl_config.example](ocannl_config.example)).
6464

6565
## Upcoming milestones
6666

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ We output a log line only for comments and array assignments (corresponding to n
143143

144144
#### Tracing via `stdout`
145145

146-
Since the CUDA backend can only log to the standard output, it passes `let logs_to_stdout = true` to `C_syntax`. This uses `printf`, and prefixes each log line with a kernel run ID. When postprocessing the logs, each run extracts its own log lines. Simultaneous logging from multiple CUDA devices should still be clean -- without interleaving lines -- because the driver is supposed to dump the logs to standard output at device synchronization points.
146+
Since the CUDA backend can only log to the standard output, it uses `printf`, and prefixes each log line with a kernel run ID. When postprocessing the logs, each run extracts its own log lines. Simultaneous logging from multiple CUDA devices should still be clean -- without interleaving lines -- because the driver is supposed to dump the logs to standard output at device synchronization points.
147147

148148
When using the default stream, CUDA would predictably write to the standard output at context synchronization only. Unfortunately, it does not appear to be the case with asynchronous streams. [Despite the assurance from the documentation, output happens in between CUDA calls...](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#formatted-output) To remedy this, we implement a `stdout` filtering scheme (function `Utils.capture_stdout_logs`), where all output is captured, tracing lines extracted, and other output printed on the original `stdout`.
149149

arrayjit/lib/utils.ml

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -655,27 +655,33 @@ let capture_stdout_logs arg =
655655
(* The reader domain will close pipe_read_fd. *)
656656

657657
let collected_logs_ref = ref [] in
658-
let passthrough_lines_ref = ref [] in (* Buffer for non-log lines *)
659658
let reader_domain_failed = Atomic.make false in
660659

661660
let reader_domain_logic () =
662661
let in_channel = Unix.in_channel_of_descr pipe_read_fd in
662+
(* Create an output channel to the original stdout for immediate passthrough *)
663+
let orig_out = Unix.out_channel_of_descr (Unix.dup original_stdout_fd) in
663664
try
664665
while true do
665666
let _is_endlined, line = input_line in_channel in
666667
match String.chop_prefix ~prefix:!captured_log_prefix line with
667668
| Some logline -> collected_logs_ref := logline :: !collected_logs_ref
668-
| None -> passthrough_lines_ref := line :: !passthrough_lines_ref (* Buffer the line *)
669+
| None ->
670+
(* Forward non-log lines to original stdout immediately *)
671+
Stdlib.output_string orig_out (line ^ "\n");
672+
Stdlib.flush orig_out
669673
done;
674+
Stdlib.close_out_noerr orig_out;
670675
Stdlib.close_in_noerr in_channel (* This closes pipe_read_fd *)
671676
with
672-
| End_of_file -> () (* Normal termination of the reader *)
673-
| exn ->
674-
Atomic.set reader_domain_failed true;
675-
Stdio.eprintf "Exception in stdout reader domain: %s\\nBacktrace:\\n%s\\n%!"
676-
(Exn.to_string exn) (Stdlib.Printexc.get_backtrace ());
677+
| End_of_file -> () (* Normal termination of the reader *)
678+
| exn ->
679+
Stdlib.close_out_noerr orig_out;
680+
Atomic.set reader_domain_failed true;
681+
Stdio.eprintf "Exception in stdout reader domain: %s\\nBacktrace:\\n%s\\n%!"
682+
(Exn.to_string exn) (Stdlib.Printexc.get_backtrace ());
677683
Stdlib.close_in_noerr in_channel (* This closes pipe_read_fd *);
678-
Stdlib.Printexc.raise_with_backtrace exn (Stdlib.Printexc.get_raw_backtrace ())
684+
Stdlib.Printexc.raise_with_backtrace exn (Stdlib.Printexc.get_raw_backtrace ())
679685
in
680686

681687
let reader_domain = Domain.spawn reader_domain_logic in
@@ -707,8 +713,6 @@ let capture_stdout_logs arg =
707713
~f:(fun { log_processor_prefix; process_logs } ->
708714
process_logs
709715
@@ List.filter_map captured_output ~f:(String.chop_prefix ~prefix:log_processor_prefix));
710-
(* Print passthrough lines even if arg() failed, if reader was ok *)
711-
List.iter (List.rev !passthrough_lines_ref) ~f:Stdlib.print_endline;
712716
);
713717
captured_log_processors := []; (* Clear processors *)
714718
Stdlib.Printexc.raise_with_backtrace exn (Stdlib.Printexc.get_raw_backtrace ())
@@ -740,10 +744,7 @@ let capture_stdout_logs arg =
740744
~f:(fun { log_processor_prefix; process_logs } ->
741745
process_logs
742746
@@ List.filter_map captured_output ~f:(String.chop_prefix ~prefix:log_processor_prefix)))
743-
~finally:(fun () -> captured_log_processors := []);
744-
745-
(* Then print passthrough lines to the now-restored original stdout *)
746-
List.iter (List.rev !passthrough_lines_ref) ~f:Stdlib.print_endline;
747+
~finally:(fun () -> captured_log_processors := [])
747748
) else (
748749
captured_log_processors := []; (* Clear processors if reader failed *)
749750
);

0 commit comments

Comments
 (0)