@@ -15,15 +15,15 @@ let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
1515[%% global_debug_log_level 9 ]
1616[%% global_debug_log_level_from_env_var " OCANNL_LOG_LEVEL" ]
1717
18- let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~backend_name
18+ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~backend_name
1919 ~value_prec ~grad_prec () =
2020 [% track_sexp
2121 let _debug : string = " started" in
2222 (fun (started : unit ) -> started) () ];
2323 (* ignore seed; *)
2424 let bench_title =
2525 [% string
26- " seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_devices #Int}, batch \
26+ " seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_streams #Int}, batch \
2727 %{batch_size#Int}, backend %{backend_name}, val prec %{Ops.prec_string value_prec}, grad \
2828 prec %{Ops.prec_string grad_prec}" ]
2929 in
@@ -45,11 +45,11 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
4545 (* TINY for debugging: *)
4646 (* let data_len = 3 * 4 in *)
4747 let flat_len = data_len / 2 in
48- (* Note: [minibatch_size = batch_size / num_devices ] is the actual per-device batch used. *)
48+ (* Note: [minibatch_size = batch_size / num_streams ] is the actual per-device batch used. *)
4949 (* let epochs = 200 in *)
50- let epochs = 100 in
50+ (* let epochs = 100 in *)
5151 (* TINY for debugging: *)
52- (* let epochs = 2 in *)
52+ let epochs = 2 in
5353 (* let epochs = 1 in *)
5454 (* let init_lr = 0.1 in *)
5555 let init_lr = 0.01 in
@@ -78,7 +78,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
7878 let % op loss_fn ~output ~expectation = ?/ (! ..1 - (expectation *. output)) in
7979 let start_time = ref None in
8080 let weight_decay = 0.0002 in
81- Arrayjit.Backends. sync_suggested_num_streams := num_devices ;
81+ Arrayjit.Backends. sync_suggested_num_streams := num_streams ;
8282 let backend = Arrayjit.Backends. fresh_backend ~backend_name () in
8383 let per_batch_callback ~at_batch :_ ~at_step :_ ~learning_rate :_ ~batch_loss :_ ~epoch_loss :_ =
8484 if Option. is_none ! start_time then start_time := Some (Time_now. nanoseconds_since_unix_epoch () )
@@ -90,8 +90,17 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
9090 in
9191 let module Backend = (val backend) in
9292 Backend. initialize Train.BT. Most_parallel_streams ;
93- let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
94- Train. example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams: num_devices ~data_len
93+ let {
94+ Train. inputs;
95+ outputs;
96+ model_result;
97+ infer_callback;
98+ batch_losses;
99+ epoch_losses;
100+ learning_rates;
101+ used_memory;
102+ } =
103+ Train. example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams: num_streams ~data_len
95104 ~epochs ~inputs: moons_flat ~outputs: moons_classes ~model: mlp ~loss_fn ~weight_decay
96105 ~per_batch_callback ~per_epoch_callback
97106 (module Backend )
@@ -161,8 +170,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
161170 {
162171 bench_title;
163172 time_in_sec;
164- (* FIXME: implement total mem assessment. *)
165- mem_in_bytes = 0 ;
173+ mem_in_bytes = used_memory;
166174 result_label = " init time in sec, min loss, last loss" ;
167175 result =
168176 [% sexp_of: float * float * float ]
@@ -176,11 +184,11 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
176184
177185let _suspend () =
178186 ignore
179- @@ classify_moons ~seed: 0 ~on_device: true ~inlining_cutoff: 3 ~num_devices : 8 ~batch_size: 16
187+ @@ classify_moons ~seed: 0 ~on_device: true ~inlining_cutoff: 3 ~num_streams : 8 ~batch_size: 16
180188 ~backend_name: " gccjit" ~value_prec: CDSL. single ~grad_prec: CDSL. double ()
181189
182- let cuda_benchmarks =
183- List. concat_map [ 1 ; 3 ; 6 ; 12 ; 16 ; 20 (* 32; 64 *) ] ~f: (fun num_devices ->
190+ let _cuda_benchmarks =
191+ List. concat_map [ 1 ; 3 ; 6 ; 12 ; 16 ; 20 (* 32; 64 *) ] ~f: (fun num_streams ->
184192 List. concat_map
185193 [
186194 (* TINY for debugging: *)
@@ -194,7 +202,26 @@ let cuda_benchmarks =
194202 List. concat_map [ (* CDSL.double; *) CDSL. single (* ; CDSL.half *) ]
195203 ~f: (fun value_prec ->
196204 [
197- classify_moons ~seed ~on_device: true ~inlining_cutoff ~num_devices
205+ classify_moons ~seed ~on_device: true ~inlining_cutoff ~num_streams
206+ ~batch_size ~backend_name ~value_prec ~grad_prec: value_prec;
207+ ]))))))
208+
209+ let _mem_benchmarks =
210+ List. concat_map [ 1 ; 3 ; 6 ; 12 ; 16 (* ; 20; 32; 64 *) ] ~f: (fun num_streams ->
211+ List. concat_map
212+ [
213+ (* TINY for debugging: *)
214+ (* 3 * 2 *)
215+ 3 * 5 * 16 (* ; 3 * 5 * 32; 3 * 5 * 64 *) ;
216+ ]
217+ ~f: (fun batch_size ->
218+ List. concat_map [ 0 ; (* 1; 2; *) 3 ] ~f: (fun inlining_cutoff ->
219+ List. concat_map [ (* 1; 3; *) 7 (* *) ] ~f: (fun seed ->
220+ List. concat_map [ (* "gccjit" ; *) " cc" (* ; "cuda" *) ] ~f: (fun backend_name ->
221+ List. concat_map [ CDSL. double; CDSL. single (* ; CDSL.half *) ]
222+ ~f: (fun value_prec ->
223+ [
224+ classify_moons ~seed ~on_device: true ~inlining_cutoff ~num_streams
198225 ~batch_size ~backend_name ~value_prec ~grad_prec: value_prec;
199226 ]))))))
200227
@@ -204,7 +231,7 @@ let cuda_benchmarks =
204231 (nth - 1) *)
205232
206233let fixed_seed_search seed =
207- classify_moons ~seed ~on_device: true ~inlining_cutoff: 3 ~num_devices : 1 ~batch_size: 20
234+ classify_moons ~seed ~on_device: true ~inlining_cutoff: 3 ~num_streams : 1 ~batch_size: 20
208235 ~backend_name: " cuda" ~value_prec: CDSL. single ~grad_prec: CDSL. single ()
209236
210237let _suspended () =
@@ -213,8 +240,8 @@ let _suspended () =
213240(* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output
214241 Stdio.stdout *)
215242
216- let benchmark () =
217- List. map cuda_benchmarks ~f: (fun bench -> bench () )
243+ let benchmark benchmarks =
244+ List. map benchmarks ~f: (fun bench -> bench () )
218245 |> PrintBox_utils. table |> PrintBox_text. output Stdio. stdout
219246
220- let () = benchmark ()
247+ let () = benchmark _mem_benchmarks
0 commit comments