Skip to content

Commit c465f79

Browse files
committed
Automatically init in Train.forward_and_ctx / forward_and_forget; refactoring for bin/ examples
The examples are still often broken and will be audited after another round of refactoring.
1 parent fa884ef commit c465f79

File tree

7 files changed

+41
-14
lines changed

7 files changed

+41
-14
lines changed

bin/micrograd_basic.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ let%diagn_sexp () : unit =
5151
let%op g = g + (10. /. f) in
5252
List.iter ~f:(function Some diff -> Train.set_hosted diff.grad | None -> ()) [ a.diff; b.diff ];
5353
(* Train.every_non_literal_on_host g; *)
54+
let init_params = Tensor.init_params g in
5455
let update = Train.grad_update g in
55-
let routine = Train.to_routine (module Backend) ctx IDX.empty update in
56+
let init = Backend.link ctx @@ Backend.compile ctx.optimize_ctx IDX.empty init_params in
57+
let routine = Train.to_routine (module Backend) init.context IDX.empty update in
5658
Utils.capture_stdout_logs @@ fun () ->
59+
Train.run init;
5760
Train.run routine;
5861
(* Tensor.print_tree ~with_grad:true ~depth:9 g; *)
5962
Tensor.print ~with_code:false ~with_grad:false `Default @@ g;

bin/micrograd_demo.ml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
7474
(scalar_loss, 0.0)
7575
in
7676
(* So that we can inspect them. *)
77+
let init_params = Tensor.init_params scalar_loss in
7778
let update = Train.grad_update scalar_loss in
7879
let%op learning_rate = 0.1 *. (!..steps - !@step_n) /. !..steps in
7980
Train.set_hosted learning_rate.value;
@@ -82,6 +83,9 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
8283
let module Backend = (val Backends.fresh_backend ()) in
8384
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
8485
let ctx = Backend.make_context stream in
86+
let init = Backend.link ctx @@ Backend.compile ctx.optimize_ctx IDX.empty init_params in
87+
let ctx = init.context in
88+
Train.run init;
8589
let routine = Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update; sgd ]) in
8690
(* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true
8791
~with_grad:false ~depth:9 scalar_loss; Stdio.print_endline "\n******** learning_rate
@@ -114,7 +118,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
114118
let points = Tn.points_2d ~xdim:0 ~ydim:1 moons_flat.value in
115119
let classes = Tn.points_1d ~xdim:0 moons_classes.value in
116120
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
117-
let%op mlp_result = mlp "point" in
121+
let%cd mlp_result = mlp "point" in
118122
Train.set_on_host mlp_result.value;
119123
(* By using jitted.context here, we don't need to copy the parameters back to the host. *)
120124
let result_routine =

bin/micrograd_demo_logging.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ let () =
3030
let%op g = f /. 2 in
3131
let%op g = g + (10. /. f) in
3232
List.iter ~f:(Option.iter ~f:(fun diff -> Train.set_hosted diff.Tensor.grad)) [ a.diff; b.diff ];
33+
let init_params = Tensor.init_params g in
34+
let init = Backend.link ctx @@ Backend.compile ctx.optimize_ctx IDX.empty init_params in
35+
let ctx = init.context in
3336
let update = Train.grad_update g in
3437
let step = Train.to_routine (module Backend) ctx IDX.empty update in
3538
Utils.capture_stdout_logs @@ fun () ->
39+
Train.run init;
3640
Train.run step;
3741
Tensor.print ~with_code:false ~with_grad:false `Default g;
3842
Tensor.print ~with_code:false ~with_grad:true `Default a;

bin/moons_benchmark.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
8989
in
9090
Stdlib.Format.printf "Initial backend global debug info: %a\n%!" Sexp.pp_hum
9191
@@ Backend.get_global_debug_info ();
92-
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
93-
(* Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step
94-
learning_rate batch_loss epoch_loss; *)
92+
let per_batch_callback ~at_batch ~at_step ~learning_rate ~batch_loss ~epoch_loss =
93+
Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step
94+
learning_rate batch_loss epoch_loss;
9595
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
9696
in
9797
(* Tn.print_accessible_headers (); *)
9898
let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss =
99-
if at_epoch % 10 = 9 then
99+
(* if at_epoch % 10 = 9 then *)
100100
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
101101
epoch_loss
102102
in

bin/moons_demo.ml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ let demo () =
5757
Train.set_hosted learning_rate.value;
5858
let sgd = Train.sgd_update ~learning_rate ~weight_decay scalar_loss in
5959

60-
let module Backend = (val Backends.fresh_backend ~backend_name:"cuda" ()) in
60+
let module Backend = (val Backends.fresh_backend ~backend_name:"metal" ()) in
6161
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
6262
let ctx = Backend.make_context stream in
63+
let init_params = Tensor.init_params scalar_loss in
64+
let init = Backend.link ctx @@ Backend.compile ctx.optimize_ctx IDX.empty init_params in
65+
let ctx = init.context in
6366
let routine = Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update; sgd ]) in
6467

6568
let points = Tn.points_2d ~xdim:0 ~ydim:1 moons_flat.value in
@@ -81,6 +84,7 @@ let demo () =
8184
let batch_ref = IDX.find_exn routine.bindings batch_n in
8285
let epoch_loss = ref 0. in
8386
step_ref := 0;
87+
Train.run init;
8488
let%track_sexp _train_loop : unit =
8589
for epoch = 0 to epochs - 1 do
8690
for batch = 0 to n_batches - 1 do
@@ -95,7 +99,7 @@ let demo () =
9599
done
96100
in
97101

98-
let%op mlp_result = mlp "point" in
102+
let%cd mlp_result = mlp "point" in
99103
Train.set_on_host mlp_result.value;
100104
let result_routine =
101105
Train.to_routine

bin/zero2hero_1of7.ml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,13 @@ let () =
154154
Train.every_non_literal_on_host l;
155155
let module Backend = (val Backends.fresh_backend ()) in
156156
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
157+
let init_params = Tensor.init_params l in
157158
let update = Train.grad_update l in
158-
let routine = Train.to_routine (module Backend) (Backend.make_context stream) IDX.empty update in
159+
let ctx = Backend.make_context stream in
160+
let init = Backend.link ctx @@ Backend.compile ctx.optimize_ctx IDX.empty init_params in
161+
let ctx = init.context in
162+
let routine = Train.to_routine (module Backend) ctx IDX.empty update in
163+
Train.run init;
159164
Train.run routine;
160165
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
161166
Backend.await stream; *)
@@ -181,7 +186,6 @@ let () =
181186
only params values will change, compared to the above.|};
182187
Tensor.print_tree ~with_grad:true ~depth:9 l;
183188
(* We could reuse the jitted code if we did not use `jit_and_run`. *)
184-
let update = Train.grad_update l in
185189
let routine = Train.to_routine (module Backend) routine.context IDX.empty update in
186190
Train.run routine;
187191
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));

lib/train.ml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,22 +504,30 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
504504
}
505505

506506
(* Note: this will get nicer with modular explicits. *)
507-
let%track3_sexp forward_and_ctx ?(hosted = true) ?(disable_rootness_check = false)
508-
(type buffer_ptr dev runner event optimize_ctx)
507+
let%track3_sexp forward_and_ctx ?(hosted = true) ?(skip_init = false)
508+
?(disable_rootness_check = false) (type buffer_ptr dev runner event optimize_ctx)
509509
(module Backend : Backend
510510
with type buffer_ptr = buffer_ptr
511511
and type dev = dev
512512
and type runner = runner
513513
and type optimize_ctx = optimize_ctx
514514
and type event = event) ctx ?(bindings = IDX.empty) t =
515515
if hosted then set_hosted t.Tensor.value;
516+
let ctx =
517+
if skip_init || Set.is_empty t.params then ctx
518+
else
519+
let init_params = Tensor.init_params t in
520+
let init = Backend.link ctx @@ Backend.compile ctx.optimize_ctx bindings init_params in
521+
run init;
522+
init.context
523+
in
516524
let routine =
517525
Backend.(link ctx @@ compile ctx.optimize_ctx bindings @@ forward ~disable_rootness_check t)
518526
in
519527
if not disable_rootness_check then Tensor.remove_bprop_root t;
520528
Task.run routine.schedule;
521529
routine.context
522530

523-
let forward_and_forget ?hosted ?disable_rootness_check backend ctx ?bindings t =
531+
let forward_and_forget ?hosted ?skip_init ?disable_rootness_check backend ctx ?bindings t =
524532
(* FIXME: to properly forget we need to free the incrementally-allocated memory! *)
525-
ignore @@ forward_and_ctx ?hosted ?disable_rootness_check backend ctx ?bindings t
533+
ignore @@ forward_and_ctx ?hosted ?skip_init ?disable_rootness_check backend ctx ?bindings t

0 commit comments

Comments
 (0)