Skip to content

Commit 1a33588

Browse files
committed
Automated to_host transfers
1 parent f48985d commit 1a33588

17 files changed

+176
-261
lines changed

arrayjit/lib/backends.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
7070
Tn.prepare_read
7171
~is_done:(fun () -> Backend.is_done e)
7272
~sync:(fun () -> Backend.sync e)
73-
~transfer:(fun () -> assert (to_host ctx tn))
73+
~transfer:(fun () -> assert (to_host ctx tn); Backend.await s)
7474
tn);
7575
(* To be on the safe side, record events for potentially cross-stream nodes. *)
7676
match tn with

arrayjit/lib/tnode.ml

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -582,33 +582,45 @@ let find =
582582

583583
(** {2 Accessors} *)
584584

585+
let do_read tn =
586+
Option.iter
587+
~f:(fun p ->
588+
p.sync ();
589+
p.transfer ())
590+
tn.prepare_read;
591+
tn.prepare_read <- None
592+
593+
let do_write tn =
594+
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
595+
tn.prepare_write <- None
596+
585597
let points_1d ?from_axis ~xdim tn =
586-
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
598+
do_read tn;
587599
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_1d_points ?from_axis ~xdim arr)
588600
@@ Lazy.force tn.array
589601

590602
let points_2d ?from_axis ~xdim ~ydim tn =
591-
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
603+
do_read tn;
592604
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ~xdim ~ydim arr)
593605
@@ Lazy.force tn.array
594606

595607
let set_value tn =
596-
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
608+
do_write tn;
597609
Nd.set_from_float @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array
598610

599611
let get_value tn =
600-
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
612+
do_read tn;
601613
Nd.get_as_float @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array
602614

603615
let set_values tn values =
604-
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
616+
do_write tn;
605617
Nd.(
606618
reset (Constant_fill { values; strict = false })
607619
@@ Option.value_exn ~here:[%here]
608620
@@ Lazy.force tn.array)
609621

610622
let get_values tn =
611-
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
623+
do_read tn;
612624
Nd.(retrieve_flat_values @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array)
613625

614626
let print_accessible_headers () =

bin/compilation_speed.ml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ let benchmark_overhead backend () =
5151
Train.to_routine (module Backend) f_routine.context ~name:"assign_x" IDX.empty update_x
5252
in
5353
Train.run assign_x;
54-
(* await device; *)
5554
Train.run f_routine;
56-
assert (Backend.to_host f_routine.context f.value);
57-
Backend.await stream;
5855
f.@[0])
5956
in
6057
let plot_box =

bin/einsum_trivia.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ let _suspended () =
4747
let%op ho2 = hey2 ++ "ab|cd->ef => cf|ae->db" in
4848
Utils.capture_stdout_logs @@ fun () ->
4949
Train.forward_and_forget backend ctx ho2;
50-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ho2
50+
Tensor.print ~with_code:false ~with_grad:false `Default @@ ho2
5151

5252
let () =
5353
Utils.set_log_level 2;
@@ -67,16 +67,16 @@ let () =
6767
let a = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
6868
let b = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 4 ] ~output_dims:[ 5 ] () in
6969
let%op a2 = a *+ "b|i->o; b|i->o => b|i->o" a in
70-
Tensor.print ~force:false ~with_code:false ~with_grad:false `Default @@ a;
70+
Tensor.print ~spy:true ~with_code:false ~with_grad:false `Default @@ a;
7171
let ctx = Utils.capture_stdout_logs (fun () -> Train.forward_and_ctx backend ctx a2) in
72-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ a;
72+
Tensor.print ~with_code:false ~with_grad:false `Default @@ a;
7373
let%op c = b *+ "b|h->o; b|i->h => b|i->o" a in
7474
Utils.capture_stdout_logs (fun () -> Train.forward_and_forget backend ctx c);
75-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ a;
75+
Tensor.print ~with_code:false ~with_grad:false `Default @@ a;
7676
(* let%op d = a *+ "a|i->h; b|h->o => ab|i->o" b in Utils.capture_stdout_logs (fun () ->
7777
Train.forward_and_forget backend ctx d); let%op e = a *+ "b|i->h; b|h->o => i->o" b in
7878
Utils.capture_stdout_logs (fun () -> Train.forward_and_forget backend ctx e); let%op f = a *+
7979
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> Train.forward_and_forget
8080
backend ctx f); *)
81-
(* Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ a2; *)
82-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ c
81+
(* Tensor.print ~with_code:false ~with_grad:false `Default @@ a2; *)
82+
Tensor.print ~with_code:false ~with_grad:false `Default @@ c

bin/hello_world.ml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ let hello3 () =
6363
Tensor.print_tree ~with_grad:false ~depth:9 zero_to_twenty;
6464
Stdlib.Format.print_newline ();
6565
Train.run routine;
66-
assert (Backend.to_host routine.context y.value);
67-
Backend.await stream;
6866
Tensor.print ~with_code:true ~with_grad:false `Default y;
6967
Stdlib.Format.force_newline ();
7068
Tensor.print_tree ~with_grad:false ~depth:9 y;
@@ -95,11 +93,11 @@ let hello4 () =
9593
Train.set_hosted tk.value;
9694
Train.forward_and_forget backend ctx positions;
9795
Stdio.print_endline "positions:";
98-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ positions;
96+
Tensor.print ~with_code:false ~with_grad:false `Default @@ positions;
9997
Stdio.print_endline "tk:";
100-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ tk;
98+
Tensor.print ~with_code:false ~with_grad:false `Default @@ tk;
10199
Stdio.print_endline "ti:";
102-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ti;
100+
Tensor.print ~with_code:false ~with_grad:false `Default @@ ti;
103101
Stdio.printf "\n%!"
104102

105103
let hello5 () =
@@ -120,8 +118,8 @@ let hello5 () =
120118
let hey = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
121119
let%op ho = hey ++ "...|1->... => ...|..." in
122120
Train.forward_and_forget backend ctx ho;
123-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ hey;
124-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ho
121+
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey;
122+
Tensor.print ~with_code:false ~with_grad:false `Default @@ ho
125123

126124
let hello6 () =
127125
Utils.set_log_level 2;

bin/hello_world_op.ml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
7272
Rand.init 0;
7373
let%op hey = [ (1, 2, 3); (4, 5, 6) ] in
7474
Train.forward_and_forget backend ctx hey;
75-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hey;
75+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hey;
7676
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey;
7777
let%op hoo = [| [ 1; 2; 3 ]; [ 4; 5; 6 ] |] in
7878
Train.forward_and_forget backend ctx hoo;
79-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hoo;
79+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hoo;
8080
Tensor.print ~with_code:false ~with_grad:false `Default @@ hoo;
8181
let%op hey2 =
8282
[
@@ -87,7 +87,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
8787
]
8888
in
8989
Train.forward_and_forget backend ctx hey2;
90-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hey2;
90+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hey2;
9191
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey2;
9292
let%op hoo2 =
9393
[|
@@ -98,8 +98,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
9898
|]
9999
in
100100
Train.forward_and_forget backend ctx hoo2;
101-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hoo2;
102-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ hoo2;
101+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hoo2;
102+
Tensor.print ~with_code:false ~with_grad:false `Default @@ hoo2;
103103
let%op heyhoo =
104104
[|
105105
[| [ 1; 2; 3 ]; [ 4; 5; 6 ] |];
@@ -109,8 +109,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
109109
|]
110110
in
111111
Train.forward_and_forget backend ctx heyhoo;
112-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo;
113-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo;
112+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo;
113+
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo;
114114
let%op heyhoo2 =
115115
[|
116116
[| [ [ 1; 31 ]; [ 2; 32 ]; [ 3; 33 ] ]; [ [ 4; 34 ]; [ 5; 35 ]; [ 6; 36 ] ] |];
@@ -120,8 +120,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
120120
|]
121121
in
122122
Train.forward_and_forget backend ctx heyhoo2;
123-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo2;
124-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo2;
123+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo2;
124+
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo2;
125125
let%op heyhoo3 =
126126
[|
127127
[|
@@ -135,8 +135,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
135135
|]
136136
in
137137
Train.forward_and_forget backend ctx heyhoo3;
138-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo3;
139-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo3;
138+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo3;
139+
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo3;
140140
let%op heyhoo4 =
141141
[|
142142
[
@@ -150,8 +150,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
150150
|]
151151
in
152152
Train.forward_and_forget backend ctx heyhoo4;
153-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo4;
154-
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo4
153+
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo4;
154+
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo4
155155

156156
let%track2_sexp _Matrix_multiplication_dims_2x3 (() : unit) : unit =
157157
Tensor.unsafe_reinitialize ();

bin/micrograd_basic.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ let%diagn_sexp () =
2626
]; *)
2727
let update = Train.grad_update d in
2828
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
29-
Train.sync_run (module Backend) routine d;
29+
Train.run routine;
3030
Tensor.print_tree ~with_grad:true ~depth:9 d;
3131
Stdio.print_endline "\n";
3232
Tensor.print ~with_code:false ~with_grad:false `Default @@ d;
@@ -53,7 +53,7 @@ let%diagn_sexp _suspended () : unit =
5353
(* Train.every_non_literal_on_host g; *)
5454
let update = Train.grad_update g in
5555
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
56-
Train.sync_run (module Backend) routine g;
56+
Train.run routine;
5757
(* Tensor.print_tree ~with_grad:true ~depth:9 g; *)
5858
Tensor.print ~with_code:false ~with_grad:false `Default @@ g;
5959
Tensor.print ~with_code:false ~with_grad:true `Default @@ a;

bin/micrograd_demo.ml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
102102
for batch = 0 to n_batches - 1 do
103103
batch_ref := batch;
104104
Train.run routine;
105-
assert (Backend.to_host routine.context learning_rate.value);
106-
assert (Backend.to_host routine.context scalar_loss.value);
107-
Backend.await stream;
108105
(* Stdio.printf "Data batch=%d, step=%d, lr=%f, batch loss=%f\n%!" !batch_ref !step_ref
109106
learning_rate.@[0] scalar_loss.@[0]; *)
110107
learning_rates := learning_rate.@[0] :: !learning_rates;
@@ -141,8 +138,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
141138
needed. *)
142139
assert (Backend.from_host result_routine.context point.value);
143140
Train.run result_routine;
144-
assert (Backend.to_host result_routine.context mlp_result.value);
145-
Backend.await stream;
146141
Float.(mlp_result.@[0] >= 0.)
147142
in
148143
let%track3_sexp _plotting : unit =

bin/moons_demo.ml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ let demo () =
9191
batch_ref := batch;
9292
Utils.capture_stdout_logs @@ fun () ->
9393
Train.run routine;
94-
assert (Backend.to_host routine.context learning_rate.value);
95-
assert (Backend.to_host routine.context scalar_loss.value);
96-
Backend.await stream;
9794
epoch_loss := !epoch_loss +. scalar_loss.@[0];
9895
Int.incr step_ref
9996
done;
@@ -117,8 +114,6 @@ let demo () =
117114
Utils.capture_stdout_logs @@ fun () ->
118115
assert (Backend.from_host result_routine.context point.value);
119116
Train.run result_routine;
120-
assert (Backend.to_host result_routine.context mlp_result.value);
121-
Backend.await stream;
122117
Float.(mlp_result.@[0] >= 0.)
123118
in
124119

bin/zero2hero_1of7.ml

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ let _suspended () =
2525
Train.every_non_literal_on_host v;
2626
let code = Train.grad_update v in
2727
let routine = Train.to_routine (module Backend) ctx IDX.empty code.fwd_bprop in
28-
Train.sync_run (module Backend) routine v;
28+
Train.run routine;
2929
Stdio.printf "\n%!";
3030
Tensor.print_tree ~with_id:true ~with_grad:true ~depth:9 v;
3131
Stdlib.Format.printf "\nHigh-level code:\n%!";
@@ -47,7 +47,7 @@ let _suspended () =
4747
Tensor.print_tree ~with_grad:false ~depth:9 f5;
4848
Stdio.printf "\n%!"
4949

50-
let () =
50+
let _suspended () =
5151
(* FIXME: why is this toplevel example broken and the next one working? *)
5252
Utils.settings.output_debug_files_in_build_directory <- true;
5353
Rand.init 0;
@@ -75,14 +75,12 @@ let () =
7575
let step_ref = IDX.find_exn routine.bindings step_sym in
7676
let ys = Array.create ~len:size 0. and dys = Array.create ~len:size 0. in
7777
let open Operation.At in
78-
let looping () =
79-
assert (Backend.to_host routine.context fx.value);
80-
assert (Backend.to_host routine.context (Option.value_exn ~here:[%here] x.diff).grad);
81-
Backend.await stream;
78+
let f () =
79+
Train.run routine;
8280
ys.(!step_ref) <- fx.@[0];
8381
dys.(!step_ref) <- x.@%[0]
8482
in
85-
Train.sync_run ~looping (module Backend) routine fx;
83+
Train.sequential_loop routine.bindings ~f;
8684
let plot_box =
8785
let open PrintBox_utils in
8886
plot ~size:(75, 35) ~x_label:"x" ~y_label:"f(x)"
@@ -101,13 +99,6 @@ let _suspended () =
10199
(* Utils.settings.debug_log_from_routines <- true; *)
102100
Rand.init 0;
103101
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
104-
let backend =
105-
(module Backend : Backend
106-
with type buffer_ptr = Backend.buffer_ptr
107-
and type dev = Backend.dev
108-
and type runner = Backend.runner
109-
and type event = Backend.event)
110-
in
111102
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
112103
let ctx = Backend.make_context stream in
113104
let open Operation.At in
@@ -138,7 +129,7 @@ let _suspended () =
138129
Array.unzip
139130
@@ Array.mapi xs ~f:(fun i _ ->
140131
step_ref := i;
141-
Train.sync_run backend fx_routine fx;
132+
Train.run fx_routine;
142133
(fx.@[0], x.@%[0]))
143134
in
144135
(* It is fine to loop around the data: it's "next epoch". We redo the work though. *)
@@ -155,7 +146,7 @@ let _suspended () =
155146
in
156147
()
157148

158-
let _suspended () =
149+
let () =
159150
Rand.init 0;
160151
Utils.set_log_level 2;
161152
Utils.settings.output_debug_files_in_build_directory <- true;
@@ -172,8 +163,8 @@ let _suspended () =
172163
in
173164
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
174165
Train.run routine;
175-
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
176-
Backend.await stream;
166+
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
167+
Backend.await stream; *)
177168
Stdio.print_endline
178169
{|
179170
We did not update the params: all values and gradients will be at initial points,
@@ -195,8 +186,8 @@ let _suspended () =
195186
List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a ->
196187
assert (Backend.from_host routine.context a));
197188
Train.run routine;
198-
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
199-
Backend.await stream;
189+
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
190+
Backend.await stream; *)
200191
Stdio.print_endline
201192
{|
202193
Now we updated the params, but after the forward and backward passes:
@@ -206,8 +197,8 @@ let _suspended () =
206197
let update = Train.grad_update l in
207198
let routine = Train.to_routine (module Backend) routine.context IDX.empty update.fwd_bprop in
208199
Train.run routine;
209-
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
210-
Backend.await stream;
200+
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
201+
Backend.await stream; *)
211202
Stdio.print_endline
212203
{|
213204
Now again we did not update the params, they will remain as above, but both param

0 commit comments

Comments
 (0)