@@ -80,13 +80,16 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
8080 let weight_decay = 0.0002 in
8181 Arrayjit.Schedulers. sync_suggested_num_streams := num_streams;
8282 let module Backend = (val Arrayjit.Backends. fresh_backend ~backend_name () ) in
83+ Stdlib.Format. printf " Initial backend global debug info: %a\n %!" Sexp. pp_hum
84+ @@ Backend. get_global_debug_info () ;
8385 let per_batch_callback ~at_batch :_ ~at_step :_ ~learning_rate :_ ~batch_loss :_ ~epoch_loss :_ =
8486 if Option. is_none ! start_time then start_time := Some (Time_now. nanoseconds_since_unix_epoch () )
8587 in
8688 (* Tn.print_accessible_headers (); *)
8789 let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss =
8890 Stdio. printf " Epoch=%d, step=%d, lr=%f, epoch loss=%f\n %!" at_epoch at_step learning_rate
89- epoch_loss
91+ epoch_loss;
92+
9093 in
9194 Backend. initialize Train.BT. Most_parallel_streams ;
9295 let {
@@ -101,7 +104,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
101104 } =
102105 Train. example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams: num_streams ~data_len
103106 ~epochs ~inputs: moons_flat ~outputs: moons_classes ~model: mlp ~loss_fn ~weight_decay
104- ~per_batch_callback ~per_epoch_callback
107+ ~per_batch_callback ~per_epoch_callback ~per_epoch_debug_streams: true
105108 (module Backend )
106109 ()
107110 in
@@ -177,6 +180,8 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
177180 }
178181 in
179182 Stdio. printf " \n\n %!" ;
183+ Stdlib.Format. printf " Final backend global debug info: %a\n %!" Sexp. pp_hum
184+ @@ Backend. get_global_debug_info () ;
180185 result
181186
182187let _suspend () =
@@ -248,4 +253,4 @@ let benchmark benchmarks =
248253 List. map benchmarks ~f: (fun bench -> bench () )
249254 |> PrintBox_utils. table |> PrintBox_text. output Stdio. stdout
250255
251- let () = benchmark _mem_benchmarks
256+ let () = benchmark _cuda_benchmarks
0 commit comments