Skip to content

Commit 84a2bce

Browse files
committed
Consistency for print_tree, and refactor of where forcing happens in "quick exec" helpers and printing
Refactor tensor operations to use `Train.forward_once` and `Train.printf` for improved clarity and consistency. Removed unnecessary context initialization in multiple files, streamlining the codebase.
1 parent e180610 commit 84a2bce

22 files changed

+574
-608
lines changed

bin/compilation_speed.ml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ let benchmark_overhead backend () =
2222
Train.set_hosted f.value;
2323

2424
(* Train.every_non_literal_on_host f; *)
25-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
26-
let ctx = Backend.make_context stream in
27-
let init_mem = Backend.(get_used_memory stream.device) in
25+
let device = Backend.(get_device ~ordinal:0) in
26+
let init_mem = Backend.get_used_memory device in
2827
let update_f = Train.grad_update f in
29-
let ctx = Train.init_params (module Backend) ~ctx IDX.empty f in
28+
let ctx = Train.init_params (module Backend) IDX.empty f in
3029
let f_routine = Train.to_routine (module Backend) ctx IDX.empty update_f in
31-
Tensor.print_tree ~with_grad:true ~with_backend_info:true ~depth:9 f;
30+
Train.printf_tree ~with_grad:true ~depth:9 f;
3231

3332
let xs = Array.init n_data ~f:Float.(fun i -> of_int i - (of_int n_data /. 2.)) in
3433
let open Operation.At in
@@ -49,7 +48,7 @@ let benchmark_overhead backend () =
4948
in
5049
let final_time = Time_now.nanoseconds_since_unix_epoch () in
5150
let time_in_sec = Int63.(to_float @@ (final_time - init_time)) /. 1000_000_000. in
52-
let mem_in_bytes = Backend.(get_used_memory stream.device) - init_mem in
51+
let mem_in_bytes = Backend.get_used_memory device - init_mem in
5352
let result =
5453
PrintBox_utils.Benchmark
5554
{

bin/einsum_trivia.ml

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ module type Backend = Ir.Backend_intf.Backend
1010
let _suspended () =
1111
Rand.init 0;
1212
let module Backend = (val Backends.fresh_backend ()) in
13-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
14-
let ctx = Backend.make_context stream in
13+
14+
1515
let a = TDSL.range_of_shape ~label:[ "a" ] ~input_dims:[ 2 ] ~output_dims:[ 2 ] () in
1616
let b = TDSL.range_of_shape ~label:[ "b" ] ~input_dims:[ 2; 3; 4 ] ~output_dims:[ 2 ] () in
1717
let%op c = a *+ "i->1; ij...->0 => ...->ji" b in
18-
Train.forward_and_force (module Backend) ctx c;
19-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ c;
18+
ignore (Train.forward_once (module Backend) c);
19+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
2020
Stdio.printf "\n%!"
2121

2222
let _suspended () =
@@ -29,19 +29,19 @@ let _suspended () =
2929
and type event = Backend.event
3030
and type optimize_ctx = Backend.optimize_ctx)
3131
in
32-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
33-
let ctx = Backend.make_context stream in
32+
33+
3434
Rand.init 0;
3535
let hey = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
36-
let%op ho = hey ++ "b|i->o => o|b->i" in
37-
let ctx = Utils.capture_stdout_logs (fun () -> Train.forward_and_ctx backend ctx ho) in
36+
let%op _ho = hey ++ "b|i->o => o|b->i" in
37+
3838
let hey2 =
3939
TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] ()
4040
in
4141
let%op ho2 = hey2 ++ "ab|cd->ef => cf|ae->db" in
4242
Utils.capture_stdout_logs @@ fun () ->
43-
Train.forward_and_force backend ctx ho2;
44-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ ho2
43+
ignore (Train.forward_once backend ho2);
44+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho2
4545

4646
let () =
4747
let module Backend = (val Backends.fresh_backend ()) in
@@ -53,19 +53,17 @@ let () =
5353
and type event = Backend.event
5454
and type optimize_ctx = Backend.optimize_ctx)
5555
in
56-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
57-
let ctx = Backend.make_context stream in
56+
57+
5858
Rand.init 0;
5959
let a = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
6060
let b = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 4 ] ~output_dims:[ 5 ] () in
61-
let%op a2 = a *+ "b|i->o; b|i->o => b|i->o" a in
62-
let ctx = Utils.capture_stdout_logs (fun () -> Train.forward_and_ctx backend ctx a2) in
61+
let%op _ = a *+ "b|i->o; b|i->o => b|i->o" a in
6362
let%op c = b *+ "b|h->o; b|i->h => b|i->o" a in
64-
Utils.capture_stdout_logs (fun () -> Train.forward_and_force backend ctx c);
63+
Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend c));
6564
(* let%op d = a *+ "a|i->h; b|h->o => ab|i->o" b in Utils.capture_stdout_logs (fun () ->
66-
Train.forward_and_force backend ctx d); let%op e = a *+ "b|i->h; b|h->o => i->o" b in
67-
Utils.capture_stdout_logs (fun () -> Train.forward_and_force backend ctx e); let%op f = a *+
68-
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> Train.forward_and_force
69-
backend ctx f); *)
70-
(* Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ a2; *)
71-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ c
65+
ignore (Train.forward_once backend d)); let%op e = a *+ "b|i->h; b|h->o => i->o" b in
66+
Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend e)); let%op f = a *+
67+
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend f)); *)
68+
(* Train.printf ~here:[%here] ~with_code:false ~with_grad:false a2; *)
69+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c

bin/hello_world.ml

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,47 +12,49 @@ module type Backend = Ir.Backend_intf.Backend
1212
let hello1 () =
1313
Rand.init 0;
1414
let module Backend = (val Backends.fresh_backend ()) in
15-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
16-
let ctx = Backend.make_context stream in
15+
16+
1717
let open Operation.TDSL in
1818
(* Hey is inferred to be a matrix. *)
1919
let hey = range_of_shape ~batch_dims:[ 7 ] ~input_dims:[ 9; 10; 11 ] ~output_dims:[ 13; 14 ] () in
2020
let%op hoo = ((1 + 1) * hey) - 10 in
2121
(* For convenience, Train.forward will set hoo.value as fully on host. *)
22-
Train.forward_and_force (module Backend) ctx hoo;
22+
ignore (Train.forward_once (module Backend) hoo);
2323
(* Disable line wrapping for viewing the output. In VSCode: `View: Toggle Word Wrap`. *)
24-
Tensor.print_tree ~with_grad:false ~depth:99 hoo;
25-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default hoo
24+
Train.printf_tree ~with_grad:false ~depth:99 hoo;
25+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hoo
2626

2727
let hello2 () =
2828
Rand.init 0;
2929
let module Backend = (val Backends.fresh_backend ()) in
30-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
31-
let ctx = Backend.make_context stream in
30+
31+
3232
(* Hey is inferred to be a matrix. *)
3333
let%op y = ("hey" * 'q' 2.0) + 'p' 1.0 in
3434
(* Punning for ["hey"] above introduced the [hey] identifier. *)
3535
Train.every_non_literal_on_host y;
36-
Train.forward_and_force (module Backend) ctx y;
37-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey;
38-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ y
36+
ignore (Train.forward_once (module Backend) y);
37+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
38+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false y
3939

4040
let hello3 () =
4141
Rand.init 0;
4242
let module Backend = (val Backends.fresh_backend ()) in
43-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
44-
let ctx = Backend.make_context stream in
43+
44+
4545
(* Hey is inferred to be a matrix. *)
4646
let hey = TDSL.param "hey" in
4747
let zero_to_twenty = TDSL.range 20 in
4848
let y = TDSL.O.(( + ) ~label:[ "y" ] (hey * zero_to_twenty) zero_to_twenty) in
4949
Train.set_hosted hey.value;
50+
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
51+
let ctx = Backend.make_context stream in
5052
let routine = Train.to_routine (module Backend) ctx IDX.empty @@ Train.forward y in
5153
Stdio.printf "\n%!";
5254
Train.run routine;
53-
Tensor.print ~here:[%here] ~with_code:true ~with_grad:false `Default y;
55+
Train.printf ~here:[%here] ~with_code:true ~with_grad:false y;
5456
Stdio.printf "\n%!";
55-
Tensor.print_tree ~with_grad:false ~depth:9 y;
57+
Train.printf_tree ~with_grad:false ~depth:9 y;
5658
Stdio.printf "\n%!"
5759

5860
let hello4 () =
@@ -65,8 +67,8 @@ let hello4 () =
6567
and type event = Backend.event
6668
and type optimize_ctx = Backend.optimize_ctx)
6769
in
68-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
69-
let ctx = Backend.make_context stream in
70+
71+
7072
Rand.init 0;
7173
let ri = TDSL.range 3 in
7274
let%op ti = ri ++ "i=>i0" in
@@ -79,13 +81,13 @@ let hello4 () =
7981
let positions = TDSL.outer_sum "ijl;kl=>ijkl" (TDSL.outer_sum "il;jl=>ijl" ti tj) tk in
8082
Train.set_hosted ti.value;
8183
Train.set_hosted tk.value;
82-
Train.forward_and_force backend ctx positions;
84+
ignore (Train.forward_once backend positions);
8385
Stdio.print_endline "positions:";
84-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ positions;
86+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false positions;
8587
Stdio.print_endline "tk:";
86-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ tk;
88+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false tk;
8789
Stdio.print_endline "ti:";
88-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ ti;
90+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ti;
8991
Stdio.printf "\n%!"
9092

9193
let hello5 () =
@@ -98,13 +100,13 @@ let hello5 () =
98100
and type event = Backend.event
99101
and type optimize_ctx = Backend.optimize_ctx)
100102
in
101-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
102-
let ctx = Backend.make_context stream in
103+
104+
103105
Rand.init 0;
104106
let hey = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
105107
let%op ho = hey ++ "...|1->... => ...|..." in
106-
Train.forward_and_force backend ctx ho;
107-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ ho
108+
ignore (Train.forward_once backend ho);
109+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho
108110

109111
let hello6 () =
110112
let module Backend = (val Backends.fresh_backend ()) in
@@ -116,14 +118,14 @@ let hello6 () =
116118
and type event = Backend.event
117119
and type optimize_ctx = Backend.optimize_ctx)
118120
in
119-
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
120-
let ctx = Backend.make_context stream in
121+
122+
121123
Rand.init 0;
122124
(* "Hey" is inferred to be a scalar. *)
123125
let%op y = 2 *. "hey" in
124-
Train.forward_and_force backend ctx y;
125-
(* Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ hey; *)
126-
Tensor.print ~here:[%here] ~with_code:false ~with_grad:false `Default @@ y
126+
ignore (Train.forward_once backend y);
127+
(* Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey; *)
128+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false y
127129

128130
let () =
129131
ignore (hello1, hello2, hello3, hello4, hello5, hello6);

0 commit comments

Comments
 (0)