Skip to content

Commit 3c0a07c

Browse files
committed
Fixes #324: Make Tensor.print non-forcing by default; refactor forward_and_forget to forward_and_force
This change ensures that tensor values are forced to the host as needed, but aren't forced by mistake.
1 parent 5caec3b commit 3c0a07c

File tree

12 files changed

+192
-139
lines changed

12 files changed

+192
-139
lines changed

bin/einsum_trivia.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ let _suspended () =
1515
let a = TDSL.range_of_shape ~label:[ "a" ] ~input_dims:[ 2 ] ~output_dims:[ 2 ] () in
1616
let b = TDSL.range_of_shape ~label:[ "b" ] ~input_dims:[ 2; 3; 4 ] ~output_dims:[ 2 ] () in
1717
let%op c = a *+ "i->1; ij...->0 => ...->ji" b in
18-
Train.forward_and_forget (module Backend) ctx c;
18+
Train.forward_and_force (module Backend) ctx c;
1919
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ c;
2020
Stdio.printf "\n%!"
2121

@@ -40,7 +40,7 @@ let _suspended () =
4040
in
4141
let%op ho2 = hey2 ++ "ab|cd->ef => cf|ae->db" in
4242
Utils.capture_stdout_logs @@ fun () ->
43-
Train.forward_and_forget backend ctx ho2;
43+
Train.forward_and_force backend ctx ho2;
4444
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ ho2
4545

4646
let () =
@@ -61,11 +61,11 @@ let () =
6161
let%op a2 = a *+ "b|i->o; b|i->o => b|i->o" a in
6262
let ctx = Utils.capture_stdout_logs (fun () -> Train.forward_and_ctx backend ctx a2) in
6363
let%op c = b *+ "b|h->o; b|i->h => b|i->o" a in
64-
Utils.capture_stdout_logs (fun () -> Train.forward_and_forget backend ctx c);
64+
Utils.capture_stdout_logs (fun () -> Train.forward_and_force backend ctx c);
6565
(* let%op d = a *+ "a|i->h; b|h->o => ab|i->o" b in Utils.capture_stdout_logs (fun () ->
66-
Train.forward_and_forget backend ctx d); let%op e = a *+ "b|i->h; b|h->o => i->o" b in
67-
Utils.capture_stdout_logs (fun () -> Train.forward_and_forget backend ctx e); let%op f = a *+
68-
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> Train.forward_and_forget
66+
Train.forward_and_force backend ctx d); let%op e = a *+ "b|i->h; b|h->o => i->o" b in
67+
Utils.capture_stdout_logs (fun () -> Train.forward_and_force backend ctx e); let%op f = a *+
68+
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> Train.forward_and_force
6969
backend ctx f); *)
7070
(* Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ a2; *)
7171
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ c

bin/hello_world.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ let hello1 () =
1919
let hey = range_of_shape ~batch_dims:[ 7 ] ~input_dims:[ 9; 10; 11 ] ~output_dims:[ 13; 14 ] () in
2020
let%op hoo = ((1 + 1) * hey) - 10 in
2121
(* For convenience, Train.forward will set hoo.value as fully on host. *)
22-
Train.forward_and_forget (module Backend) ctx hoo;
22+
Train.forward_and_force (module Backend) ctx hoo;
2323
(* Disable line wrapping for viewing the output. In VSCode: `View: Toggle Word Wrap`. *)
2424
Tensor.print_tree ~with_grad:false ~depth:99 hoo;
2525
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default hoo
@@ -33,7 +33,7 @@ let hello2 () =
3333
let%op y = ("hey" * 'q' 2.0) + 'p' 1.0 in
3434
(* Punning for ["hey"] above introduced the [hey] identifier. *)
3535
Train.every_non_literal_on_host y;
36-
Train.forward_and_forget (module Backend) ctx y;
36+
Train.forward_and_force (module Backend) ctx y;
3737
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey;
3838
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ y
3939

@@ -79,7 +79,7 @@ let hello4 () =
7979
let positions = TDSL.outer_sum "ijl;kl=>ijkl" (TDSL.outer_sum "il;jl=>ijl" ti tj) tk in
8080
Train.set_hosted ti.value;
8181
Train.set_hosted tk.value;
82-
Train.forward_and_forget backend ctx positions;
82+
Train.forward_and_force backend ctx positions;
8383
Stdio.print_endline "positions:";
8484
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ positions;
8585
Stdio.print_endline "tk:";
@@ -103,7 +103,7 @@ let hello5 () =
103103
Rand.init 0;
104104
let hey = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
105105
let%op ho = hey ++ "...|1->... => ...|..." in
106-
Train.forward_and_forget backend ctx ho;
106+
Train.forward_and_force backend ctx ho;
107107
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ ho
108108

109109
let hello6 () =
@@ -121,7 +121,7 @@ let hello6 () =
121121
Rand.init 0;
122122
(* "Hey" is inferred to be a scalar. *)
123123
let%op y = 2 *. "hey" in
124-
Train.forward_and_forget backend ctx y;
124+
Train.forward_and_force backend ctx y;
125125
(* Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey; *)
126126
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ y
127127

bin/hello_world_op.ml

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ let%track2_sexp _Pointwise_multiplication_dims_1 (() : unit) : unit =
2929
Rand.init 0;
3030
(* "Hey" is inferred to be a scalar. *)
3131
let%op ya = 2 *. "hey" 7.0 in
32-
Train.forward_and_forget backend ctx ya;
32+
Train.forward_and_force backend ctx ya;
3333
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ ya
3434

3535
let%track2_sexp _Matrix_multiplication_dims_1x1 (() : unit) : unit =
@@ -48,11 +48,25 @@ let%track2_sexp _Matrix_multiplication_dims_1x1 (() : unit) : unit =
4848
Rand.init 0;
4949
(* Hey is inferred to be a matrix because of matrix multiplication [*]. *)
5050
let%op yb = ("hey" 7.0 * 'q' 2.0) + 'p' 1.0 in
51-
Train.forward_and_forget backend ctx yb;
51+
Train.forward_and_force backend ctx yb;
5252
(* Punning for ["hey"] above introduced the [hey] identifier. *)
5353
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey;
5454
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ yb
5555

56+
let%track2_sexp _Print_constant_tensor_too_early (() : unit) : unit =
57+
Tensor.unsafe_reinitialize ();
58+
let module Backend = (val Backends.fresh_backend ()) in
59+
let print_tensor = Tensor.print ~with_code:false ~with_grad:false in
60+
61+
let%op a = [| 1.; 2.; 3.; 4. |] in
62+
let%op b = [| 2.; 3.; 4.; 5. |] in
63+
print_tensor ~here:[%here] `Default a;
64+
print_tensor ~here:[%here] `Default b;
65+
let%op c = a *. b in
66+
let ctx = Train.init_params (module Backend) IDX.empty c in
67+
Train.forward_and_force (module Backend) ctx c;
68+
print_tensor ~here:[%here] `Default c
69+
5670
let%track2_sexp _Print_constant_tensor (() : unit) : unit =
5771
Tensor.unsafe_reinitialize ();
5872
let module Backend = (val Backends.fresh_backend ()) in
@@ -68,11 +82,11 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
6882
let ctx = Backend.make_context stream in
6983
Rand.init 0;
7084
let%op hey = [ (1, 2, 3); (4, 5, 6) ] in
71-
Train.forward_and_forget backend ctx hey;
85+
Train.forward_and_force backend ctx hey;
7286
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ hey;
7387
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey;
7488
let%op hoo = [| [ 1; 2; 3 ]; [ 4; 5; 6 ] |] in
75-
Train.forward_and_forget backend ctx hoo;
89+
Train.forward_and_force backend ctx hoo;
7690
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ hoo;
7791
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hoo;
7892
let%op hey2 =
@@ -83,7 +97,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
8397
((19, 20, 21), (22, 23, 24));
8498
]
8599
in
86-
Train.forward_and_forget backend ctx hey2;
100+
Train.forward_and_force backend ctx hey2;
87101
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ hey2;
88102
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey2;
89103
let%op hoo2 =
@@ -94,7 +108,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
94108
[ [ 19; 20; 21 ]; [ 22; 23; 24 ] ];
95109
|]
96110
in
97-
Train.forward_and_forget backend ctx hoo2;
111+
Train.forward_and_force backend ctx hoo2;
98112
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ hoo2;
99113
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hoo2;
100114
let%op heyhoo =
@@ -105,7 +119,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
105119
[| [ 19; 20; 21 ]; [ 22; 23; 24 ] |];
106120
|]
107121
in
108-
Train.forward_and_forget backend ctx heyhoo;
122+
Train.forward_and_force backend ctx heyhoo;
109123
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ heyhoo;
110124
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ heyhoo;
111125
let%op heyhoo2 =
@@ -116,7 +130,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
116130
[| [ [ 19; 49 ]; [ 20; 50 ]; [ 21; 51 ] ]; [ [ 22; 52 ]; [ 23; 53 ]; [ 24; 54 ] ] |];
117131
|]
118132
in
119-
Train.forward_and_forget backend ctx heyhoo2;
133+
Train.forward_and_force backend ctx heyhoo2;
120134
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ heyhoo2;
121135
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ heyhoo2;
122136
let%op heyhoo3 =
@@ -131,7 +145,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
131145
|];
132146
|]
133147
in
134-
Train.forward_and_forget backend ctx heyhoo3;
148+
Train.forward_and_force backend ctx heyhoo3;
135149
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ heyhoo3;
136150
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ heyhoo3;
137151
let%op heyhoo4 =
@@ -146,7 +160,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
146160
];
147161
|]
148162
in
149-
Train.forward_and_forget backend ctx heyhoo4;
163+
Train.forward_and_force backend ctx heyhoo4;
150164
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Inline @@ heyhoo4;
151165
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ heyhoo4
152166

@@ -166,7 +180,7 @@ let%track2_sexp _Matrix_multiplication_dims_2x3 (() : unit) : unit =
166180
Rand.init 0;
167181
(* Hey is inferred to be a matrix. *)
168182
let%op yc = ("hey" 7.0 * [ 2; 3 ]) + [ 4; 5; 6 ] in
169-
Train.forward_and_forget backend ctx yc;
183+
Train.forward_and_force backend ctx yc;
170184
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey;
171185
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ yc
172186

@@ -188,7 +202,7 @@ let%track2_sexp _Big_matrix (() : unit) : unit =
188202
let hey = TDSL.param ~value:0.5 "hey" in
189203
let zero_to_twenty = TDSL.range 20 in
190204
let%op yd = (hey * zero_to_twenty) + zero_to_twenty in
191-
Train.forward_and_forget backend ctx yd;
205+
Train.forward_and_force backend ctx yd;
192206
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default hey;
193207
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default yd
194208

@@ -208,7 +222,7 @@ let%track2_sexp _Very_big_tensor (() : unit) : unit =
208222
Rand.init 0;
209223
let hey = TDSL.range_of_shape ~batch_dims:[ 6 ] ~input_dims:[ 7; 8 ] ~output_dims:[ 9 ] () in
210224
let%op ye = (hey * (1 + 1)) - 10 in
211-
Train.forward_and_forget backend ctx ye;
225+
Train.forward_and_force backend ctx ye;
212226
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default ye
213227

214228
let _suspended (() : unit) : unit =
@@ -223,7 +237,4 @@ let _suspended (() : unit) : unit =
223237
_Big_matrix ();
224238
_Very_big_tensor ()
225239

226-
let (() : unit) : unit =
227-
_Matrix_multiplication_dims_2x3 ();
228-
_Big_matrix ();
229-
_Very_big_tensor ()
240+
let (() : unit) : unit = _Print_constant_tensor_too_early ()

bin/zero2hero_1of7.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ let _suspended () =
3838
let%op f5 = f 5 in
3939
let module Backend = (val Backends.fresh_backend ()) in
4040
Train.every_non_literal_on_host f5;
41-
Train.forward_and_forget
41+
Train.forward_and_force
4242
(module Backend)
4343
Backend.(make_context @@ new_stream @@ get_device ~ordinal:0)
4444
f5;
@@ -100,7 +100,7 @@ let _suspended () =
100100
let%op f x = (3 *. (x **. 2)) - (4 *. x) + 5 in
101101
let%op f5 = f 5 in
102102
Train.every_non_literal_on_host f5;
103-
Train.forward_and_forget (module Backend) ctx f5;
103+
Train.forward_and_force (module Backend) ctx f5;
104104
Tensor.print_tree ~with_grad:false ~depth:9 f5;
105105
let size = 100 in
106106
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) - 5.) in

lib/operation.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ let embed_symbol ?(label = []) static_sym : Tensor.t =
375375
(Shape.make ~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ())
376376
[]
377377

378+
(*
378379
let random_seed =
379380
let seed = Option.value ~default:42 @@ Utils.settings.fixed_state_for_init in
380381
let res =
@@ -384,7 +385,7 @@ let random_seed =
384385
in
385386
Tn.update_memory_mode res.value Tn.Effectively_constant 24;
386387
Tn.update_prec res.value Ir.Ops.uint4x32;
387-
ref res
388+
ref res *)
388389

389390
module DO = struct
390391
let ( * ) = matmul ~grad_spec:If_needed

lib/tensor.ml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -610,11 +610,13 @@ let to_dag ?(single_node = false) ?(embedded_only = false) ?entries_per_axis ~sp
610610
grad_txt diff ^ if (not should_elide) && not embedded then " non-emb" else ""
611611
in
612612
let node =
613+
if Lazy.is_val diff.grad.array then
613614
match Lazy.force diff.grad.array with
614615
| Some g_array ->
615616
Tn.do_read diff.grad;
616617
`Box (Nd.render_array ~brief:true ~prefix ?entries_per_axis ~labels ~indices g_array)
617618
| None -> `Text (prefix ^ " " ^ where_located diff.grad)
619+
else `Text (prefix ^ " <not-in-yet> " ^ where_located diff.grad)
618620
in
619621
`Subtree_with_ID (id, `Tree (add_shape [ node ], children))
620622
| _, true, true, Some diff ->
@@ -666,7 +668,7 @@ let log_debug_info ~from_log_level t =
666668
Tn.log_debug_info ~from_log_level diff.grad]);
667669
List.iter ~f:log_child t.children]]
668670

669-
let to_doc ?(spy = false) ~with_grad ~with_code ?(with_low_level = false)
671+
let to_doc ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
670672
(style : array_print_style) t =
671673
let sh = t.shape in
672674
let label = Tn.label t.value in
@@ -724,7 +726,7 @@ let to_doc ?(spy = false) ~with_grad ~with_code ?(with_low_level = false)
724726
let open PPrint in
725727
(* Create document for tensor value *)
726728
let value_doc =
727-
if spy && not (Lazy.is_val t.value.array) then
729+
if not force_read && not (Lazy.is_val t.value.array) then
728730
string prefix_str ^^ string " <not-in-yet>" ^^ space
729731
else
730732
match (style, Lazy.force t.value.array) with
@@ -743,7 +745,7 @@ let to_doc ?(spy = false) ~with_grad ~with_code ?(with_low_level = false)
743745
if with_grad then
744746
match t.diff with
745747
| Some diff -> (
746-
if spy && not (Lazy.is_val diff.grad.array) then
748+
if not force_read && not (Lazy.is_val diff.grad.array) then
747749
string (grad_txt diff) ^^ string " <not-in-yet>" ^^ space
748750
else
749751
match Lazy.force diff.grad.array with
@@ -816,12 +818,12 @@ let to_doc ?(spy = false) ~with_grad ~with_code ?(with_low_level = false)
816818
(* Combine all documents and print *)
817819
group (value_doc ^^ break 1 ^^ grad_doc ^^ break 1 ^^ code_doc ^^ break 1 ^^ low_level_doc)
818820

819-
let print ?here ?(spy = false) ~with_grad ~with_code ?(with_low_level = false)
821+
let print ?here ?(force_read = false) ~with_grad ~with_code ?(with_low_level = false)
820822
(style : array_print_style) t =
821823
Option.iter here ~f:(fun here ->
822824
Stdio.printf "HERE: %s\n%!" (Source_code_position.to_string here));
823825
PPrint.ToChannel.pretty 0.7 100 Stdio.stdout
824-
(to_doc ~spy ~with_grad ~with_code ~with_low_level style t)
826+
(to_doc ~force_read ~with_grad ~with_code ~with_low_level style t)
825827

826828
let print_forward_roots ~with_grad ~with_code (style : array_print_style) =
827829
List.iter (Map.to_alist ~key_order:`Increasing session_state.forward_roots) ~f:(fun (id, root) ->

lib/tensor.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ val to_printbox :
345345
PrintBox.t
346346

347347
val to_doc :
348-
?spy:bool ->
348+
?force_read:bool ->
349349
with_grad:bool ->
350350
with_code:bool ->
351351
?with_low_level:bool ->
@@ -355,7 +355,7 @@ val to_doc :
355355

356356
val print :
357357
?here:Ppx_here_lib.position ->
358-
?spy:bool ->
358+
?force_read:bool ->
359359
with_grad:bool ->
360360
with_code:bool ->
361361
?with_low_level:bool ->

lib/train.ml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,15 @@ let%track3_sexp forward_and_ctx ?(hosted = true) ?(skip_init = false)
550550
Task.run routine.schedule;
551551
routine.context
552552

553-
let forward_and_forget ?hosted ?skip_init ?disable_rootness_check backend ctx ?bindings t =
553+
(** [forward_and_force] is a wrapper around {!forward_and_ctx} that additionally forces the
554+
tensor's value and ensures it is transferred back to host as needed, see the setting
555+
{!Utils.settings.automatic_host_transfers}. The resulting context is ignored.
556+
557+
Note: [Tensor.print ~force_read:true] also has this effect, so: using [forward_and_force] you
558+
don't need to pass [~force_read:true], and if you need the context and also to print the result,
559+
you can combine {!forward_and_ctx} and [Tensor.print ~force_read:true]. *)
560+
let forward_and_force ?hosted ?skip_init ?disable_rootness_check backend ctx ?bindings t =
554561
(* FIXME: to properly forget we need to free the incrementally-allocated memory! *)
555-
ignore @@ forward_and_ctx ?hosted ?skip_init ?disable_rootness_check backend ctx ?bindings t
562+
ignore @@ forward_and_ctx ?hosted ?skip_init ?disable_rootness_check backend ctx ?bindings t;
563+
ignore (Lazy.force t.value.array);
564+
Tn.do_read t.value

0 commit comments

Comments
 (0)