Skip to content

Commit f654955

Browse files
committed
moons_benchmark modified settings sweep
1 parent 41e59a6 commit f654955

File tree

2 files changed

+21
-30
lines changed

2 files changed

+21
-30
lines changed

bin/moons_benchmark.ml

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,32 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
3333
Tensor.default_value_prec := value_prec;
3434
Tensor.default_grad_prec := grad_prec;
3535
Utils.settings.output_debug_files_in_build_directory <- true;
36+
(* This will only log from routines if log-level is high enough. *)
3637
Utils.settings.debug_log_from_routines <- true;
3738
Rand.init (* seed *) 0;
3839
(* let hid_2_3 = 8 in let hid_4_5 = 4 in *)
3940
let hid_dim = 16 in
4041
(* let hid_dim = 4 in *)
41-
let len = batch_size * 20 in
42-
(* let epochs = 100 in *)
43-
let epochs = 20 in
44-
(* let epochs = 10 in *)
42+
let data_len = 3 * 1024 in
43+
let flat_len = data_len / 2 in
44+
(* Note: [minibatch_size = batch_size / num_devices] is the actual per-device batch used. *)
45+
(* let epochs = 20 in *)
46+
let epochs = 10 in
4547
(* let epochs = 1 in *)
4648
let init_lr = 0.1 in
4749
let noise () = Rand.float_range (-0.1) 0.1 in
4850
let moons_flat =
49-
Array.concat_map (Array.create ~len ())
51+
Array.concat_map (Array.create ~len:flat_len ())
5052
~f:
5153
Float.(
5254
fun () ->
53-
let i = Rand.int len in
54-
let v = of_int i * pi / of_int len in
55+
let i = Rand.int flat_len in
56+
let v = of_int i * pi / of_int flat_len in
5557
let c = cos v and s = sin v in
5658
[| c + noise (); s + noise (); 1.0 - c + noise (); 0.5 - s + noise () |])
5759
in
5860
let moons_flat ~b = TDSL.init_const ~l:"moons_flat" ~b ~o:[ 2 ] moons_flat in
59-
let moons_classes = Array.init (len * 2) ~f:(fun i -> if i % 2 = 0 then 1. else -1.) in
61+
let moons_classes = Array.init data_len ~f:(fun i -> if i % 2 = 0 then 1. else -1.) in
6062
let moons_classes ~b = TDSL.init_const ~l:"moons_classes" ~b ~o:[ 1 ] moons_classes in
6163

6264
let init_time = Time_now.nanoseconds_since_unix_epoch () in
@@ -80,7 +82,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
8082
let module Backend = (val backend) in
8183
Backend.initialize Train.BT.Most_parallel_devices;
8284
let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
83-
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_devices:num_devices ~data_len:len
85+
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_devices:num_devices ~data_len
8486
~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay
8587
~per_batch_callback ~per_epoch_callback
8688
(module Backend)
@@ -164,24 +166,12 @@ let _suspend () =
164166
@@ classify_moons ~seed:0 ~on_device:true ~inlining_cutoff:3 ~num_devices:8 ~batch_size:16
165167
~backend_name:"gccjit" ~value_prec:CDSL.single ~grad_prec:CDSL.double ()
166168

167-
let _cpu_benchmarks =
168-
List.concat_map [ 0; 1; 2; 3 ] ~f:(fun inlining_cutoff ->
169-
List.concat_map [ 1; 2; 5; 8; 10; 16 (* ; 20 *) ] ~f:(fun num_devices ->
170-
List.concat_map [ 120; 160 (* ; 320; 640; 1280 *) ] ~f:(fun batch_size ->
171-
List.concat_map [ 0; 1 (* ; 2; 3; 4 *) ] ~f:(fun seed ->
172-
List.concat_map [ "gccjit" (* *; "cuda" *) ] ~f:(fun backend_name ->
173-
[
174-
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices
175-
~batch_size ~backend_name ~value_prec:CDSL.single ~grad_prec:CDSL.single;
176-
])))))
177-
178169
let cuda_benchmarks =
179-
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
180-
List.concat_map [ 1; 2; 5; 8; 10 (* ; 16; 20; 30; 32; 40; 64 *) ] ~f:(fun num_devices ->
181-
List.concat_map [ 120; 160 (* ; 320; 640; 1280 *) ] ~f:(fun batch_size ->
170+
List.concat_map [ 0; 1; (* 2; *) 3 ] ~f:(fun inlining_cutoff ->
171+
List.concat_map [ 1; 3; 6; 12 (* ; 16; 32; 64 *) ] ~f:(fun num_devices ->
172+
List.concat_map [ 64; 128 (* ; 256 *) ] ~f:(fun batch_size ->
182173
List.concat_map [ 0; 1 (* ; 2; 3; 4 *) ] ~f:(fun seed ->
183-
List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ]
184-
~f:(fun backend_name ->
174+
List.concat_map [ (* "gccjit" ; *) "cc"; "cuda" ] ~f:(fun backend_name ->
185175
[
186176
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices
187177
~batch_size ~backend_name ~value_prec:CDSL.single ~grad_prec:CDSL.single;

lib/train.ml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -421,12 +421,13 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
421421
in
422422
let num_devices = Array.length prior_contexts in
423423
let minibatch_size = batch_size / num_devices in
424-
let n_batches = data_len / minibatch_size in
425-
let inputs = inputs ~b:[ n_batches; minibatch_size ] in
426-
let outputs = outputs ~b:[ n_batches; minibatch_size ] in
427-
let steps = epochs * n_batches in
424+
let n_minibatches = data_len / minibatch_size in
425+
let inputs = inputs ~b:[ n_minibatches; minibatch_size ] in
426+
let outputs = outputs ~b:[ n_minibatches; minibatch_size ] in
427+
(* This is the joint number of steps done by the round-robin scheduler across devices. *)
428+
let steps = epochs * n_minibatches in
428429
Utils.settings.fixed_state_for_init <- Some seed;
429-
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in
430+
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_minibatches IDX.empty in
430431
let step_n, bindings = IDX.get_static_symbol bindings in
431432
let%op input = inputs @| batch_n in
432433
let%op expectation = outputs @| batch_n in

0 commit comments

Comments
 (0)