@@ -73,8 +73,8 @@ let set_hosted (a : Tn.t) =
7373
7474(* * Sets the tensor's value as "fully on host", returns the tensor's forward code with a
7575 label-derived comment. *)
76- let forward ?( disable_rootness_check = false ) t =
77- let fwd = if disable_rootness_check then t. Tensor. forward else Tensor. consume_forward_code t in
76+ let forward t =
77+ let fwd = Tensor. consume_forward_code t in
7878 set_hosted t.Tensor. value;
7979 let label = Tn. debug_name t.value in
8080 { fwd with asgns = Asgns. Block_comment (label ^ " fwd" , fwd.asgns) }
@@ -101,22 +101,16 @@ let grad_update_nochecks loss =
101101(* * Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived
102102 comments. Sets the tensor's value as "fully on host". If [setup_for_parallel] is true (false by
103103 default), sets the parameters and their gradients as "non-local" (on-device). *)
104- let grad_update ?(disable_rootness_check = false ) ?( setup_for_parallel = false ) loss =
104+ let grad_update ?(setup_for_parallel = false ) loss =
105105 set_hosted loss.Tensor. value;
106106 if setup_for_parallel then
107107 Set. iter loss.Tensor. params ~f: (fun p ->
108108 set_materialized (Option. value_exn ~here: [% here] p.diff).grad);
109- let fwd =
110- if disable_rootness_check then loss.Tensor. forward else Tensor. consume_forward_code loss
111- in
112- let diff = diff_or_error loss " Train.grad_update" in
113- let zero_grads, bprop =
114- if disable_rootness_check then (diff.zero_grads, diff.backprop)
115- else Tensor. consume_backprop_code loss
116- in
109+ let fwd = Tensor. consume_forward_code loss in
110+ let zero_grads, bprop = Tensor. consume_backprop_code loss in
117111 (* Note: the %cd syntax for [loss.grad] does not modify roots. *)
118112 [% cd
119- ~~ (loss " gradient update" ;
113+ ~~ (loss " gradient update for " loss ;
120114 ~~ (loss " fwd" ;
121115 fwd);
122116 ~~ (loss " zero grads" ;
@@ -391,9 +385,9 @@ type example_train_result = {
391385 used_memory : int ;
392386}
393387
394- let example_train_loop ?( disable_rootness_check = false ) ~seed ~batch_size ~init_lr ?lr_schedule
395- ?( copy_to_merge = false ) ? max_num_streams ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn
396- ~ weight_decay ?per_batch_callback ?per_epoch_callback ?(per_epoch_debug_streams = false )
388+ let example_train_loop ~seed ~batch_size ~init_lr ?lr_schedule ?( copy_to_merge = false )
389+ ?max_num_streams ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn ~ weight_decay
390+ ?per_batch_callback ?per_epoch_callback ?(per_epoch_debug_streams = false )
397391 (module Backend : Backend ) () =
398392 let module TDSL = Operation. TDSL in
399393 let module NTDSL = Operation. NTDSL in
@@ -421,7 +415,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
421415 let learning_rates = ref [] in
422416 let % op loss_tensor = loss_fn ~output: (model input) ~expectation in
423417 let % op scalar_loss = (loss_tensor ++ " ...|... => 0" ) /. ! ..batch_size in
424- let update = grad_update ~disable_rootness_check ~ setup_for_parallel:true scalar_loss in
418+ let update = grad_update ~setup_for_parallel: true scalar_loss in
425419 (* Define learning_rate after scalar_loss is compiled, to not trigger rootness sanitizer. *)
426420 let % op learning_rate =
427421 match lr_schedule with
@@ -487,11 +481,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
487481 done ;
488482 (* Using %cd instead of %op to avoid being asked to initialize [infer]. *)
489483 let % cd model_result = model " infer_input" in
490- let infer_fwd =
491- if disable_rootness_check then model_result.Tensor. forward
492- else Tensor. consume_forward_code model_result
493- in
494- if not disable_rootness_check then Tensor. remove_bprop_root model_result;
484+ let infer_fwd = Tensor. consume_forward_code model_result in
485+ Tensor. remove_bprop_root model_result;
495486 set_on_host model_result.Tensor. value;
496487 (* By using sgd_update.context, maybe we don't need to copy the parameters back to the host. *)
497488 let routine =
@@ -529,7 +520,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
529520 [reinit_all] is true (false by default), all parameters are reinitialized, otherwise only the
530521 parameters that are not in [ctx.ctx_arrays] are initialized. *)
531522let % track3_sexp run_once ?(hosted = true ) ?(skip_init = false ) ?reinit_all
532- ?(disable_rootness_check = false ) (type buffer_ptr dev runner event optimize_ctx)
523+ (type buffer_ptr dev runner event optimize_ctx)
533524 (module Backend : Backend
534525 with type buffer_ptr = buffer_ptr
535526 and type dev = dev
@@ -539,7 +530,7 @@ let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all
539530 (* TODO: this will get nicer with modular explicits. *)
540531 if hosted then set_hosted t.Tensor. value;
541532 (* Compute the update early, to ensure the shape inference is done. *)
542- let update = f ~disable_rootness_check t in
533+ let update = f t in
543534 let ctx =
544535 match ctx with
545536 | Some ctx -> ctx
@@ -549,45 +540,33 @@ let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all
549540 if skip_init || Set. is_empty t.params then ctx
550541 else init_params (module Backend ) ~ctx ~hosted ?reinit_all bindings t
551542 in
552- let routine =
553- Backend. (link ctx @@ compile ctx.optimize_ctx bindings update)
554- in
543+ let routine = Backend. (link ctx @@ compile ctx.optimize_ctx bindings update) in
555544 Task. run routine.schedule;
556545 routine.context
557546
558547(* * [forward_once] is a wrapper around {!run_once} that runs the forward code of [t]. *)
559- let forward_once ?hosted ?skip_init ?reinit_all ?(disable_rootness_check = false )
560- (type buffer_ptr dev runner event optimize_ctx )
548+ let forward_once ?hosted ?skip_init ?reinit_all (type buffer_ptr dev runner event optimize_ctx )
561549 (module Backend : Backend
562550 with type buffer_ptr = buffer_ptr
563551 and type dev = dev
564552 and type runner = runner
565553 and type optimize_ctx = optimize_ctx
566554 and type event = event ) ?ctx ?bindings t =
567- let ctx =
568- run_once ?hosted ?skip_init ?reinit_all
569- (module Backend )
570- ~f: (fun ~disable_rootness_check t -> forward ~disable_rootness_check t)
571- ~disable_rootness_check ?ctx ?bindings t
572- in
555+ let ctx = run_once ?hosted ?skip_init ?reinit_all (module Backend ) ~f: forward ?ctx ?bindings t in
573556 (* FIXME: this is going away soon. *)
574- if not disable_rootness_check then Tensor. remove_bprop_root t;
557+ Tensor. remove_bprop_root t;
575558 ctx
576559
577560(* * [update_once] is a wrapper around {!run_once} that runs the gradient update code of [t]: both
578561 forward and backprop. *)
579- let update_once ?hosted ?skip_init ?reinit_all ?(disable_rootness_check = false )
580- (type buffer_ptr dev runner event optimize_ctx )
562+ let update_once ?hosted ?skip_init ?reinit_all (type buffer_ptr dev runner event optimize_ctx )
581563 (module Backend : Backend
582564 with type buffer_ptr = buffer_ptr
583565 and type dev = dev
584566 and type runner = runner
585567 and type optimize_ctx = optimize_ctx
586568 and type event = event ) ?ctx ?bindings t =
587- run_once ?hosted ?skip_init ?reinit_all
588- (module Backend )
589- ~f: (fun ~disable_rootness_check t -> grad_update ~disable_rootness_check t)
590- ~disable_rootness_check ?ctx ?bindings t
569+ run_once ?hosted ?skip_init ?reinit_all (module Backend ) ~f: grad_update ?ctx ?bindings t
591570
592571(* * [printf] is a wrapper around {!Tensor.print} that assumes [~force:true], and by default sets
593572 [~with_code:false], [~with_grad:true], and [~style:`Default]. *)
0 commit comments