@@ -51,8 +51,10 @@ let%expect_test "Graph drawing recompile" =
5151 | }];
5252 let % op f = (3 *. (" x" [ 5 ] **. 2 )) - (4 *. x) + 5 in
5353 Train. every_non_literal_on_host f;
54+ let f_init = Train. to_routine (module Backend ) ctx IDX. empty @@ Tensor. init_params f in
5455 let f_upd = Train. grad_update f in
55- let f_bprop = Train. to_routine (module Backend ) ctx IDX. empty f_upd in
56+ let f_bprop = Train. to_routine (module Backend ) f_init.context IDX. empty f_upd in
57+ Train. run f_init;
5658 Train. run f_bprop;
5759 Tensor. print_tree ~with_grad: true ~depth: 9 f;
5860 [% expect
@@ -267,9 +269,10 @@ let%expect_test "Simple gradients hosted" =
267269 Train. every_non_literal_on_host l;
268270 Train. every_non_literal_on_host learning_rate;
269271 let sgd = Train. sgd_update ~learning_rate l in
270- let grad_routine = Train. to_routine (module Backend ) ctx IDX. empty grad in
272+ let f_init = Train. to_routine (module Backend ) ctx IDX. empty @@ Tensor. init_params l in
273+ let grad_routine = Train. to_routine (module Backend ) f_init.context IDX. empty grad in
271274 let sgd_routine = Train. to_routine (module Backend ) grad_routine.context IDX. empty sgd in
272- (* Check out the initial state without running a forward pass. *)
275+ (* Check out the initial state without running an init or forward pass. *)
273276 Tensor. print_tree ~spy: true ~with_grad: true ~depth: 9 l;
274277 [% expect
275278 {|
@@ -292,6 +295,7 @@ let%expect_test "Simple gradients hosted" =
292295 | }];
293296 (* Do not update the params: all values and gradients will be at initial points, which are
294297 specified in the tensor in the brackets. *)
298+ Train. run f_init;
295299 Train. run grad_routine;
296300 Tensor. print_tree ~with_grad: true ~depth: 9 l;
297301 [% expect
@@ -398,7 +402,8 @@ let%expect_test "Simple gradients virtual" =
398402 #1 grad_a Material / 28 │#3 grad_b Material / 28 │ │
399403 < not - in - yet> │< not - in - yet> │ │
400404 | }];
401- let grad_routine = Train. to_routine (module Backend ) ctx IDX. empty grad in
405+ let f_init = Train. to_routine (module Backend ) ctx IDX. empty @@ Tensor. init_params l in
406+ let grad_routine = Train. to_routine (module Backend ) f_init.context IDX. empty grad in
402407 (* Check out the state without running a forward pass or compiling the SGD update. *)
403408 Tensor. print_tree ~spy: true ~with_grad: true ~depth: 9 l;
404409 [% expect
@@ -422,6 +427,7 @@ let%expect_test "Simple gradients virtual" =
422427 | }];
423428 (* Do not update the params: all values and gradients will be at initial points, which are
424429 specified in the tensor in the brackets. *)
430+ Train. run f_init;
425431 Train. run grad_routine;
426432 Tensor. print_tree ~with_grad: true ~depth: 9 l;
427433 [% expect
@@ -507,7 +513,9 @@ let%expect_test "2D neuron hosted" =
507513 let % op v = (" w" [ (- 3 , 1 ) ] * " x" [ 2 ; 0 ]) + " b" [ 6.7 ] in
508514 Train. every_non_literal_on_host v;
509515 let update = Train. grad_update v in
510- let routine = Train. to_routine (module Backend ) ctx IDX. empty update in
516+ let f_init = Train. to_routine (module Backend ) ctx IDX. empty @@ Tensor. init_params v in
517+ let routine = Train. to_routine (module Backend ) f_init.context IDX. empty update in
518+ Train. run f_init;
511519 Train. run routine;
512520 Tensor. print_tree ~with_grad: true ~depth: 9 v;
513521 [% expect
@@ -534,7 +542,9 @@ let%expect_test "2D neuron virtual" =
534542 let ctx = Backend. make_context stream in
535543 let % op v = (" w" [ (- 3 , 1 ) ] * " x" [ 2 ; 0 ]) + " b" [ 6.7 ] in
536544 let update = Train. grad_update v in
537- let routine = Train. to_routine (module Backend ) ctx IDX. empty update in
545+ let f_init = Train. to_routine (module Backend ) ctx IDX. empty @@ Tensor. init_params v in
546+ let routine = Train. to_routine (module Backend ) f_init.context IDX. empty update in
547+ Train. run f_init;
538548 Train. run routine;
539549 Tensor. print_tree ~with_grad: true ~depth: 9 v;
540550 [% expect
0 commit comments