@@ -33,30 +33,32 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
3333 Tensor. default_value_prec := value_prec;
3434 Tensor. default_grad_prec := grad_prec;
3535 Utils. settings.output_debug_files_in_build_directory < - true ;
36+ (* This will only log from routines if log-level is high enough. *)
3637 Utils. settings.debug_log_from_routines < - true ;
3738 Rand. init (* seed *) 0 ;
3839 (* let hid_2_3 = 8 in let hid_4_5 = 4 in *)
3940 let hid_dim = 16 in
4041 (* let hid_dim = 4 in *)
41- let len = batch_size * 20 in
42- (* let epochs = 100 in *)
43- let epochs = 20 in
44- (* let epochs = 10 in *)
42+ let data_len = 3 * 1024 in
43+ let flat_len = data_len / 2 in
44+ (* Note: [minibatch_size = batch_size / num_devices] is the actual per-device batch used. *)
45+ (* let epochs = 20 in *)
46+ let epochs = 10 in
4547 (* let epochs = 1 in *)
4648 let init_lr = 0.1 in
4749 let noise () = Rand. float_range (- 0.1 ) 0.1 in
4850 let moons_flat =
49- Array. concat_map (Array. create ~len () )
51+ Array. concat_map (Array. create ~len: flat_len () )
5052 ~f:
5153 Float. (
5254 fun () ->
53- let i = Rand. int len in
54- let v = of_int i * pi / of_int len in
55+ let i = Rand. int flat_len in
56+ let v = of_int i * pi / of_int flat_len in
5557 let c = cos v and s = sin v in
5658 [| c + noise () ; s + noise () ; 1.0 - c + noise () ; 0.5 - s + noise () |])
5759 in
5860 let moons_flat ~b = TDSL. init_const ~l: " moons_flat" ~b ~o: [ 2 ] moons_flat in
59- let moons_classes = Array. init (len * 2 ) ~f: (fun i -> if i % 2 = 0 then 1. else - 1. ) in
61+ let moons_classes = Array. init data_len ~f: (fun i -> if i % 2 = 0 then 1. else - 1. ) in
6062 let moons_classes ~b = TDSL. init_const ~l: " moons_classes" ~b ~o: [ 1 ] moons_classes in
6163
6264 let init_time = Time_now. nanoseconds_since_unix_epoch () in
@@ -80,7 +82,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
8082 let module Backend = (val backend) in
8183 Backend. initialize Train.BT. Most_parallel_devices ;
8284 let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
83- Train. example_train_loop ~seed ~batch_size ~init_lr ~max_num_devices: num_devices ~data_len: len
85+ Train. example_train_loop ~seed ~batch_size ~init_lr ~max_num_devices: num_devices ~data_len
8486 ~epochs ~inputs: moons_flat ~outputs: moons_classes ~model: mlp ~loss_fn ~weight_decay
8587 ~per_batch_callback ~per_epoch_callback
8688 (module Backend )
@@ -164,24 +166,12 @@ let _suspend () =
164166 @@ classify_moons ~seed: 0 ~on_device: true ~inlining_cutoff: 3 ~num_devices: 8 ~batch_size: 16
165167 ~backend_name: " gccjit" ~value_prec: CDSL. single ~grad_prec: CDSL. double ()
166168
167- let _cpu_benchmarks =
168- List. concat_map [ 0 ; 1 ; 2 ; 3 ] ~f: (fun inlining_cutoff ->
169- List. concat_map [ 1 ; 2 ; 5 ; 8 ; 10 ; 16 (* ; 20 *) ] ~f: (fun num_devices ->
170- List. concat_map [ 120 ; 160 (* ; 320; 640; 1280 *) ] ~f: (fun batch_size ->
171- List. concat_map [ 0 ; 1 (* ; 2; 3; 4 *) ] ~f: (fun seed ->
172- List. concat_map [ " gccjit" (* *; "cuda" *) ] ~f: (fun backend_name ->
173- [
174- classify_moons ~seed ~on_device: true ~inlining_cutoff ~num_devices
175- ~batch_size ~backend_name ~value_prec: CDSL. single ~grad_prec: CDSL. single;
176- ])))))
177-
178169let cuda_benchmarks =
179- List. concat_map [ 0 ; (* 1; 2; *) 3 ] ~f: (fun inlining_cutoff ->
180- List. concat_map [ 1 ; 2 ; 5 ; 8 ; 10 (* ; 16; 20; 30; 32; 40 ; 64 *) ] ~f: (fun num_devices ->
181- List. concat_map [ 120 ; 160 (* ; 320; 640; 1280 *) ] ~f: (fun batch_size ->
170+ List. concat_map [ 0 ; 1 ; (* 2; *) 3 ] ~f: (fun inlining_cutoff ->
171+ List. concat_map [ 1 ; 3 ; 6 ; 12 (* ; 16; 32 ; 64 *) ] ~f: (fun num_devices ->
172+ List. concat_map [ 64 ; 128 (* ; 256 *) ] ~f: (fun batch_size ->
182173 List. concat_map [ 0 ; 1 (* ; 2; 3; 4 *) ] ~f: (fun seed ->
183- List. concat_map [ (* "gccjit" ; *) " cc" (* ; "cuda" *) ]
184- ~f: (fun backend_name ->
174+ List. concat_map [ (* "gccjit" ; *) " cc" ; " cuda" ] ~f: (fun backend_name ->
185175 [
186176 classify_moons ~seed ~on_device: true ~inlining_cutoff ~num_devices
187177 ~batch_size ~backend_name ~value_prec: CDSL. single ~grad_prec: CDSL. single;
0 commit comments