@@ -163,9 +163,9 @@ let _suspended () =
163163 let device = new_virtual_device @@ get_device ~ordinal: 0 in
164164 let update = Train. grad_update l in
165165 let routine = link (init device) @@ compile IDX. empty @@ update.fwd_bprop in
166- Tensor. iter_outputs l ~f: (fun a -> ignore (from_host routine.context a : bool ));
166+ Tensor. iter_embedded l ~f: (fun a -> ignore (from_host routine.context a : bool ));
167167 Train. run routine;
168- Tensor. iter_outputs l ~f: (fun a -> ignore (to_host routine.context a : bool ));
168+ Tensor. iter_embedded l ~f: (fun a -> ignore (to_host routine.context a : bool ));
169169 await device;
170170 Stdio. print_endline
171171 {|
@@ -177,7 +177,7 @@ let _suspended () =
177177 link routine.context @@ compile IDX. empty @@ Train. sgd_update ~learning_rate update
178178 in
179179 (* learning_rate is virtual so this will not print anything. *)
180- Tensor. iter_outputs learning_rate ~f: (fun a ->
180+ Tensor. iter_embedded learning_rate ~f: (fun a ->
181181 ignore (from_host routine.context a : bool ));
182182 Stdio. print_endline
183183 {|
@@ -187,7 +187,7 @@ let _suspended () =
187187 List. iter [ a.value; b.value; c.value; f.value ] ~f: (fun a ->
188188 assert (from_host routine.context a));
189189 Train. run routine;
190- Tensor. iter_outputs l ~f: (fun a -> ignore (to_host routine.context a : bool ));
190+ Tensor. iter_embedded l ~f: (fun a -> ignore (to_host routine.context a : bool ));
191191 await device;
192192 Stdio. print_endline
193193 {|
@@ -198,7 +198,7 @@ let _suspended () =
198198 let update = Train. grad_update l in
199199 let routine = link routine.context @@ compile IDX. empty update.fwd_bprop in
200200 Train. run routine;
201- Tensor. iter_outputs l ~f: (fun a -> ignore (to_host routine.context a : bool ));
201+ Tensor. iter_embedded l ~f: (fun a -> ignore (to_host routine.context a : bool ));
202202 await device;
203203 Stdio. print_endline
204204 {|
0 commit comments