Skip to content

Commit e85a09d

Browse files
committed
Debugging tweaks
1 parent 9c29e0c commit e85a09d

File tree

6 files changed

+66
-47
lines changed

6 files changed

+66
-47
lines changed

arrayjit/lib/assignments.ml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ let is_total ~initialize_neutral ~projections =
8080
(** Returns materialized nodes in the sense of {!Tnode.is_in_context}. NOTE: it should be called
8181
after compilation and ideally after linking with the relevant contexts; otherwise, it is an
8282
under-estimate. *)
83-
let context_nodes ~use_host_memory asgns =
83+
let%debug3_sexp context_nodes ~(use_host_memory : bool) (asgns : t) : Tn.t_set =
8484
let open Utils.Set_O in
8585
let empty = Set.empty (module Tn) in
8686
let one tn =
@@ -117,12 +117,12 @@ let%diagn1_sexp to_low_level code =
117117
if not (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)) then
118118
[%log
119119
"get",
120-
"a=",
121-
(tn : Tn.t),
122-
":",
123-
Tn.label tn,
124-
(idcs : Indexing.axis_index array),
125-
(Lazy.force tn.dims : int array)];
120+
"a=",
121+
(tn : Tn.t),
122+
":",
123+
Tn.label tn,
124+
(idcs : Indexing.axis_index array),
125+
(Lazy.force tn.dims : int array)];
126126
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
127127
match buffer with
128128
| Node tn -> Low_level.Get (tn, idcs)
@@ -133,12 +133,12 @@ let%diagn1_sexp to_low_level code =
133133
if not (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)) then
134134
[%log
135135
"set",
136-
"a=",
137-
(tn : Tn.t),
138-
":",
139-
Tn.label tn,
140-
(idcs : Indexing.axis_index array),
141-
(Lazy.force tn.dims : int array)];
136+
"a=",
137+
(tn : Tn.t),
138+
":",
139+
Tn.label tn,
140+
(idcs : Indexing.axis_index array),
141+
(Lazy.force tn.dims : int array)];
142142
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
143143
Low_level.Set { tn; idcs; llv; debug = "" }
144144
in

arrayjit/lib/backends.ml

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ let check_merge_buffer stream ~code_node =
2222
^ ", expected by code: " ^ name code_node)
2323

2424
module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
25-
let[@landmark] wait_for_all ctx streams tn =
25+
let wait_for_all ctx streams tn =
2626
let s = ctx.stream in
2727
Hashtbl.update_and_return streams tn
2828
~f:
@@ -31,15 +31,15 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
3131
|> List.iter ~f:(fun (work_stream, e) ->
3232
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)
3333

34-
let[@landmark] wait_for_ready ~dst ~src tn =
34+
let wait_for_ready ~dst ~src tn =
3535
let s = src.stream in
3636
let d = dst.stream in
3737
(* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *)
3838
Hashtbl.find s.updating_for tn
3939
|> Option.iter ~f:(fun upd_e ->
4040
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)
4141

42-
let[@landmark] update_writer_event ?e ?from s tn =
42+
let update_writer_event ?e ?from s tn =
4343
let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in
4444
let f l = (s, e) :: Option.value ~default:[] l in
4545
(match (from, tn) with
@@ -59,22 +59,24 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
5959
(* Note: the previous event does not need to be done! *)
6060
s.updating_for_merge_buffer <- Some (tn, Some e)
6161

62-
let%track2_l_sexp[@landmark] from_host (ctx : Backend.context) tn =
62+
let%track2_l_sexp from_host (ctx : Backend.context) tn =
6363
match (tn, Map.find ctx.ctx_arrays tn) with
6464
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
6565
wait_for_all ctx ctx.stream.reader_streams tn;
6666
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
67+
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
6768
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
6869
update_writer_event ~from:`Host ctx.stream @@ Node tn;
6970
true
7071
| _ -> false
7172

72-
let%track2_l_sexp[@landmark] to_host (ctx : Backend.context) (tn : Tn.t) =
73+
let%track2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
7374
match (tn, Map.find ctx.ctx_arrays tn) with
7475
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
7576
if Tn.potentially_cross_stream tn then
7677
wait_for_all ctx ctx.stream.device.shared_writer_streams tn;
7778
[%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"];
79+
(* Stdio.printf "copying: %s to_host\n" (Tn.debug_name tn); *)
7880
Backend.to_host ~src_ptr:src ~src:ctx hosted;
7981
let s = ctx.stream in
8082
let e = Backend.all_work s in
@@ -83,8 +85,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
8385
true
8486
| _ -> false
8587

86-
let%diagn2_l_sexp[@landmark] device_to_device (tn : Tn.t) ~into_merge_buffer
87-
~(dst : Backend.context) ~(src : Backend.context) =
88+
let%diagn2_l_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context)
89+
~(src : Backend.context) =
8890
let ordinal_of ctx = ctx.stream.device.ordinal in
8991
let name_of ctx = Backend.(get_name ctx.stream) in
9092
let same_device = ordinal_of dst = ordinal_of src in
@@ -116,15 +118,17 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
116118
Backend.(
117119
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
118120
dst.stream.updating_for_merge_buffer <- Some (tn, None);
119-
let[@landmark] merge_task () = Task.run task in
121+
let merge_task () = Task.run task in
120122
merge_task ();
121123
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
122124
[%log "streaming into merge buffer", Tn.debug_name tn, "from", name_of src];
123125
true)
124126

125-
let%track2_l_sexp sync_routine r =
127+
type r = Backend.context routine [@@deriving sexp_of]
128+
129+
let%track2_l_sexp sync_routine (r : r) : r =
126130
let s = r.context.stream in
127-
let[@landmark] pre () =
131+
let pre () =
128132
Set.iter r.inputs ~f:(fun tn ->
129133
if Tn.potentially_cross_stream tn then
130134
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->
@@ -135,13 +139,13 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
135139
else Hashtbl.remove s.device.shared_writer_streams tn)
136140
(* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *)
137141
in
138-
let[@landmark] post () =
142+
let post () =
139143
let e = Backend.all_work s in
140144
Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn)
141145
in
142146
{ r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) }
143147

144-
let[@landmark] sync_device device =
148+
let sync_device device =
145149
Utils.weak_iter device.streams ~f:Backend.await;
146150
Hashtbl.clear device.host_writing_streams;
147151
Hashtbl.clear device.host_reading_streams;
@@ -180,15 +184,16 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
180184
Some (Assignments.lower ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) )
181185
else (None, None))
182186

183-
let verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context =
187+
let%debug3_sexp verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context : unit =
184188
Set.iter from_prior_context ~f:(fun tn ->
185189
if
186190
(* Err on the safe side. *)
187191
Option.value ~default:false (Tn.is_in_context ~use_host_memory tn)
188192
&& not (Option.is_some @@ Map.find ctx_arrays tn)
189193
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))
190194

191-
let from_prior_context_batch ~use_host_memory comps =
195+
let%debug3_sexp from_prior_context_batch ~use_host_memory (comps : Assignments.comp option array) :
196+
Tn.t_set =
192197
Array.filter_map comps ~f:(fun comp ->
193198
Option.map comp ~f:(fun comp ->
194199
Set.diff
@@ -279,20 +284,20 @@ module Add_device
279284
in
280285
(Option.value_exn ~here:[%here] bindings, schedules)
281286

282-
let[@landmark] from_host ~dst_ptr ~dst hosted =
287+
let from_host ~dst_ptr ~dst hosted =
283288
let work () = host_to_buffer hosted ~dst:dst_ptr in
284289
(* TODO: pass description to from_host. *)
285290
schedule_task dst.stream
286291
(Task.Task
287292
{ context_lifetime = dst; description = "from_host on " ^ get_name dst.stream; work })
288293

289-
let[@landmark] to_host ~src_ptr ~src hosted =
294+
let to_host ~src_ptr ~src hosted =
290295
let work () = buffer_to_host hosted ~src:src_ptr in
291296
(* TODO: pass description to to_host. *)
292297
schedule_task src.stream
293298
(Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work })
294299

295-
let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
300+
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
296301
let s = dst.stream in
297302
let size_in_bytes = Tnode.size_in_bytes tn in
298303
let work =
@@ -343,15 +348,16 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
343348
}
344349
[@@deriving sexp_of]
345350

346-
let compile ?shared ?name bindings comp : code =
351+
let%debug3_sexp compile ?shared ?name bindings (comp : Assignments.comp) : code =
347352
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
348353
let code = compile ?shared ~name bindings lowered in
349354
let from_prior_context =
350355
Set.diff (Assignments.context_nodes ~use_host_memory comp.asgns) comp.embedded_nodes
351356
in
352357
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }
353358

354-
let compile_batch ?shared ?names ?occupancy bindings comps =
359+
let%debug3_sexp compile_batch ?shared ?names ?occupancy bindings (comps : Assignments.comp array) :
360+
code_batch =
355361
let names, lowereds =
356362
lower_batch_assignments ?names ?occupancy bindings
357363
@@ Array.map comps ~f:(fun c -> c.Assignments.asgns)
@@ -479,7 +485,7 @@ let reinitialize (module Backend : Backend) config =
479485
Stdlib.Gc.full_major ();
480486
Backend.initialize config)
481487

482-
let[@landmark] finalize (type buffer_ptr dev runner event)
488+
let finalize (type buffer_ptr dev runner event)
483489
(module Backend : Backend
484490
with type buffer_ptr = buffer_ptr
485491
and type dev = dev

arrayjit/lib/c_syntax.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct
6161

6262
(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
6363
-> idx + (offset * dim)) *)
64-
let%diagn_sexp compile_globals ppf =
64+
let%debug3_sexp compile_globals ppf : Tn.t Hash_set.t =
6565
let open Stdlib.Format in
6666
let is_global = Hash_set.create (module Tn) in
6767
fprintf ppf {|@[<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,24 +172,28 @@ let suggested_num_streams device =
172172
| For_parallel_copying -> 1 + (cuda_properties device).async_engine_count
173173
| Most_parallel_streams -> (cuda_properties device).multiprocessor_count
174174

175-
let[@landmark] await stream : unit =
175+
let await stream : unit =
176176
set_ctx stream.device.dev.primary_context;
177177
Cu.Stream.synchronize stream.runner;
178178
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ())
179179

180180
let is_idle stream = Cu.Stream.is_ready stream.runner
181181

182-
let[@landmark] from_host ~dst_ptr ~dst hosted =
182+
let from_host ~dst_ptr ~dst hosted =
183+
(* Stdio.printf "run: from_host on backend:0:%d\n" dst.stream.stream_id; *)
183184
set_ctx @@ ctx_of dst;
184185
let f src = Cu.Stream.memcpy_H_to_D ~dst:dst_ptr ~src dst.stream.runner in
185186
Ndarray.map { f } hosted
186187

187-
let[@landmark] to_host ~src_ptr ~src hosted =
188+
let to_host ~src_ptr ~src hosted =
189+
(* Stdio.printf "run: to_host on backend:0:%d\n" src.stream.stream_id; *)
188190
set_ctx @@ ctx_of src;
189191
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src_ptr src.stream.runner in
190192
Ndarray.map { f } hosted
191193

192-
let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
194+
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
195+
(* Stdio.printf "run: device_to_device %s dst backend:0:%d src backend:0:%d\n" (Tn.debug_name tn)
196+
dst.stream.stream_id src.stream.stream_id; *)
193197
let dev = dst.stream.device in
194198
let same_device = dev.ordinal = src.stream.device.ordinal in
195199
let size_in_bytes = Tn.size_in_bytes tn in

arrayjit/lib/task.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type t =
1313

1414
let describe (Task task) = task.description
1515

16-
let%diagn_l_sexp run (Task task) =
16+
let%debug3_l_sexp run (Task task) : unit =
1717
[%log_result "run", task.description];
1818
task.work ()
1919

arrayjit/lib/tnode.ml

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,13 @@ let log_debug_info ~from_log_level tn =
145145
from_log_level (debug_name tn);
146146
[%log
147147
"id:",
148-
(tn.id : int),
149-
"label:",
150-
(tn.label : string list),
151-
"mem:",
152-
debug_memory_mode tn.memory_mode,
153-
"backends:",
154-
(tn.backend_info : Sexp.t)];
148+
(tn.id : int),
149+
"label:",
150+
(tn.label : string list),
151+
"mem:",
152+
debug_memory_mode tn.memory_mode,
153+
"backends:",
154+
(tn.backend_info : Sexp.t)];
155155
if Lazy.is_val tn.array then
156156
match tn.array with
157157
| (lazy None) -> [%log "<not-on-host>"]
@@ -190,7 +190,7 @@ let is_materialized_force tn provenance =
190190

191191
(* Unlike the [known_] functions which can only change from [false] to [true], [is_in_context
192192
~use_host_memory tn] is more precise. Generally, it can only change away from [None]. *)
193-
let is_in_context ~use_host_memory tn =
193+
let%debug3_sexp is_in_context ~(use_host_memory : bool) (tn : t) : bool option =
194194
match tn.memory_mode with
195195
| Some (Hosted (Changed_on_devices Per_stream), _) -> Some true
196196
| Some ((Materialized | Hosted Nonconstant), _) when not use_host_memory -> Some true
@@ -404,6 +404,15 @@ let hash nd = Int.hash nd.id
404404
let hash_fold_t acc nd = hash_fold_int acc nd.id
405405
let hash_t = hash
406406

407+
module Comp = struct
408+
type nonrec t = t
409+
type nonrec comparator_witness = comparator_witness
410+
end
411+
412+
type t_set = Set.M(Comp).t
413+
414+
let sexp_of_t_set s = [%sexp_of: t Sequence.t] @@ Set.to_sequence s
415+
407416
let get_exn a =
408417
match a.array with
409418
| (lazy (Some nd)) -> nd

0 commit comments

Comments
 (0)