Skip to content

Commit 495b9f4

Browse files
committed
Tentative domain-based stdout capture
Signed-off-by: lukstafi <lukstafi@users.noreply.github.com>
1 parent 326a2e6 commit 495b9f4

File tree

3 files changed

+94
-53
lines changed

3 files changed

+94
-53
lines changed

arrayjit/lib/cuda_backend.ml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ end) : Ir.Backend_impl.Lowered_backend = struct
113113
let%track4_sexp finalize_device (device : device) =
114114
Cu.Context.set_current device.dev.primary_context;
115115
Cu.Context.synchronize ();
116-
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
117116
(* Note: this is not necessary as releasing the primary context by GC will reset the context. *)
118117
Hashtbl.iter device.cross_stream_candidates ~f:(fun buffer_ptr ->
119118
Cu.Deviceptr.mem_free buffer_ptr)
@@ -168,8 +167,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
168167

169168
let await stream : unit =
170169
set_ctx stream.device.dev.primary_context;
171-
Cu.Stream.synchronize stream.runner;
172-
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ())
170+
Cu.Stream.synchronize stream.runner
173171

174172
let is_idle stream = Cu.Stream.is_ready stream.runner
175173

arrayjit/lib/utils.ml

Lines changed: 92 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ let captured_log_prefix = ref "!@#"
622622

623623
(** To avoid the complication of a concurrent thread, we expose a callback for collaborative log
624624
processing. *)
625-
let advance_captured_logs = ref None
625+
let advance_captured_logs : (unit -> unit) option ref = ref None
626626

627627
type captured_log_processor = { log_processor_prefix : string; process_logs : string list -> unit }
628628

@@ -644,59 +644,101 @@ let input_line chan =
644644
let capture_stdout_logs arg =
645645
if never_capture_stdout () || not (debug_log_from_routines ()) then arg ()
646646
else (
647-
Stdlib.flush Stdlib.stdout;
648-
let ls = ref [] in
649-
let lastl = ref "" in
650-
let backup = ref (Unix.dup Unix.stdout) in
651-
let exit_entrance = ref (Unix.pipe ()) in
652-
let pre_advance () =
653-
Unix.dup2 (snd !exit_entrance) Unix.stdout;
654-
Unix.set_nonblock (snd !exit_entrance)
655-
in
656-
let advance is_last () =
657-
Stdlib.flush Stdlib.stdout;
658-
Unix.close (snd !exit_entrance);
659-
Unix.dup2 !backup Unix.stdout;
660-
let channel = Unix.in_channel_of_descr (fst !exit_entrance) in
661-
(try
662-
while true do
663-
let is_endlined, line = input_line channel in
664-
let line = !lastl ^ line in
665-
if is_endlined then (
666-
(match String.chop_prefix ~prefix:!captured_log_prefix line with
667-
| None -> Stdlib.print_endline line
668-
(* ls := line :: !ls *)
669-
| Some logline -> ls := logline :: !ls);
670-
lastl := "")
671-
else lastl := line
672-
done
673-
with End_of_file -> ());
674-
if not is_last then (
675-
backup := Unix.dup Unix.stdout;
676-
exit_entrance := Unix.pipe ();
677-
pre_advance ())
647+
Stdlib.flush Stdlib.stdout; (* Ensure previous stdout is flushed *)
648+
let original_stdout_fd = Unix.dup Unix.stdout in
649+
let old_advance_captured_logs_val = !advance_captured_logs in
650+
advance_captured_logs := None;
651+
652+
let pipe_read_fd, pipe_write_fd = Unix.pipe ~cloexec:true () in
653+
Unix.dup2 pipe_write_fd Unix.stdout;
654+
(* pipe_write_fd is now the new Stdlib.stdout, do not close it in parent until done. *)
655+
(* The reader domain will close pipe_read_fd. *)
656+
657+
let collected_logs_ref = ref [] in
658+
let passthrough_lines_ref = ref [] in (* Buffer for non-log lines *)
659+
let reader_domain_failed = Atomic.make false in
660+
661+
let reader_domain_logic () =
662+
let in_channel = Unix.in_channel_of_descr pipe_read_fd in
663+
try
664+
while true do
665+
let _is_endlined, line = input_line in_channel in
666+
match String.chop_prefix ~prefix:!captured_log_prefix line with
667+
| Some logline -> collected_logs_ref := logline :: !collected_logs_ref
668+
| None -> passthrough_lines_ref := line :: !passthrough_lines_ref (* Buffer the line *)
669+
done;
670+
Stdlib.close_in_noerr in_channel (* This closes pipe_read_fd *)
671+
with
672+
| End_of_file -> () (* Normal termination of the reader *)
673+
| exn ->
674+
Atomic.set reader_domain_failed true;
675+
Stdio.eprintf "Exception in stdout reader domain: %s\\nBacktrace:\\n%s\\n%!"
676+
(Exn.to_string exn) (Stdlib.Printexc.get_backtrace ());
677+
Stdlib.close_in_noerr in_channel (* This closes pipe_read_fd *);
678+
Stdlib.Printexc.raise_with_backtrace exn (Stdlib.Printexc.get_raw_backtrace ())
678679
in
679-
advance_captured_logs := Some (advance false);
680-
pre_advance ();
680+
681+
let reader_domain = Domain.spawn reader_domain_logic in
682+
681683
let result =
682684
try arg ()
683-
with Sys_blocked_io ->
684-
advance_captured_logs := None;
685-
invalid_arg
686-
"capture_stdout_logs: unfortunately, flushing stdout inside captured code is prohibited"
685+
with exn ->
686+
(* Ensure cleanup even if arg() fails *)
687+
Stdlib.flush Stdlib.stdout; (* Flush to pipe_write_fd *)
688+
Unix.close pipe_write_fd; (* Signal EOF to reader domain *)
689+
(try Domain.join reader_domain
690+
with e ->
691+
Stdio.eprintf "Exception while joining reader domain (arg failed): %s\\n%!"
692+
(Exn.to_string e));
693+
694+
Unix.dup2 original_stdout_fd Unix.stdout; (* Restore stdout *)
695+
Unix.close original_stdout_fd;
696+
advance_captured_logs := old_advance_captured_logs_val;
697+
698+
if not (Atomic.get reader_domain_failed) then (
699+
let captured_output = List.rev !collected_logs_ref in
700+
List.iter (List.rev !captured_log_processors)
701+
~f:(fun { log_processor_prefix; process_logs } ->
702+
process_logs
703+
@@ List.filter_map captured_output ~f:(String.chop_prefix ~prefix:log_processor_prefix));
704+
(* Print passthrough lines even if arg() failed, if reader was ok *)
705+
List.iter (List.rev !passthrough_lines_ref) ~f:Stdlib.print_endline;
706+
);
707+
captured_log_processors := []; (* Clear processors *)
708+
Stdlib.Printexc.raise_with_backtrace exn (Stdlib.Printexc.get_raw_backtrace ())
687709
in
688-
advance true ();
689-
let output = List.rev !ls in
690-
Exn.protect
691-
~f:(fun () ->
692-
(* Preserve the order in which kernels were launched. *)
693-
List.iter (List.rev !captured_log_processors)
694-
~f:(fun { log_processor_prefix; process_logs } ->
695-
process_logs
696-
@@ List.filter_map output ~f:(String.chop_prefix ~prefix:log_processor_prefix)))
697-
~finally:(fun () ->
698-
advance_captured_logs := None;
699-
captured_log_processors := []);
710+
711+
(* Normal path: arg() completed successfully *)
712+
Stdlib.flush Stdlib.stdout; (* Flush to pipe_write_fd *)
713+
Unix.close pipe_write_fd; (* Signal EOF to reader domain *)
714+
715+
(try Domain.join reader_domain
716+
with e ->
717+
Stdio.eprintf "Exception while joining reader domain (arg succeeded): %s\\n%!"
718+
(Exn.to_string e);
719+
if Atomic.get reader_domain_failed then
720+
Stdlib.Printexc.raise_with_backtrace e (Stdlib.Printexc.get_raw_backtrace ()));
721+
722+
Unix.dup2 original_stdout_fd Unix.stdout; (* Restore stdout *)
723+
Unix.close original_stdout_fd;
724+
advance_captured_logs := old_advance_captured_logs_val;
725+
726+
if not (Atomic.get reader_domain_failed) then (
727+
let captured_output = List.rev !collected_logs_ref in
728+
Exn.protect
729+
~f:(fun () ->
730+
(* Process captured logs by processors first. *)
731+
List.iter (List.rev !captured_log_processors)
732+
~f:(fun { log_processor_prefix; process_logs } ->
733+
process_logs
734+
@@ List.filter_map captured_output ~f:(String.chop_prefix ~prefix:log_processor_prefix)))
735+
~finally:(fun () -> captured_log_processors := []);
736+
737+
(* Then print passthrough lines to the now-restored original stdout *)
738+
List.iter (List.rev !passthrough_lines_ref) ~f:Stdlib.print_endline;
739+
) else (
740+
captured_log_processors := []; (* Clear processors if reader failed *)
741+
);
700742
result)
701743

702744
let log_debug_routine_logs ~log_contents ~stream_name =

bin/micrograd_demo_logging.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ let () =
3232
List.iter ~f:(Option.iter ~f:(fun diff -> Train.set_hosted diff.Tensor.grad)) [ a.diff; b.diff ];
3333
let update = Train.grad_update g in
3434
let step = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
35+
Utils.capture_stdout_logs @@ fun () ->
3536
Train.run step;
3637
Tensor.print ~with_code:false ~with_grad:false `Default g;
3738
Tensor.print ~with_code:false ~with_grad:true `Default a;

0 commit comments

Comments
 (0)