Skip to content

Commit 24f71f9

Browse files
committed
disable_rootness_check is going awaaaay
1 parent 9cd1261 commit 24f71f9

File tree

2 files changed

+20
-50
lines changed

2 files changed

+20
-50
lines changed

bin/zero2hero_1of7.ml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,6 @@ let _suspended () =
149149
{|
150150
Now we updated the params, but after the forward and backward passes:
151151
only params values will change, compared to the above.|};
152-
Train.printf_tree ~with_grad:true ~depth:9 l;
153-
(* We could reuse the jitted code if we did not use `update_once`. [disable_rootness_check:true]
154-
because it's not once, it's twice. *)
155-
ignore (Train.update_once ~disable_rootness_check:true (module Backend) l);
156-
Stdio.print_endline
157-
{|
158-
Now again we did not update the params, they will remain as above, but both param
159-
gradients and the values and gradients of other nodes will change thanks to the forward and
160-
backward passes.|};
161152
Train.printf_tree ~with_grad:true ~depth:9 l
162153

163154
let () =

lib/train.ml

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ let set_hosted (a : Tn.t) =
7373

7474
(** Sets the tensor's value as "fully on host", returns the tensor's forward code with a
7575
label-derived comment. *)
76-
let forward ?(disable_rootness_check = false) t =
77-
let fwd = if disable_rootness_check then t.Tensor.forward else Tensor.consume_forward_code t in
76+
let forward t =
77+
let fwd = Tensor.consume_forward_code t in
7878
set_hosted t.Tensor.value;
7979
let label = Tn.debug_name t.value in
8080
{ fwd with asgns = Asgns.Block_comment (label ^ " fwd", fwd.asgns) }
@@ -101,22 +101,16 @@ let grad_update_nochecks loss =
101101
(** Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived
102102
comments. Sets the tensor's value as "fully on host". If [setup_for_parallel] is true (false by
103103
default), sets the parameters and their gradients as "non-local" (on-device). *)
104-
let grad_update ?(disable_rootness_check = false) ?(setup_for_parallel = false) loss =
104+
let grad_update ?(setup_for_parallel = false) loss =
105105
set_hosted loss.Tensor.value;
106106
if setup_for_parallel then
107107
Set.iter loss.Tensor.params ~f:(fun p ->
108108
set_materialized (Option.value_exn ~here:[%here] p.diff).grad);
109-
let fwd =
110-
if disable_rootness_check then loss.Tensor.forward else Tensor.consume_forward_code loss
111-
in
112-
let diff = diff_or_error loss "Train.grad_update" in
113-
let zero_grads, bprop =
114-
if disable_rootness_check then (diff.zero_grads, diff.backprop)
115-
else Tensor.consume_backprop_code loss
116-
in
109+
let fwd = Tensor.consume_forward_code loss in
110+
let zero_grads, bprop = Tensor.consume_backprop_code loss in
117111
(* Note: the %cd syntax for [loss.grad] does not modify roots. *)
118112
[%cd
119-
~~(loss "gradient update";
113+
~~(loss "gradient update for" loss;
120114
~~(loss "fwd";
121115
fwd);
122116
~~(loss "zero grads";
@@ -391,9 +385,9 @@ type example_train_result = {
391385
used_memory : int;
392386
}
393387

394-
let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init_lr ?lr_schedule
395-
?(copy_to_merge = false) ?max_num_streams ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn
396-
~weight_decay ?per_batch_callback ?per_epoch_callback ?(per_epoch_debug_streams = false)
388+
let example_train_loop ~seed ~batch_size ~init_lr ?lr_schedule ?(copy_to_merge = false)
389+
?max_num_streams ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn ~weight_decay
390+
?per_batch_callback ?per_epoch_callback ?(per_epoch_debug_streams = false)
397391
(module Backend : Backend) () =
398392
let module TDSL = Operation.TDSL in
399393
let module NTDSL = Operation.NTDSL in
@@ -421,7 +415,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
421415
let learning_rates = ref [] in
422416
let%op loss_tensor = loss_fn ~output:(model input) ~expectation in
423417
let%op scalar_loss = (loss_tensor ++ "...|... => 0") /. !..batch_size in
424-
let update = grad_update ~disable_rootness_check ~setup_for_parallel:true scalar_loss in
418+
let update = grad_update ~setup_for_parallel:true scalar_loss in
425419
(* Define learning_rate after scalar_loss is compiled, to not trigger rootness sanitizer. *)
426420
let%op learning_rate =
427421
match lr_schedule with
@@ -487,11 +481,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
487481
done;
488482
(* Using %cd instead of %op to avoid being asked to initialize [infer]. *)
489483
let%cd model_result = model "infer_input" in
490-
let infer_fwd =
491-
if disable_rootness_check then model_result.Tensor.forward
492-
else Tensor.consume_forward_code model_result
493-
in
494-
if not disable_rootness_check then Tensor.remove_bprop_root model_result;
484+
let infer_fwd = Tensor.consume_forward_code model_result in
485+
Tensor.remove_bprop_root model_result;
495486
set_on_host model_result.Tensor.value;
496487
(* By using sgd_update.context, maybe we don't need to copy the parameters back to the host. *)
497488
let routine =
@@ -529,7 +520,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
529520
[reinit_all] is true (false by default), all parameters are reinitialized, otherwise only the
530521
parameters that are not in [ctx.ctx_arrays] are initialized. *)
531522
let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all
532-
?(disable_rootness_check = false) (type buffer_ptr dev runner event optimize_ctx)
523+
(type buffer_ptr dev runner event optimize_ctx)
533524
(module Backend : Backend
534525
with type buffer_ptr = buffer_ptr
535526
and type dev = dev
@@ -539,7 +530,7 @@ let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all
539530
(* TODO: this will get nicer with modular explicits. *)
540531
if hosted then set_hosted t.Tensor.value;
541532
(* Compute the update early, to ensure the shape inference is done. *)
542-
let update = f ~disable_rootness_check t in
533+
let update = f t in
543534
let ctx =
544535
match ctx with
545536
| Some ctx -> ctx
@@ -549,45 +540,33 @@ let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all
549540
if skip_init || Set.is_empty t.params then ctx
550541
else init_params (module Backend) ~ctx ~hosted ?reinit_all bindings t
551542
in
552-
let routine =
553-
Backend.(link ctx @@ compile ctx.optimize_ctx bindings update)
554-
in
543+
let routine = Backend.(link ctx @@ compile ctx.optimize_ctx bindings update) in
555544
Task.run routine.schedule;
556545
routine.context
557546

558547
(** [forward_once] is a wrapper around {!run_once} that runs the forward code of [t]. *)
559-
let forward_once ?hosted ?skip_init ?reinit_all ?(disable_rootness_check = false)
560-
(type buffer_ptr dev runner event optimize_ctx)
548+
let forward_once ?hosted ?skip_init ?reinit_all (type buffer_ptr dev runner event optimize_ctx)
561549
(module Backend : Backend
562550
with type buffer_ptr = buffer_ptr
563551
and type dev = dev
564552
and type runner = runner
565553
and type optimize_ctx = optimize_ctx
566554
and type event = event) ?ctx ?bindings t =
567-
let ctx =
568-
run_once ?hosted ?skip_init ?reinit_all
569-
(module Backend)
570-
~f:(fun ~disable_rootness_check t -> forward ~disable_rootness_check t)
571-
~disable_rootness_check ?ctx ?bindings t
572-
in
555+
let ctx = run_once ?hosted ?skip_init ?reinit_all (module Backend) ~f:forward ?ctx ?bindings t in
573556
(* FIXME: this is going away soon. *)
574-
if not disable_rootness_check then Tensor.remove_bprop_root t;
557+
Tensor.remove_bprop_root t;
575558
ctx
576559

577560
(** [update_once] is a wrapper around {!run_once} that runs the gradient update code of [t]: both
578561
forward and backprop. *)
579-
let update_once ?hosted ?skip_init ?reinit_all ?(disable_rootness_check = false)
580-
(type buffer_ptr dev runner event optimize_ctx)
562+
let update_once ?hosted ?skip_init ?reinit_all (type buffer_ptr dev runner event optimize_ctx)
581563
(module Backend : Backend
582564
with type buffer_ptr = buffer_ptr
583565
and type dev = dev
584566
and type runner = runner
585567
and type optimize_ctx = optimize_ctx
586568
and type event = event) ?ctx ?bindings t =
587-
run_once ?hosted ?skip_init ?reinit_all
588-
(module Backend)
589-
~f:(fun ~disable_rootness_check t -> grad_update ~disable_rootness_check t)
590-
~disable_rootness_check ?ctx ?bindings t
569+
run_once ?hosted ?skip_init ?reinit_all (module Backend) ~f:grad_update ?ctx ?bindings t
591570

592571
(** [printf] is a wrapper around {!Tensor.print} that assumes [~force:true], and by default sets
593572
[~with_code:false], [~with_grad:true], and [~style:`Default]. *)

0 commit comments

Comments
 (0)