Skip to content

Commit e22b6f9

Browse files
committed
Missing param initializations: zero2hero_1of7 test
1 parent eea2724 commit e22b6f9

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

test/operations/zero2hero_1of7.ml

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ let%expect_test "Graph drawing recompile" =
5151
|}];
5252
let%op f = (3 *. ("x" [ 5 ] **. 2)) - (4 *. x) + 5 in
5353
Train.every_non_literal_on_host f;
54+
let f_init = Train.to_routine (module Backend) ctx IDX.empty @@ Tensor.init_params f in
5455
let f_upd = Train.grad_update f in
55-
let f_bprop = Train.to_routine (module Backend) ctx IDX.empty f_upd in
56+
let f_bprop = Train.to_routine (module Backend) f_init.context IDX.empty f_upd in
57+
Train.run f_init;
5658
Train.run f_bprop;
5759
Tensor.print_tree ~with_grad:true ~depth:9 f;
5860
[%expect
@@ -267,9 +269,10 @@ let%expect_test "Simple gradients hosted" =
267269
Train.every_non_literal_on_host l;
268270
Train.every_non_literal_on_host learning_rate;
269271
let sgd = Train.sgd_update ~learning_rate l in
270-
let grad_routine = Train.to_routine (module Backend) ctx IDX.empty grad in
272+
let f_init = Train.to_routine (module Backend) ctx IDX.empty @@ Tensor.init_params l in
273+
let grad_routine = Train.to_routine (module Backend) f_init.context IDX.empty grad in
271274
let sgd_routine = Train.to_routine (module Backend) grad_routine.context IDX.empty sgd in
272-
(* Check out the initial state without running a forward pass. *)
275+
(* Check out the initial state without running an init or forward pass. *)
273276
Tensor.print_tree ~spy:true ~with_grad:true ~depth:9 l;
274277
[%expect
275278
{|
@@ -292,6 +295,7 @@ let%expect_test "Simple gradients hosted" =
292295
|}];
293296
(* Do not update the params: all values and gradients will be at initial points, which are
294297
specified in the tensor in the brackets. *)
298+
Train.run f_init;
295299
Train.run grad_routine;
296300
Tensor.print_tree ~with_grad:true ~depth:9 l;
297301
[%expect
@@ -398,7 +402,8 @@ let%expect_test "Simple gradients virtual" =
398402
#1 grad_a Material/28 │#3 grad_b Material/28 │ │
399403
<not-in-yet><not-in-yet> │ │
400404
|}];
401-
let grad_routine = Train.to_routine (module Backend) ctx IDX.empty grad in
405+
let f_init = Train.to_routine (module Backend) ctx IDX.empty @@ Tensor.init_params l in
406+
let grad_routine = Train.to_routine (module Backend) f_init.context IDX.empty grad in
402407
(* Check out the state without running a forward pass or compiling the SGD update. *)
403408
Tensor.print_tree ~spy:true ~with_grad:true ~depth:9 l;
404409
[%expect
@@ -422,6 +427,7 @@ let%expect_test "Simple gradients virtual" =
422427
|}];
423428
(* Do not update the params: all values and gradients will be at initial points, which are
424429
specified in the tensor in the brackets. *)
430+
Train.run f_init;
425431
Train.run grad_routine;
426432
Tensor.print_tree ~with_grad:true ~depth:9 l;
427433
[%expect
@@ -507,7 +513,9 @@ let%expect_test "2D neuron hosted" =
507513
let%op v = ("w" [ (-3, 1) ] * "x" [ 2; 0 ]) + "b" [ 6.7 ] in
508514
Train.every_non_literal_on_host v;
509515
let update = Train.grad_update v in
510-
let routine = Train.to_routine (module Backend) ctx IDX.empty update in
516+
let f_init = Train.to_routine (module Backend) ctx IDX.empty @@ Tensor.init_params v in
517+
let routine = Train.to_routine (module Backend) f_init.context IDX.empty update in
518+
Train.run f_init;
511519
Train.run routine;
512520
Tensor.print_tree ~with_grad:true ~depth:9 v;
513521
[%expect
@@ -534,7 +542,9 @@ let%expect_test "2D neuron virtual" =
534542
let ctx = Backend.make_context stream in
535543
let%op v = ("w" [ (-3, 1) ] * "x" [ 2; 0 ]) + "b" [ 6.7 ] in
536544
let update = Train.grad_update v in
537-
let routine = Train.to_routine (module Backend) ctx IDX.empty update in
545+
let f_init = Train.to_routine (module Backend) ctx IDX.empty @@ Tensor.init_params v in
546+
let routine = Train.to_routine (module Backend) f_init.context IDX.empty update in
547+
Train.run f_init;
538548
Train.run routine;
539549
Tensor.print_tree ~with_grad:true ~depth:9 v;
540550
[%expect

0 commit comments

Comments
 (0)