@@ -25,7 +25,7 @@ let _suspended () =
2525 Train. every_non_literal_on_host v;
2626 let code = Train. grad_update v in
2727 let routine = Train. to_routine (module Backend ) ctx IDX. empty code.fwd_bprop in
28- Train. sync_run ( module Backend ) routine v ;
28+ Train. run routine;
2929 Stdio. printf " \n %!" ;
3030 Tensor. print_tree ~with_id: true ~with_grad: true ~depth: 9 v;
3131 Stdlib.Format. printf " \n High-level code:\n %!" ;
@@ -47,7 +47,7 @@ let _suspended () =
4747 Tensor. print_tree ~with_grad: false ~depth: 9 f5;
4848 Stdio. printf " \n %!"
4949
50- let () =
50+ let _suspended () =
5151 (* FIXME: why is this toplevel example broken and the next one working? *)
5252 Utils. settings.output_debug_files_in_build_directory < - true ;
5353 Rand. init 0 ;
@@ -75,14 +75,12 @@ let () =
7575 let step_ref = IDX. find_exn routine.bindings step_sym in
7676 let ys = Array. create ~len: size 0. and dys = Array. create ~len: size 0. in
7777 let open Operation.At in
78- let looping () =
79- assert (Backend. to_host routine.context fx.value);
80- assert (Backend. to_host routine.context (Option. value_exn ~here: [% here] x.diff).grad);
81- Backend. await stream;
78+ let f () =
79+ Train. run routine;
8280 ys.(! step_ref) < - fx.@ [0 ];
8381 dys.(! step_ref) < - x.@% [0 ]
8482 in
85- Train. sync_run ~looping ( module Backend ) routine fx ;
83+ Train. sequential_loop routine.bindings ~f ;
8684 let plot_box =
8785 let open PrintBox_utils in
8886 plot ~size: (75 , 35 ) ~x_label: " x" ~y_label: " f(x)"
@@ -101,13 +99,6 @@ let _suspended () =
10199 (* Utils.settings.debug_log_from_routines <- true; *)
102100 Rand. init 0 ;
103101 let module Backend = (val Arrayjit.Backends. fresh_backend () ) in
104- let backend =
105- (module Backend : Backend
106- with type buffer_ptr = Backend. buffer_ptr
107- and type dev = Backend. dev
108- and type runner = Backend. runner
109- and type event = Backend. event)
110- in
111102 let stream = Backend. (new_stream @@ get_device ~ordinal: 0 ) in
112103 let ctx = Backend. make_context stream in
113104 let open Operation.At in
@@ -138,7 +129,7 @@ let _suspended () =
138129 Array. unzip
139130 @@ Array. mapi xs ~f: (fun i _ ->
140131 step_ref := i;
141- Train. sync_run backend fx_routine fx ;
132+ Train. run fx_routine;
142133 (fx.@ [0 ], x.@% [0 ]))
143134 in
144135 (* It is fine to loop around the data: it's "next epoch". We redo the work though. *)
@@ -155,7 +146,7 @@ let _suspended () =
155146 in
156147 ()
157148
158- let _suspended () =
149+ let () =
159150 Rand. init 0 ;
160151 Utils. set_log_level 2 ;
161152 Utils. settings.output_debug_files_in_build_directory < - true ;
@@ -172,8 +163,8 @@ let _suspended () =
172163 in
173164 Tensor. iter_embedded l ~f: (fun a -> ignore (Backend. from_host routine.context a : bool ));
174165 Train. run routine;
175- Tensor. iter_embedded l ~f: (fun a -> ignore (Backend. to_host routine.context a : bool ));
176- Backend. await stream;
166+ (* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
167+ Backend.await stream; *)
177168 Stdio. print_endline
178169 {|
179170 We did not update the params : all values and gradients will be at initial points ,
@@ -195,8 +186,8 @@ let _suspended () =
195186 List. iter [ a.value; b.value; c.value; f.value ] ~f: (fun a ->
196187 assert (Backend. from_host routine.context a));
197188 Train. run routine;
198- Tensor. iter_embedded l ~f: (fun a -> ignore (Backend. to_host routine.context a : bool ));
199- Backend. await stream;
189+ (* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
190+ Backend.await stream; *)
200191 Stdio. print_endline
201192 {|
202193 Now we updated the params, but after the forward and backward passes :
@@ -206,8 +197,8 @@ let _suspended () =
206197 let update = Train. grad_update l in
207198 let routine = Train. to_routine (module Backend ) routine.context IDX. empty update.fwd_bprop in
208199 Train. run routine;
209- Tensor. iter_embedded l ~f: (fun a -> ignore (Backend. to_host routine.context a : bool ));
210- Backend. await stream;
200+ (* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
201+ Backend.await stream; *)
211202 Stdio. print_endline
212203 {|
213204 Now again we did not update the params, they will remain as above, but both param
0 commit comments