Skip to content

Commit 8ed1bdc

Browse files
committed
Move Train.fresh_backend -> Backends.fresh_backend
1 parent 1fcf256 commit 8ed1bdc

15 files changed

+67
-68
lines changed

arrayjit/lib/backends.ml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,3 +1341,24 @@ let reinitialize (module Backend : Backend) config =
13411341
Core.Gc.full_major ();
13421342
Backend.unsafe_cleanup ();
13431343
Backend.initialize config)
1344+
1345+
1346+
(** Reinitializes a backend selected via a global [backend] flag. *)
1347+
let fresh_backend ?backend_name ?(config = Physical_devices_only) () =
1348+
let backend =
1349+
match
1350+
Option.value_or_thunk backend_name ~default:(fun () ->
1351+
Utils.get_global_arg ~arg_name:"backend" ~default:"pipes_cc")
1352+
|> String.lowercase
1353+
with
1354+
| "cc" -> (module Cc_backend : Backend)
1355+
| "gccjit" -> (module Gccjit_backend : Backend)
1356+
| "sync_cc" -> (module Sync_cc_backend : Backend)
1357+
| "sync_gccjit" -> (module Sync_gccjit_backend : Backend)
1358+
| "pipes_cc" -> (module Pipes_cc_backend : Backend)
1359+
| "pipes_gccjit" -> (module Pipes_gccjit_backend : Backend)
1360+
| "cuda" -> (module Cuda_backend : Backend)
1361+
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]
1362+
in
1363+
reinitialize backend config;
1364+
backend

bin/einsum_trivia.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module Rand = Arrayjit.Rand.Lib
88

99
let _suspended () =
1010
Rand.init 0;
11-
let module Backend = (val Train.fresh_backend ()) in
11+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
1212
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
1313
let ctx = Backend.init device in
1414
Utils.settings.output_debug_files_in_build_directory <- true;
@@ -25,7 +25,7 @@ let _suspended () =
2525
Utils.settings.log_level <- 2;
2626
Utils.settings.output_debug_files_in_build_directory <- true;
2727
Utils.settings.debug_log_from_routines <- true;
28-
let module Backend = (val Train.fresh_backend ~backend_name:"cuda" ()) in
28+
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name:"cuda" ()) in
2929
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in
3030
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
3131
let ctx = Backend.init device in
@@ -46,7 +46,7 @@ let () =
4646
Utils.settings.log_level <- 2;
4747
Utils.settings.output_debug_files_in_build_directory <- true;
4848
Utils.settings.debug_log_from_routines <- true;
49-
let module Backend = (val Train.fresh_backend ()) in
49+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
5050
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in
5151
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
5252
let ctx = Backend.init device in

bin/hello_world.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module Rand = Arrayjit.Rand.Lib
1010

1111
let hello1 () =
1212
Rand.init 0;
13-
let module Backend = (val Train.fresh_backend ()) in
13+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
1414
Utils.settings.log_level <- 2;
1515
(* Utils.settings.output_debug_files_in_build_directory <- true; *)
1616
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
@@ -27,7 +27,7 @@ let hello1 () =
2727

2828
let hello2 () =
2929
Rand.init 0;
30-
let module Backend = (val Train.fresh_backend ()) in
30+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
3131
Utils.settings.log_level <- 2;
3232
(* Utils.settings.output_debug_files_in_build_directory <- true; *)
3333
(* Utils.settings.debug_log_from_routines <- true; *)
@@ -43,7 +43,7 @@ let hello2 () =
4343

4444
let hello3 () =
4545
Rand.init 0;
46-
let module Backend = (val Train.fresh_backend ()) in
46+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
4747
Utils.settings.output_debug_files_in_build_directory <- true;
4848
(* Utils.settings.debug_log_from_routines <- true; *)
4949
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
@@ -69,7 +69,7 @@ let hello3 () =
6969
Stdlib.Format.force_newline ()
7070

7171
let hello4 () =
72-
let module Backend = (val Train.fresh_backend ()) in
72+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
7373
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in
7474
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
7575
let ctx = Backend.init device in
@@ -98,7 +98,7 @@ let hello5 () =
9898
Utils.settings.log_level <- 2;
9999
Utils.settings.output_debug_files_in_build_directory <- true;
100100
Utils.settings.debug_log_from_routines <- true;
101-
let module Backend = (val Train.fresh_backend ()) in
101+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
102102
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in
103103
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
104104
let ctx = Backend.init device in
@@ -113,7 +113,7 @@ let hello6 () =
113113
Utils.settings.log_level <- 2;
114114
Utils.settings.output_debug_files_in_build_directory <- true;
115115
Utils.settings.debug_log_from_routines <- true;
116-
let module Backend = (val Train.fresh_backend ()) in
116+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
117117
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in
118118
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
119119
let ctx = Backend.init device in

bin/micrograd_basic.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module Rand = Arrayjit.Rand.Lib
99
module Debug_runtime = Utils.Debug_runtime
1010

1111
let%diagn_sexp () =
12-
let module Backend = (val Train.fresh_backend ~backend_name:"cc" ()) in
12+
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name:"cc" ()) in
1313
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
1414
let ctx = Backend.init device in
1515
Utils.settings.output_debug_files_in_build_directory <- true;
@@ -34,7 +34,7 @@ let%diagn_sexp () =
3434
Tensor.print ~with_code:false ~with_grad:true `Default @@ b
3535

3636
let%diagn_sexp _suspended () : unit =
37-
let module Backend = (val Train.fresh_backend ()) in
37+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
3838
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
3939
let ctx = Backend.init device in
4040
(* Utils.settings.output_debug_files_in_build_directory <- true; *)

bin/micrograd_demo.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
7878
Train.set_hosted learning_rate.value;
7979
let sgd = Train.sgd_update ~learning_rate ~weight_decay update in
8080

81-
let module Backend = (val Train.fresh_backend ()) in
81+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
8282
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
8383
let ctx = Backend.init device in
8484
let routine = Backend.(link ctx @@ compile bindings (Seq (update.fwd_bprop, sgd))) in
@@ -177,7 +177,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
177177
Backend.unsafe_cleanup ()
178178

179179
let () = experiment 4 ~no_batch_shape_inference:true ~use_builtin_weight_decay:true ()
180-
181180
let () = experiment 4 ~no_batch_shape_inference:false ~use_builtin_weight_decay:false ()
182181

183182
let _suspended () =

bin/moons_benchmark.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
6969
let start_time = ref None in
7070
let weight_decay = 0.0002 in
7171
Arrayjit.Backends.sync_suggested_num_virtual_devices := num_devices;
72-
let backend = Train.fresh_backend ~backend_name () in
72+
let backend = Arrayjit.Backends.fresh_backend ~backend_name () in
7373
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
7474
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
7575
in

bin/moons_demo.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ let demo () =
5252

5353
let epoch_loss = ref 0. in
5454

55-
let module Backend = (val Train.fresh_backend ~backend_name:"cuda" ()) in
55+
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name:"cuda" ()) in
5656
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
5757
let ctx = Backend.init device in
5858
let routine = Backend.(link ctx @@ compile bindings (Seq (update.fwd_bprop, sgd))) in

bin/moons_demo_parallel.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ let experiment ~seed ~backend_name ~config () =
4444
computation. *)
4545
let weight_decay = 0.0002 in
4646
(* So that we can inspect them. *)
47-
let backend = Train.fresh_backend ~backend_name ~config () in
47+
let backend = Arrayjit.Backends.fresh_backend ~backend_name ~config () in
4848
let per_batch_callback ~at_batch ~at_step ~learning_rate ~batch_loss ~epoch_loss =
4949
if (at_batch + 1) % 20 = 0 then
5050
Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step

bin/zero2hero_1of7.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
1616

1717
let _suspended () =
1818
Rand.init 0;
19-
let module Backend = (val Train.fresh_backend ()) in
19+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
2020
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
2121
let ctx = Backend.init device in
2222
let%op v = ("w" [ (-3, 1) ] * "x" [ 2; 0 ]) + "b" [ 6.7 ] in
@@ -35,7 +35,7 @@ let _suspended () =
3535
CDSL.virtualize_settings.enable_device_only <- false;
3636
let%op f x = (3 *. (x **. 2)) - (4 *. x) + 5 in
3737
let%op f5 = f 5 in
38-
let module Backend = (val Train.fresh_backend ()) in
38+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
3939
Train.every_non_literal_on_host f5;
4040
Train.forward_and_forget
4141
(module Backend)
@@ -65,7 +65,7 @@ let () =
6565
(* let x = Operation.slice ~label:[ "x" ] ~grad_spec:Require_grad step_sym x_flat in *)
6666
Train.set_hosted (Option.value_exn ~here:[%here] x.diff).grad;
6767
let%op fx = f x in
68-
let module Backend = (val Train.fresh_backend ()) in
68+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
6969
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
7070
let ctx = Backend.init device in
7171
let update = Train.grad_update fx in
@@ -98,7 +98,7 @@ let _suspended () =
9898
Utils.settings.output_debug_files_in_build_directory <- true;
9999
(* Utils.settings.debug_log_from_routines <- true; *)
100100
Rand.init 0;
101-
let module Backend = (val Train.fresh_backend ()) in
101+
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
102102
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in
103103
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
104104
let ctx = Backend.init device in
@@ -159,7 +159,7 @@ let _suspended () =
159159
let%op d = e + "c" [ 10 ] in
160160
let%op l = d *. "f" [ -2 ] in
161161
Train.every_non_literal_on_host l;
162-
let open (val Train.fresh_backend ()) in
162+
let open (val Arrayjit.Backends.fresh_backend ()) in
163163
let device = new_virtual_device @@ get_device ~ordinal:0 in
164164
let update = Train.grad_update l in
165165
let routine = link (init device) @@ compile IDX.empty @@ update.fwd_bprop in

lib/train.ml

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,6 @@ end
4343

4444
let run jitted = Tn.run jitted.BT.schedule
4545

46-
(** Reinitializes a backend selected via a global [backend] flag. *)
47-
let fresh_backend ?backend_name ?(config = BT.Physical_devices_only) () =
48-
let module B = Arrayjit.Backends in
49-
let backend =
50-
match
51-
Option.value_or_thunk backend_name ~default:(fun () ->
52-
Arrayjit.Utils.get_global_arg ~arg_name:"backend" ~default:"pipes_cc")
53-
|> String.lowercase
54-
with
55-
| "cc" -> (module B.Cc_backend : B.Backend)
56-
| "gccjit" -> (module B.Gccjit_backend : B.Backend)
57-
| "sync_cc" -> (module B.Sync_cc_backend : B.Backend)
58-
| "sync_gccjit" -> (module B.Sync_gccjit_backend : B.Backend)
59-
| "pipes_cc" -> (module B.Pipes_cc_backend : B.Backend)
60-
| "pipes_gccjit" -> (module B.Pipes_gccjit_backend : B.Backend)
61-
| "cuda" -> (module B.Cuda_backend : B.Backend)
62-
| backend -> invalid_arg [%string "Train.fresh_backend: unknown backend %{backend}"]
63-
in
64-
B.reinitialize backend config;
65-
backend
66-
6746
let is_param t =
6847
match t with
6948
| { Tensor.children = []; diff = Some _; _ } -> not @@ Tn.known_not_param t.value

0 commit comments

Comments
 (0)