Skip to content

Commit f1ca041

Browse files
committed
Logging support for the Metal backend, by Gemini
1 parent 41a9d1e commit f1ca041

File tree

1 file changed

+92
-17
lines changed

1 file changed

+92
-17
lines changed

arrayjit/lib/metal_backend.ml

Lines changed: 92 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
113113
let metal_devices : Me.Device.t array = Me.Device.copy_all_devices ()
114114
let () = assert (Array.length metal_devices > 0)
115115

116+
(* Store for captured logs per stream_id *)
117+
let stream_logs : (int, string list ref) Hashtbl.t = Hashtbl.create (module Int)
118+
116119
(* Metal has unified memory on Apple Silicon, so we can use host memory *)
117120
let get_buffer_for_ptr device ~size_in_bytes bytes =
118121
Me.Buffer.on_device_with_bytes_no_copy device ~bytes ~length:size_in_bytes
@@ -142,12 +145,47 @@ end) : Ir.Backend_impl.Lowered_backend = struct
142145

143146
let new_stream (device_wrapper : device) : stream =
144147
let metal_device = device_wrapper.dev in
145-
let queue = Me.CommandQueue.on_device metal_device in
148+
let queue =
149+
if Utils.debug_log_from_routines () then (
150+
let log_entries_ref = ref [] in
151+
(* This ref will be captured by the log handler *)
152+
let log_desc = Me.LogStateDescriptor.create () in
153+
Me.LogStateDescriptor.set_level log_desc Me.LogLevel.Debug;
154+
(* Capture all debug logs and above *)
155+
Me.LogStateDescriptor.set_buffer_size log_desc (1024 * 100);
156+
(* 100KB buffer *)
157+
let log_state = Me.LogState.on_device_with_descriptor metal_device log_desc in
158+
Me.LogState.add_log_handler log_state (fun ~sub_system:_ ~category:_ ~level:_ ~message ->
159+
(* Strip the !Utils.captured_log_prefix and the run_id prefix *)
160+
match String.chop_prefix ~prefix:!Utils.captured_log_prefix message with
161+
| None -> () (* Not a log line we are interested in, or malformed *)
162+
| Some rest_of_message -> (
163+
match String.lsplit2 ~on:':' rest_of_message with
164+
| Some (_run_id_str, actual_log_line_with_space) ->
165+
let actual_log_line = String.strip actual_log_line_with_space in
166+
log_entries_ref := actual_log_line :: !log_entries_ref
167+
| None -> () (* Malformed after prefix *)));
168+
let queue_desc = Me.CommandQueueDescriptor.create () in
169+
Me.CommandQueueDescriptor.set_log_state queue_desc (Some log_state);
170+
(* The log_state and its handler (capturing log_entries_ref) are kept alive by the
171+
queue_desc / queue itself. *)
172+
let created_q = Me.CommandQueue.on_device_with_descriptor metal_device queue_desc in
173+
(* Store the log_entries_ref for later retrieval, associated with the stream_id which will
174+
be assigned by make_stream shortly. We\'ll add it after make_stream. *)
175+
(created_q, Some log_entries_ref))
176+
else (Me.CommandQueue.on_device metal_device, None)
177+
in
178+
let actual_queue, opt_log_entries_ref = queue in
146179
let shared_event_obj = Me.SharedEvent.on_device metal_device in
147180
let counter = Unsigned.ULLong.one in
148181
(* Next value = 1 *)
149-
let runner = { queue; event = shared_event_obj; counter } in
150-
make_stream device_wrapper runner
182+
let runner = { queue = actual_queue; event = shared_event_obj; counter } in
183+
let stream_obj = make_stream device_wrapper runner in
184+
(* Finalize linking log_entries_ref to stream_id and set up GC finalizer *)
185+
Option.iter opt_log_entries_ref ~f:(fun log_ref ->
186+
Hashtbl.add_exn stream_logs ~key:stream_obj.stream_id ~data:log_ref);
187+
Stdlib.Gc.finalise (fun s -> Hashtbl.remove stream_logs s.stream_id) stream_obj;
188+
stream_obj
151189

152190
(* --- Event Handling --- *)
153191
let is_done event =
@@ -186,7 +224,15 @@ end) : Ir.Backend_impl.Lowered_backend = struct
186224
let queue = stream.runner.queue in
187225
let command_buffer = Me.CommandBuffer.on_queue queue in
188226
Me.CommandBuffer.commit command_buffer;
189-
Me.CommandBuffer.wait_until_completed command_buffer
227+
Me.CommandBuffer.wait_until_completed command_buffer;
228+
(* Process captured logs if any *)
229+
if Utils.debug_log_from_routines () then
230+
match Hashtbl.find stream_logs stream.stream_id with
231+
| Some log_entries_ref ->
232+
let logs_to_process = List.rev !log_entries_ref in
233+
if not (List.is_empty logs_to_process) then Utils.log_trace_tree logs_to_process;
234+
log_entries_ref := [] (* Clear processed logs *)
235+
| None -> () (* No log bucket for this stream, logging likely not enabled for it *)
190236

191237
let is_idle stream =
192238
(* FIXME: store the latest CommandBuffer with the stream runner and check that it's completed *)
@@ -263,7 +309,14 @@ end) : Ir.Backend_impl.Lowered_backend = struct
263309
let get_global_debug_info () = Sexp.Atom "Metal global debug info NYI"
264310

265311
let get_debug_info stream =
266-
Sexp.message "Metal stream debug info NYI" [ ("stream_id", sexp_of_int stream.stream_id) ]
312+
let num_pending_logs =
313+
match Hashtbl.find stream_logs stream.stream_id with None -> 0 | Some r -> List.length !r
314+
in
315+
Sexp.message "Metal stream debug info"
316+
[
317+
("stream_id", sexp_of_int stream.stream_id);
318+
("pending_shader_logs", sexp_of_int num_pending_logs);
319+
]
267320

268321
(* --- Copy Operations --- *)
269322
let from_host ~dst_ptr ~dst hosted =
@@ -385,12 +438,16 @@ end) : Ir.Backend_impl.Lowered_backend = struct
385438
"uint3 gid [[threadgroup_position_in_grid]]"; "uint3 lid [[thread_position_in_threadgroup]]";
386439
]
387440

388-
let includes = [ "<metal_stdlib>"; "<metal_math>"; "<metal_compute>"; "<metal_atomic>" ]
389-
441+
let includes =
442+
[ "<metal_stdlib>"; "<metal_math>"; "<metal_logging>"; "<metal_compute>"; "<metal_atomic>" ]
443+
390444
let metal_log_object_name = "custom_log" (* As used in logging_tests.ml *)
391-
let extra_declarations =
392-
[ "using namespace metal;";
393-
"constant os_log " ^ metal_log_object_name ^ "(\"com.custom_log.subsystem\", \"custom_category\");"
445+
446+
let extra_declarations =
447+
[
448+
"using namespace metal;";
449+
"constant os_log " ^ metal_log_object_name
450+
^ "(\"com.custom_log.subsystem\", \"custom_category\");";
394451
]
395452

396453
let typ_of_prec = function
@@ -490,15 +547,21 @@ end) : Ir.Backend_impl.Lowered_backend = struct
490547
let convert_precision ~from ~to_ =
491548
if Ops.equal_prec from to_ then ("", "") else ("(" ^ typ_of_prec to_ ^ ")(", ")")
492549

493-
let kernel_log_param = Some ("int", "log_id")
550+
let kernel_log_param = Some ("const int&", "log_id")
494551
let log_involves_file_management = false
495552

496553
let pp_log_statement ~log_param_c_expr_doc ~base_message_literal ~args_docs =
497554
let open PPrint in
498-
(* Metal os_log handles newlines directly. Prefix with captured_log_prefix and log_id for consistency. *)
499-
let format_string_literal =
500-
!Utils.captured_log_prefix ^ "%d: " ^ base_message_literal
555+
(* Metal os_log handles newlines directly. Prefix with captured_log_prefix and log_id for
556+
consistency. *)
557+
let base_message_literal =
558+
let with_ = if for_log_trace_tree then "$" else "\\n" in
559+
let res = String.substr_replace_all base_message_literal ~pattern:"\n" ~with_ in
560+
if for_log_trace_tree && String.is_suffix res ~suffix:"$" then
561+
String.drop_suffix res 1 ^ "\\n"
562+
else res
501563
in
564+
let format_string_literal = !Utils.captured_log_prefix ^ "%d: " ^ base_message_literal in
502565
let all_args =
503566
match log_param_c_expr_doc with
504567
| Some doc -> doc :: args_docs
@@ -507,13 +570,20 @@ end) : Ir.Backend_impl.Lowered_backend = struct
507570
group
508571
(string metal_log_object_name ^^ string ".log("
509572
^^ dquotes (string format_string_literal)
510-
^^ comma ^^ space ^^ separate (comma ^^ space) all_args ^^ rparen ^^ semi)
573+
^^ comma ^^ space
574+
^^ separate (comma ^^ space) all_args
575+
^^ rparen ^^ semi)
511576
end
512577

513578
let%diagn_sexp compile_metal_source ~name ~source ~device =
514579
let options = Me.CompileOptions.init () in
515-
Me.CompileOptions.set_language_version options Me.CompileOptions.LanguageVersion.version_3_0;
516-
if Utils.debug_log_from_routines () then Me.CompileOptions.set_enable_logging options true;
580+
if Utils.debug_log_from_routines () then (
581+
Me.CompileOptions.set_language_version options Me.CompileOptions.LanguageVersion.version_3_2;
582+
Me.CompileOptions.set_enable_logging options true
583+
) else (
584+
Me.CompileOptions.set_language_version options Me.CompileOptions.LanguageVersion.version_3_0
585+
(* Logging is disabled by default in CompileOptions, so no need to explicitly set it to false *)
586+
);
517587

518588
if Utils.with_runtime_debug () then (
519589
let metal_file = Utils.build_file (name ^ ".metal") in
@@ -587,9 +657,14 @@ end) : Ir.Backend_impl.Lowered_backend = struct
587657
let runner_label = get_name stream in
588658
let func = Me.Library.new_function_with_name library func_name in
589659
let pso, _ = Me.ComputePipelineState.on_device_with_function device func in
660+
(* let log_id_prefix_for_util = if Utils.debug_log_from_routines () then Int.to_string
661+
run_log_id ^ ": " else "" in *)
590662

591663
let work () : unit =
592664
[%log3_result "Launching", func_name, "on", runner_label, (run_log_id : int)];
665+
(* Unlike CUDA, we don\'t use Utils.add_log_processor here. Logs are captured by the LogState
666+
handler installed on the CommandQueue. They will be processed by Utils.log_trace_tree in
667+
`await`. *)
593668
try
594669
let command_buffer = Me.CommandBuffer.on_queue queue in
595670
let encoder = Me.ComputeCommandEncoder.on_buffer command_buffer in

0 commit comments

Comments
 (0)