Skip to content

Commit 864cc24

Browse files
committed
Formatting update
1 parent 233a7b2 commit 864cc24

File tree

6 files changed

+16
-19
lines changed

6 files changed

+16
-19
lines changed

arrayjit/lib/backends.mli

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ val reinitialize : (module Backend_types.Backend) -> Backend_types.config -> uni
88
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)
99

1010
val fresh_backend :
11-
?backend_name:string ->
12-
?config:Backend_types.config ->
13-
unit ->
14-
(module Backend_types.Backend)
11+
?backend_name:string -> ?config:Backend_types.config -> unit -> (module Backend_types.Backend)
1512
(** Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
1613
the global [backend] setting. See {!reinitialize}. *)

arrayjit/lib/dune

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@
2424
(-> cuda_backend.missing.ml))
2525
ppx_minidebug.runtime)
2626
(preprocess
27-
(pps ppx_compare ppx_hash ppx_here ppx_sexp_conv ppx_string ppx_variants_conv ppx_minidebug))
27+
(pps
28+
ppx_compare
29+
ppx_hash
30+
ppx_here
31+
ppx_sexp_conv
32+
ppx_string
33+
ppx_variants_conv
34+
ppx_minidebug))
2835
(modules
2936
utils
3037
rand

arrayjit/lib/task.ml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
88
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
99

1010
type t =
11-
| Task : {
12-
context_lifetime : ('a[@sexp.opaque]);
13-
description : string;
14-
work : unit -> unit;
15-
}
16-
-> t
11+
| Task : { context_lifetime : ('a[@sexp.opaque]); description : string; work : unit -> unit } -> t
1712
[@@deriving sexp_of]
1813

1914
let describe (Task task) = task.description
@@ -32,7 +27,8 @@ let prepend ~work (Task task) =
3227
task.work ());
3328
}
3429

35-
let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream (Task { description; _ } as task) =
30+
let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream
31+
(Task { description; _ } as task) =
3632
[%log_result "enschedule", description, "on", get_stream_name stream];
3733
let work () = schedule_task stream task in
3834
Task

bin/moons_benchmark.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ let _mem_benchmarks =
216216
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
217217
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
218218
List.concat_map [ (* "gccjit" ; *) "cc"; "cuda" ] ~f:(fun backend_name ->
219-
List.concat_map [ (* CDSL.double; *) CDSL.single ; CDSL.half ]
219+
List.concat_map [ (* CDSL.double; *) CDSL.single; CDSL.half ]
220220
~f:(fun value_prec ->
221221
[
222222
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams

bin/zero2hero_1of7.ml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ let _suspended () =
3737
let%op f5 = f 5 in
3838
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
3939
Train.every_non_literal_on_host f5;
40-
Train.forward_and_forget
41-
(module Backend)
42-
Backend.(init @@ new_stream @@ get_device ~ordinal:0)
43-
f5;
40+
Train.forward_and_forget (module Backend) Backend.(init @@ new_stream @@ get_device ~ordinal:0) f5;
4441
Stdio.printf "\n%!";
4542
Tensor.print_tree ~with_grad:false ~depth:9 f5;
4643
Stdio.printf "\n%!"

lib/tensor.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ val consume_backprop_code : t -> asgns * comp
194194

195195
val iter_embedded : f:(tn -> unit) -> t -> unit
196196
(** [iter_embedded t] iterates over all descendant nodes that are embedded, i.e. are not members of
197-
[t.forward.embedded_nodes] or '[t.diff.backprop.embedded_nodes]' (if any). Note: [iter_embedded] should only be
198-
called after shape inference finishes. *)
197+
[t.forward.embedded_nodes] or '[t.diff.backprop.embedded_nodes]' (if any). Note: [iter_embedded]
198+
should only be called after shape inference finishes. *)
199199

200200
val unsafe_reinitialize : unit -> unit
201201
(** Bring global state to its initialization values. This invalidates any previously defined tensors

0 commit comments

Comments
 (0)