Skip to content

Commit e2780a6

Browse files
committed
Fixes #245: report used memory
Note: cuda backend migration to `Tnode.sharing` still broken.
1 parent a09e2d7 commit e2780a6

File tree

7 files changed

+122
-50
lines changed

7 files changed

+122
-50
lines changed

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ let is_initialized, initialize =
4949
let finalize _ctx = ()
5050

5151
let init ~label =
52-
let result =
53-
{ label; arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tn) } }
54-
in
52+
let result = { label; arrays = empty_ctx_arrays } in
5553
Stdlib.Gc.finalise finalize result;
5654
result
5755

bin/compilation_speed.ml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ let benchmark_overhead backend () =
2424
(* Train.every_non_literal_on_host f; *)
2525
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
2626
let ctx = Backend.init stream in
27+
let init_mem = Backend.(get_used_memory @@ get_stream_device stream) in
2728
let update_f = Train.grad_update f in
2829
(* Initialize the context with a mock update of x to ensure that it is not optimized as a
2930
constant. *)
@@ -60,13 +61,13 @@ let benchmark_overhead backend () =
6061
in
6162
let final_time = Time_now.nanoseconds_since_unix_epoch () in
6263
let time_in_sec = Int63.(to_float @@ (final_time - init_time)) /. 1000_000_000. in
64+
let mem_in_bytes = Backend.(get_used_memory @@ get_stream_device stream) - init_mem in
6365
let result =
6466
PrintBox_utils.Benchmark
6567
{
6668
bench_title = Backend.name ^ " overhead";
6769
time_in_sec;
68-
(* FIXME: global mem consumption *)
69-
mem_in_bytes = 0;
70+
mem_in_bytes;
7071
result_label = "x, f(x)";
7172
result =
7273
[%sexp_of: (float * float) list]

bin/moons_benchmark.ml

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
1515
[%%global_debug_log_level 9]
1616
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
1717

18-
let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~backend_name
18+
let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~backend_name
1919
~value_prec ~grad_prec () =
2020
[%track_sexp
2121
let _debug : string = "started" in
2222
(fun (started : unit) -> started) ()];
2323
(* ignore seed; *)
2424
let bench_title =
2525
[%string
26-
"seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_devices#Int}, batch \
26+
"seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_streams#Int}, batch \
2727
%{batch_size#Int}, backend %{backend_name}, val prec %{Ops.prec_string value_prec}, grad \
2828
prec %{Ops.prec_string grad_prec}"]
2929
in
@@ -45,11 +45,11 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
4545
(* TINY for debugging: *)
4646
(* let data_len = 3 * 4 in *)
4747
let flat_len = data_len / 2 in
48-
(* Note: [minibatch_size = batch_size / num_devices] is the actual per-device batch used. *)
48+
(* Note: [minibatch_size = batch_size / num_streams] is the actual per-device batch used. *)
4949
(* let epochs = 200 in *)
50-
let epochs = 100 in
50+
(* let epochs = 100 in *)
5151
(* TINY for debugging: *)
52-
(* let epochs = 2 in *)
52+
let epochs = 2 in
5353
(* let epochs = 1 in *)
5454
(* let init_lr = 0.1 in *)
5555
let init_lr = 0.01 in
@@ -78,7 +78,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
7878
let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in
7979
let start_time = ref None in
8080
let weight_decay = 0.0002 in
81-
Arrayjit.Backends.sync_suggested_num_streams := num_devices;
81+
Arrayjit.Backends.sync_suggested_num_streams := num_streams;
8282
let backend = Arrayjit.Backends.fresh_backend ~backend_name () in
8383
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
8484
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
@@ -90,8 +90,17 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
9090
in
9191
let module Backend = (val backend) in
9292
Backend.initialize Train.BT.Most_parallel_streams;
93-
let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
94-
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams:num_devices ~data_len
93+
let {
94+
Train.inputs;
95+
outputs;
96+
model_result;
97+
infer_callback;
98+
batch_losses;
99+
epoch_losses;
100+
learning_rates;
101+
used_memory;
102+
} =
103+
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams:num_streams ~data_len
95104
~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay
96105
~per_batch_callback ~per_epoch_callback
97106
(module Backend)
@@ -161,8 +170,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
161170
{
162171
bench_title;
163172
time_in_sec;
164-
(* FIXME: implement total mem assessment. *)
165-
mem_in_bytes = 0;
173+
mem_in_bytes = used_memory;
166174
result_label = "init time in sec, min loss, last loss";
167175
result =
168176
[%sexp_of: float * float * float]
@@ -176,11 +184,11 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
176184

177185
let _suspend () =
178186
ignore
179-
@@ classify_moons ~seed:0 ~on_device:true ~inlining_cutoff:3 ~num_devices:8 ~batch_size:16
187+
@@ classify_moons ~seed:0 ~on_device:true ~inlining_cutoff:3 ~num_streams:8 ~batch_size:16
180188
~backend_name:"gccjit" ~value_prec:CDSL.single ~grad_prec:CDSL.double ()
181189

182-
let cuda_benchmarks =
183-
List.concat_map [ 1; 3; 6; 12; 16; 20 (* 32; 64 *) ] ~f:(fun num_devices ->
190+
let _cuda_benchmarks =
191+
List.concat_map [ 1; 3; 6; 12; 16; 20 (* 32; 64 *) ] ~f:(fun num_streams ->
184192
List.concat_map
185193
[
186194
(* TINY for debugging: *)
@@ -194,7 +202,26 @@ let cuda_benchmarks =
194202
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
195203
~f:(fun value_prec ->
196204
[
197-
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices
205+
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
206+
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
207+
]))))))
208+
209+
let _mem_benchmarks =
210+
List.concat_map [ 1; 3; 6; 12; 16 (* ; 20; 32; 64 *) ] ~f:(fun num_streams ->
211+
List.concat_map
212+
[
213+
(* TINY for debugging: *)
214+
(* 3 * 2 *)
215+
3 * 5 * 16 (* ; 3 * 5 * 32; 3 * 5 * 64 *);
216+
]
217+
~f:(fun batch_size ->
218+
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
219+
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
220+
List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ] ~f:(fun backend_name ->
221+
List.concat_map [ CDSL.double; CDSL.single (* ; CDSL.half *) ]
222+
~f:(fun value_prec ->
223+
[
224+
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
198225
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
199226
]))))))
200227

@@ -204,7 +231,7 @@ let cuda_benchmarks =
204231
(nth - 1) *)
205232

206233
let fixed_seed_search seed =
207-
classify_moons ~seed ~on_device:true ~inlining_cutoff:3 ~num_devices:1 ~batch_size:20
234+
classify_moons ~seed ~on_device:true ~inlining_cutoff:3 ~num_streams:1 ~batch_size:20
208235
~backend_name:"cuda" ~value_prec:CDSL.single ~grad_prec:CDSL.single ()
209236

210237
let _suspended () =
@@ -213,8 +240,8 @@ let _suspended () =
213240
(* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output
214241
Stdio.stdout *)
215242

216-
let benchmark () =
217-
List.map cuda_benchmarks ~f:(fun bench -> bench ())
243+
let benchmark benchmarks =
244+
List.map benchmarks ~f:(fun bench -> bench ())
218245
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout
219246

220-
let () = benchmark ()
247+
let () = benchmark _mem_benchmarks

bin/moons_demo_parallel.ml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,16 @@ let experiment ~seed ~backend_name ~config () =
5656
epoch_loss
5757
in
5858
let module Backend = (val backend) in
59-
let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
59+
let {
60+
Train.inputs;
61+
outputs;
62+
model_result;
63+
infer_callback;
64+
batch_losses;
65+
epoch_losses;
66+
learning_rates;
67+
used_memory;
68+
} =
6069
Train.example_train_loop ~seed ~batch_size ~max_num_streams:(batch_size / 2) ~init_lr
6170
~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn
6271
~weight_decay ~per_batch_callback ~per_epoch_callback
@@ -68,7 +77,7 @@ let experiment ~seed ~backend_name ~config () =
6877
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
6978
Stdio.print_endline "\n******** mlp_result **********";
7079
Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 model_result;
71-
Stdio.printf "\n********\n%!";
80+
Stdio.printf "\n********\nUsed memory: %d\n%!" used_memory;
7281
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
7382
let plot_moons =
7483
let open PrintBox_utils in

lib/train.ml

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -405,43 +405,53 @@ let%track3_sexp parallel_update (type context)
405405

406406
let get_all_suggested_streams ?(max_num_streams : int option) (type device stream)
407407
(backend : (module Backend_type with type device = device and type stream = stream)) :
408-
stream array =
408+
device array * stream array =
409409
let max_num_streams = Option.value max_num_streams ~default:Int.max_value_30_bits in
410410
let module Backend =
411411
(val backend : Backend_type with type device = device and type stream = stream)
412412
in
413413
let num_devices = min max_num_streams @@ Backend.num_devices () in
414414
let devices = Array.init num_devices ~f:(fun ordinal -> Backend.get_device ~ordinal) in
415-
Array.folding_mapi devices ~init:0 ~f:(fun ordinal num_collected device ->
416-
let remaining_devices = num_devices - ordinal - 1 in
417-
let max_current = Backend.suggested_num_streams device in
418-
let take_current = min max_current @@ (max_num_streams - remaining_devices) in
419-
( num_collected + take_current,
420-
Array.init take_current ~f:(fun _subordinal -> Backend.new_stream device) ))
421-
|> Array.concat_map ~f:Fn.id
415+
let result =
416+
Array.folding_mapi devices ~init:0 ~f:(fun ordinal num_collected device ->
417+
let remaining_devices = num_devices - ordinal - 1 in
418+
let max_current = Backend.suggested_num_streams device in
419+
let take_current = min max_current @@ (max_num_streams - remaining_devices) in
420+
( num_collected + take_current,
421+
Array.init take_current ~f:(fun _subordinal -> Backend.new_stream device) ))
422+
|> Array.concat_map ~f:Fn.id
423+
in
424+
(devices, result)
422425

423426
let to_routine (type context) (module Backend : Backend_type with type context = context)
424427
(context : context) ?shared ?name bindings comp =
425428
Backend.link context @@ Backend.compile ?shared ?name bindings comp
426429

430+
type example_train_result = {
431+
inputs : Tensor.t;
432+
outputs : Tensor.t;
433+
model_result : Tensor.t;
434+
infer_callback : float array -> float array;
435+
(** Note: infer_callback is significantly less efficient than using the model via arrayjit. *)
436+
batch_losses : float list;
437+
epoch_losses : float list;
438+
learning_rates : float list;
439+
used_memory : int;
440+
}
441+
427442
let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init_lr ?lr_schedule
428443
?(copy_to_merge = false) ?max_num_streams ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn
429444
~weight_decay ?per_batch_callback ?per_epoch_callback (type context)
430-
?(prior_contexts : context array option)
431445
(backend : (module Backend_type with type context = context)) () =
432446
let module TDSL = Operation.TDSL in
433447
let module NTDSL = Operation.NTDSL in
434448
Rand.init seed;
435449
let module Backend = (val backend : Backend_type with type context = context) in
436-
let prior_contexts =
437-
match prior_contexts with
438-
| Some contexts -> contexts
439-
| None ->
440-
let devices = get_all_suggested_streams ?max_num_streams (module Backend) in
441-
Array.map devices ~f:Backend.init
442-
in
443-
let num_devices = Array.length prior_contexts in
444-
let minibatch_size = batch_size / num_devices in
450+
let devices, streams = get_all_suggested_streams ?max_num_streams (module Backend) in
451+
let num_streams = Array.length streams in
452+
let contexts = Array.map streams ~f:Backend.init in
453+
let init_mem = Array.fold devices ~init:0 ~f:(fun acc dev -> acc + Backend.get_used_memory dev) in
454+
let minibatch_size = batch_size / num_streams in
445455
let n_minibatches = data_len / minibatch_size in
446456
let inputs = inputs ~b:[ n_minibatches; minibatch_size ] in
447457
let outputs = outputs ~b:[ n_minibatches; minibatch_size ] in
@@ -468,7 +478,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
468478
set_hosted learning_rate.value;
469479
let sgd = sgd_update ~learning_rate ~weight_decay update in
470480
let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in
471-
let grad_updates = Array.map prior_contexts ~f:(fun ctx -> Backend.link ctx grad_update) in
481+
let grad_updates = Array.map contexts ~f:(fun ctx -> Backend.link ctx grad_update) in
472482
let sgd_update = to_routine (module Backend) grad_updates.(0).context bindings sgd in
473483
Tensor.log_debug_info ~from_log_level:2 inputs;
474484
Tensor.log_debug_info ~from_log_level:2 outputs;
@@ -534,8 +544,19 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
534544
Backend.(await @@ get_ctx_stream routine.context);
535545
Tensor.get_values model_result
536546
in
537-
(* Note: infer_callback is significantly less efficient than using the model via arrayjit. *)
538-
(inputs, outputs, model_result, infer_callback, !batch_losses, !epoch_losses, !learning_rates)
547+
let used_memory =
548+
Array.fold devices ~init:0 ~f:(fun acc dev -> acc + Backend.get_used_memory dev) - init_mem
549+
in
550+
{
551+
inputs;
552+
outputs;
553+
model_result;
554+
infer_callback;
555+
batch_losses = !batch_losses;
556+
epoch_losses = !epoch_losses;
557+
learning_rates = !learning_rates;
558+
used_memory;
559+
}
539560

540561
let%track3_sexp forward_and_ctx ?(disable_rootness_check = false) (type context)
541562
(module Backend : Backend_type with type context = context) ctx ?(bindings = IDX.empty) t =

test/moons_demo_parallel.ml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,16 @@ let main () =
5454
epoch_loss
5555
in
5656
let module Backend = (val backend) in
57-
let inputs, outputs, _model_result, infer_callback, _batch_losses, _epoch_losses, _learning_rates
58-
=
57+
let {
58+
Train.inputs;
59+
outputs;
60+
model_result = _;
61+
infer_callback;
62+
batch_losses = _;
63+
epoch_losses = _;
64+
learning_rates = _;
65+
used_memory = _;
66+
} =
5967
Train.example_train_loop ~seed ~batch_size ~max_num_streams:(batch_size / 2) ~init_lr
6068
~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn
6169
~weight_decay ~per_batch_callback ~per_epoch_callback

test/moons_demo_parallel_run.ml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,16 @@ let main () =
5353
epoch_loss
5454
in
5555
let module Backend = (val backend) in
56-
let inputs, outputs, _model_result, infer_callback, _batch_losses, _epoch_losses, _learning_rates
57-
=
56+
let {
57+
Train.inputs;
58+
outputs;
59+
model_result = _;
60+
infer_callback;
61+
batch_losses = _;
62+
epoch_losses = _;
63+
learning_rates = _;
64+
used_memory = _;
65+
} =
5866
Train.example_train_loop ~seed ~batch_size ~max_num_streams:(batch_size / 2) ~init_lr
5967
~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn
6068
~weight_decay ~per_batch_callback ~per_epoch_callback

0 commit comments

Comments
 (0)