Skip to content

Commit 6d41b75

Browse files
committed
Automated from_host transfers
1 parent 1a33588 commit 6d41b75

File tree

11 files changed

+12
-48
lines changed

11 files changed

+12
-48
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Added
44

55
- Automatic transfers to host from the context that most recently updated a node.
6+
- Automatic transfers of routine's inputs from host to routine's context if the host array modification was not yet transfered.
67

78
## Fixed
89

arrayjit/lib/backends.ml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
7070
Tn.prepare_read
7171
~is_done:(fun () -> Backend.is_done e)
7272
~sync:(fun () -> Backend.sync e)
73-
~transfer:(fun () -> assert (to_host ctx tn); Backend.await s)
73+
~transfer:(fun () ->
74+
assert (to_host ctx tn);
75+
Backend.await s)
7476
tn);
7577
(* To be on the safe side, record events for potentially cross-stream nodes. *)
7678
match tn with
@@ -92,6 +94,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
9294
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
9395
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
9496
update_writer_event ~from:`Host ctx @@ Node tn;
97+
tn.host_modified <- false;
9598
true
9699
| _ -> false
97100

@@ -140,6 +143,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
140143
let s = r.context.stream in
141144
let hosted_inputs = Set.filter r.inputs ~f:(fun tn -> Tn.is_hosted_force tn 47) in
142145
let pre () =
146+
assert (Domain.is_main_domain ());
147+
Set.iter hosted_inputs ~f:(fun tn -> if tn.host_modified then assert (from_host r.context tn));
143148
Set.iter r.inputs ~f:(fun tn ->
144149
if Tn.potentially_cross_stream tn then
145150
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->

arrayjit/lib/tnode.ml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ type t = {
8484
mutable code_name : string option;
8585
mutable prepare_read : prepare option;
8686
mutable prepare_write : prepare option;
87+
mutable host_modified : bool;
8788
}
8889
[@@deriving sexp_of]
8990

@@ -553,6 +554,7 @@ let create ?default_prec ~id ~label ~dims init_op =
553554
code_name = None;
554555
prepare_read = None;
555556
prepare_write = None;
557+
host_modified = true;
556558
}
557559
in
558560
(* Note: if tensor nodes get non-trivial finalizers, remember to either add an is_finalized flag
@@ -576,6 +578,7 @@ let find =
576578
code_name = None;
577579
prepare_read = None;
578580
prepare_write = None;
581+
host_modified = false;
579582
}
580583
in
581584
fun ~id -> Registry.find_opt registry { mock with id }
@@ -592,7 +595,8 @@ let do_read tn =
592595

593596
let do_write tn =
594597
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
595-
tn.prepare_write <- None
598+
tn.prepare_write <- None;
599+
tn.host_modified <- true
596600

597601
let points_1d ?from_axis ~xdim tn =
598602
do_read tn;

bin/compilation_speed.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ let benchmark_overhead backend () =
3939
Train.to_routine (module Backend) init_assign_x.context IDX.empty update_f.fwd_bprop
4040
in
4141
Tensor.print_tree ~with_grad:true ~with_backend_info:true ~depth:9 f;
42-
Tensor.iter_embedded f ~f:(fun a -> ignore (Backend.from_host f_routine.context a : bool));
4342

4443
let xs = Array.init n_data ~f:Float.(fun i -> of_int i - (of_int n_data /. 2.)) in
4544
let open Operation.At in

bin/hello_world.ml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ let hello3 () =
5656
let y = TDSL.O.(( + ) ~label:[ "y" ] (hey * zero_to_twenty) zero_to_twenty) in
5757
Train.set_hosted hey.value;
5858
let routine = Train.to_routine (module Backend) ctx IDX.empty @@ Train.forward y in
59-
assert (Backend.from_host routine.context hey.value);
60-
assert (Backend.from_host routine.context zero_to_twenty.value);
6159
Tensor.print ~with_code:true ~with_grad:false `Inline zero_to_twenty;
6260
Tensor.print ~with_code:true ~with_grad:false `Default zero_to_twenty;
6361
Tensor.print_tree ~with_grad:false ~depth:9 zero_to_twenty;

bin/micrograd_demo.ml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
8585
let routine =
8686
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
8787
in
88-
Train.all_host_to_device (module Backend) routine.context scalar_loss;
89-
Train.all_host_to_device (module Backend) routine.context learning_rate;
9088
(* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true
9189
~with_grad:false ~depth:9 scalar_loss; Stdio.print_endline "\n******** learning_rate
9290
**********"; Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 learning_rate;
@@ -136,7 +134,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
136134
Tn.set_values point.value [| x; y |];
137135
(* For the gccjit backend, point is only on host, not on device. For cuda, this will be
138136
needed. *)
139-
assert (Backend.from_host result_routine.context point.value);
140137
Train.run result_routine;
141138
Float.(mlp_result.@[0] >= 0.)
142139
in

bin/moons_demo.ml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ let demo () =
7878
PrintBox_text.output Stdio.stdout plot_moons;
7979
Stdio.print_endline "\n";
8080

81-
Train.all_host_to_device (module Backend) routine.context scalar_loss;
82-
Train.all_host_to_device (module Backend) routine.context learning_rate;
8381
let open Operation.At in
8482
let step_ref = IDX.find_exn routine.bindings step_n in
8583
let batch_ref = IDX.find_exn routine.bindings batch_n in
@@ -112,7 +110,6 @@ let demo () =
112110
let callback (x, y) =
113111
Tn.set_values point.value [| x; y |];
114112
Utils.capture_stdout_logs @@ fun () ->
115-
assert (Backend.from_host result_routine.context point.value);
116113
Train.run result_routine;
117114
Float.(mlp_result.@[0] >= 0.)
118115
in

bin/zero2hero_1of7.ml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ let () =
161161
let routine =
162162
Train.to_routine (module Backend) (Backend.make_context stream) IDX.empty update.fwd_bprop
163163
in
164-
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
165164
Train.run routine;
166165
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
167166
Backend.await stream; *)
@@ -176,18 +175,12 @@ let () =
176175
@@ Train.sgd_update ~learning_rate update
177176
in
178177
(* learning_rate is virtual so this will not print anything. *)
179-
Tensor.iter_embedded learning_rate ~f:(fun a ->
180-
ignore (Backend.from_host routine.context a : bool));
181178
Stdio.print_endline
182179
{|
183180
Due to how the gccjit backend works, since the parameters were constant in the grad_update
184181
computation, they did not exist on the device before. Now they do. This would not be needed
185182
on the cuda backend.|};
186-
List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a ->
187-
assert (Backend.from_host routine.context a));
188183
Train.run routine;
189-
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
190-
Backend.await stream; *)
191184
Stdio.print_endline
192185
{|
193186
Now we updated the params, but after the forward and backward passes:

lib/train.ml

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,6 @@ let every_non_literal_on_host =
269269
Tensor.iter_embedded ~f:(fun a ->
270270
if Tn.mode_is_unspecified a && not (Tn.known_constant a) then set_hosted a)
271271

272-
(* Note: this will get nicer with modular explicits. *)
273-
let%debug2_sexp all_host_to_device (type buffer_ptr dev runner event)
274-
(module Backend : Backend
275-
with type buffer_ptr = buffer_ptr
276-
and type dev = dev
277-
and type runner = runner
278-
and type event = event) (context : Backend.context) =
279-
let f tn = ignore (Backend.from_host context tn : bool) in
280-
Tensor.iter_embedded ~f
281-
282272
module Lazy = Utils.Lazy
283273

284274
(** Performs one optimization step, potentially in parallel (if [grad_updates] are linked with
@@ -469,8 +459,6 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
469459
let sgd_update = to_routine (module Backend) grad_updates.(0).context bindings sgd in
470460
Tensor.log_debug_info ~from_log_level:2 inputs;
471461
Tensor.log_debug_info ~from_log_level:2 outputs;
472-
all_host_to_device (module Backend) sgd_update.context scalar_loss;
473-
all_host_to_device (module Backend) sgd_update.context learning_rate;
474462
let open Operation.At in
475463
let epoch_loss = ref 0. in
476464
let step_ref = IDX.find_exn sgd_update.bindings step_n in
@@ -531,7 +519,6 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
531519
(* For the gccjit backend, infer is only on host, not on device. For cuda, this will be
532520
needed. *)
533521
Utils.capture_stdout_logs @@ fun () ->
534-
assert (Backend.from_host routine.context infer.value);
535522
run routine;
536523
Tn.get_values model_result.value
537524
in
@@ -558,7 +545,6 @@ let%track3_sexp forward_and_ctx ?(disable_rootness_check = false) (type buffer_p
558545
and type event = event) ctx ?(bindings = IDX.empty) t =
559546
let routine = Backend.(link ctx @@ compile bindings @@ forward ~disable_rootness_check t) in
560547
if not disable_rootness_check then Tensor.remove_bprop_root t;
561-
Tensor.iter_embedded t ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
562548
Task.run routine.schedule;
563549
routine.context
564550

test/micrograd_demo.ml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ let%expect_test "Micrograd README basic example" =
2929
List.iter ~f:(Option.iter ~f:(fun diff -> Train.set_hosted diff.Tensor.grad)) [ a.diff; b.diff ];
3030
let update = Train.grad_update g in
3131
let step = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
32-
Tensor.iter_embedded g ~f:(fun a -> ignore (Backend.from_host step.context a : bool));
3332
Train.run step;
3433
Tensor.print ~with_code:false ~with_grad:false `Default g;
3534
[%expect
@@ -89,13 +88,6 @@ let%expect_test "Micrograd half-moons example" =
8988
(* Note: for as-yet unknown reason, this test can lead to different resuls on different versions
9089
of dependencies. *)
9190
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name:"cc" ()) in
92-
let backend =
93-
(module Backend : Backend
94-
with type buffer_ptr = Backend.buffer_ptr
95-
and type dev = Backend.dev
96-
and type runner = Backend.runner
97-
and type event = Backend.event)
98-
in
9991
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
10092
let ctx = Backend.make_context stream in
10193
let open Operation.At in
@@ -148,8 +140,6 @@ let%expect_test "Micrograd half-moons example" =
148140
let sgd_routine =
149141
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
150142
in
151-
Train.all_host_to_device backend sgd_routine.context scalar_loss;
152-
Train.all_host_to_device backend sgd_routine.context learning_rate;
153143
let step_ref = IDX.find_exn sgd_routine.bindings step_n in
154144
step_ref := 0;
155145
for _epoch = 1 to epochs do
@@ -180,7 +170,6 @@ let%expect_test "Micrograd half-moons example" =
180170
Tn.set_values point.value [| x; y |];
181171
(* For the gccjit backend, point is only on host, not on device. For cuda, this will be
182172
needed. *)
183-
assert (Backend.from_host result_routine.context point.value);
184173
Train.run result_routine;
185174
Float.(mlp_result.@[0] >= 0.)
186175
in

0 commit comments

Comments
 (0)