Skip to content

Commit b9987fa

Browse files
committed
Add backend debugging and missing synchronization between epochs
1 parent 0159bfd commit b9987fa

File tree

7 files changed

+44
-7
lines changed

7 files changed

+44
-7
lines changed

arrayjit/lib/backend_intf.ml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ module type Device_config = sig
7979
val name : string
8080
end
8181

82-
8382
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
8483
dev : 'dev;
8584
ordinal : int;
@@ -270,6 +269,12 @@ module type Backend_device_common = sig
270269
val get_used_memory : device -> int
271270
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
272271

272+
val get_global_debug_info : unit -> Sexp.t
273+
(** Global debug information; backend-specific and might evolve independently on the backends. *)
274+
275+
val get_debug_info : stream -> Sexp.t
276+
(** Per-stream debug information; backend-specific and might evolve independently on the backends *)
277+
273278
val await : stream -> unit
274279
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)
275280

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,13 @@ let%track3_sexp link_batch prior_context (code_batch : code_batch) ctx_arrays =
476476
Some task))
477477
in
478478
(lowered_bindings, procs)
479+
480+
let get_global_debug_info () =
481+
Sexp.message "cuda_global_debug"
482+
[ ("live_streams", [%sexp_of: int] @@ Cudajit.Stream.get_total_live_streams ()) ]
483+
484+
let get_debug_info (stream : stream) =
485+
let tot, unr, unf = Cudajit.Stream.total_unreleased_unfinished_delimited_events stream.runner in
486+
let i2s = [%sexp_of: int] in
487+
Sexp.message "cuda_stream_debug"
488+
[ ("total_events", i2s tot); ("unreleased_events", i2s unr); ("unfinished_events", i2s unf) ]

arrayjit/lib/cuda_backend.missing.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,6 @@ let num_devices () = 0
7171
let suggested_num_streams Unimplemented_device = 0
7272
let get_ctx_stream Unimplemented_ctx = Unimplemented_stream
7373
let to_ordinal _stream = 0
74+
let get_global_debug_info () = Sexp.message "global_debug" []
75+
let get_debug_info Unimplemented_stream = Sexp.message "stream_debug" []
7476
let name = "cuda"

arrayjit/lib/schedulers.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ module Multicore (Backend : For_add_scheduler) :
206206
let stream = spinup_stream ~stream_id:!latest_stream_id in
207207
Stdlib.Gc.finalise cleanup_stream stream;
208208
stream
209+
210+
let get_global_debug_info () = Sexp.message "global_debug" []
211+
let get_debug_info (stream : stream) = sexp_of_runner stream.runner
209212
end
210213

211214
(** For debugging, allow [Sync_scheduler(...).suggested_num_streams] calls to return >1 numbers. *)
@@ -263,4 +266,6 @@ module Sync (Backend : For_add_scheduler) = struct
263266
let initialize = Backend.initialize
264267
let is_initialized = Backend.is_initialized
265268
let schedule_task _stream task = Task.run task
269+
let get_global_debug_info () = Sexp.message "global_debug" []
270+
let get_debug_info (stream : stream) = sexp_of_runner stream.runner
266271
end

bin/moons_benchmark.ml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,16 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
8080
let weight_decay = 0.0002 in
8181
Arrayjit.Schedulers.sync_suggested_num_streams := num_streams;
8282
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name ()) in
83+
Stdlib.Format.printf "Initial backend global debug info: %a\n%!" Sexp.pp_hum
84+
@@ Backend.get_global_debug_info ();
8385
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
8486
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
8587
in
8688
(* Tn.print_accessible_headers (); *)
8789
let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss =
8890
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
89-
epoch_loss
91+
epoch_loss;
92+
9093
in
9194
Backend.initialize Train.BT.Most_parallel_streams;
9295
let {
@@ -101,7 +104,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
101104
} =
102105
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams:num_streams ~data_len
103106
~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay
104-
~per_batch_callback ~per_epoch_callback
107+
~per_batch_callback ~per_epoch_callback ~per_epoch_debug_streams:true
105108
(module Backend)
106109
()
107110
in
@@ -177,6 +180,8 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
177180
}
178181
in
179182
Stdio.printf "\n\n%!";
183+
Stdlib.Format.printf "Final backend global debug info: %a\n%!" Sexp.pp_hum
184+
@@ Backend.get_global_debug_info ();
180185
result
181186

182187
let _suspend () =
@@ -248,4 +253,4 @@ let benchmark benchmarks =
248253
List.map benchmarks ~f:(fun bench -> bench ())
249254
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout
250255

251-
let () = benchmark _mem_benchmarks
256+
let () = benchmark _cuda_benchmarks

lib/train.ml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ type example_train_result = {
456456

457457
let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init_lr ?lr_schedule
458458
?(copy_to_merge = false) ?max_num_streams ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn
459-
~weight_decay ?per_batch_callback ?per_epoch_callback (module Backend : Backend) () =
459+
~weight_decay ?per_batch_callback ?per_epoch_callback ?(per_epoch_debug_streams = false)
460+
(module Backend : Backend) () =
460461
let module TDSL = Operation.TDSL in
461462
let module NTDSL = Operation.NTDSL in
462463
Rand.init seed;
@@ -528,7 +529,15 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
528529
epoch_losses := !epoch_loss :: !epoch_losses;
529530
Option.iter per_epoch_callback ~f:(fun f ->
530531
f ~at_step:!step_ref ~at_epoch:epoch ~learning_rate:learning_rate.@[0]
531-
~epoch_loss:!epoch_loss)
532+
~epoch_loss:!epoch_loss);
533+
let debug_at pos =
534+
Array.iter streams ~f:(fun s ->
535+
Stdlib.Format.printf "Stream %d debug %s:@ %a\n%!" s.stream_id pos Sexp.pp_hum
536+
@@ Backend.get_debug_info s)
537+
in
538+
if per_epoch_debug_streams then debug_at "before sync";
539+
Array.iter streams ~f:Backend.await;
540+
if per_epoch_debug_streams then debug_at "after sync"
532541
done;
533542
let%op model_result = model "infer" in
534543
let infer_fwd =

todo.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
(B) bin/moons_benchmark with the cc backend crashes with half-prec overflow {cm:2024-11-24}
33
(B) remove syncing from the data parallel algo: stream-to-stream syncing is now automatic {cm:2024-11-23}
44
(A) cuda backend crashes in bin/moons_benchmark {cm:2024-11-22}
5-
(B) figure out why cuda backend parallelism slows down in later epochs
5+
(B) figure out why cuda backend parallelism slows down in later epochs {cm:2024-11-25}
6+
(A) Ensure that reading from host on CPU performs required synchronization

0 commit comments

Comments
 (0)