@@ -74,6 +74,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
7474 (scalar_loss, 0.0 )
7575 in
7676 (* So that we can inspect them. *)
77+ let init_params = Tensor. init_params scalar_loss in
7778 let update = Train. grad_update scalar_loss in
7879 let % op learning_rate = 0.1 *. (! ..steps - ! @ step_n) /. ! ..steps in
7980 Train. set_hosted learning_rate.value;
@@ -82,6 +83,9 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
8283 let module Backend = (val Backends. fresh_backend () ) in
8384 let stream = Backend. (new_stream @@ get_device ~ordinal: 0 ) in
8485 let ctx = Backend. make_context stream in
86+ let init = Backend. link ctx @@ Backend. compile ctx.optimize_ctx IDX. empty init_params in
87+ let ctx = init.context in
88+ Train. run init;
8589 let routine = Train. to_routine (module Backend ) ctx bindings (Asgns. sequence [ update; sgd ]) in
8690 (* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true
8791 ~with_grad:false ~depth:9 scalar_loss; Stdio.print_endline "\n******** learning_rate
@@ -114,7 +118,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
114118 let points = Tn. points_2d ~xdim: 0 ~ydim: 1 moons_flat.value in
115119 let classes = Tn. points_1d ~xdim: 0 moons_classes.value in
116120 let points1, points2 = Array. partitioni_tf points ~f: Float. (fun i _ -> classes.(i) > 0. ) in
117- let % op mlp_result = mlp " point" in
121+ let % cd mlp_result = mlp " point" in
118122 Train. set_on_host mlp_result.value;
119123 (* By using jitted.context here, we don't need to copy the parameters back to the host. *)
120124 let result_routine =
0 commit comments