Skip to content

Commit 2285eaf

Browse files
committed
In progress: Allow inlining in more cases
WARNING: performance regression This commit introduces a new configuration option, `inline_simple_computations`, to control inlining behavior for computations built from index embeddings and scalar constant expressions.
1 parent 5659e7a commit 2285eaf

File tree

5 files changed

+64
-11
lines changed

5 files changed

+64
-11
lines changed

arrayjit/lib/low_level.ml

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ type virtualize_settings = {
9090
mutable max_visits : int;
9191
mutable max_tracing_dim : int;
9292
mutable inline_scalar_constexprs : bool;
93+
mutable inline_simple_computations : bool;
9394
}
9495

9596
let virtualize_settings =
@@ -105,7 +106,16 @@ let virtualize_settings =
105106
let inline_scalar_constexprs =
106107
Bool.of_string @@ Utils.get_global_arg ~arg_name:"inline_scalar_constexprs" ~default:"true"
107108
in
108-
{ enable_device_only; max_visits; max_tracing_dim; inline_scalar_constexprs }
109+
let inline_simple_computations =
110+
Bool.of_string @@ Utils.get_global_arg ~arg_name:"inline_simple_computations" ~default:"true"
111+
in
112+
{
113+
enable_device_only;
114+
max_visits;
115+
max_tracing_dim;
116+
inline_scalar_constexprs;
117+
inline_simple_computations;
118+
}
109119

110120
type visits = Visits of int | Recurrent [@@deriving sexp, equal, variants]
111121

@@ -118,6 +128,7 @@ type traced_array = {
118128
mutable read_before_write : bool;
119129
mutable read_only : bool;
120130
mutable is_scalar_constexpr : bool;
131+
mutable is_complex : bool;
121132
}
122133
[@@deriving sexp_of]
123134

@@ -147,6 +158,7 @@ let get_node store tn =
147158
read_before_write = false;
148159
read_only = false;
149160
is_scalar_constexpr = false;
161+
is_complex = false;
150162
})
151163

152164
let visit ~is_assigned old =
@@ -175,6 +187,24 @@ let is_constexpr_comp traced_store llv =
175187
in
176188
loop llv
177189

190+
let is_complex_comp traced_store llv =
191+
let rec loop llv =
192+
match llv with
193+
| Get_local { tn; _ } | Local_scope { id = { tn; _ }; _ } ->
194+
let traced = get_node traced_store tn in
195+
traced.is_complex
196+
| Access (_, _) -> true
197+
| Get (tn, _) ->
198+
let traced = get_node traced_store tn in
199+
not traced.is_scalar_constexpr
200+
| Ternop (_, v1, v2, v3) -> loop v1 || loop v2 || loop v3
201+
| Binop (_, v1, v2) -> loop v1 || loop v2
202+
| Unop (_, v) -> loop v
203+
| Constant _ -> false
204+
| Embed_index _ -> false
205+
in
206+
loop llv
207+
178208
let is_scalar_dims tn = Array.for_all ~f:(( = ) 1) @@ Lazy.force tn.Tn.dims
179209

180210
let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
@@ -187,23 +217,27 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
187217
List.fold symbols ~init:offset ~f:(fun acc (coeff, s) ->
188218
acc + (coeff * (Option.value ~default:0 @@ Map.find env s))))
189219
in
190-
let rec loop_proc env llc =
191-
let loop = loop_proc env in
220+
let rec loop_proc ~first_visit env llc =
221+
let loop = loop_proc ~first_visit env in
192222
match llc with
193223
| Noop -> ()
194224
| (Seq (c1, c2) : t) ->
195225
loop c1;
196226
loop c2
197227
| For_loop { index; from_; to_ = _; body; trace_it = false } ->
198-
loop_proc (Map.add_exn ~key:index ~data:from_ env) body
228+
loop_proc ~first_visit (Map.add_exn ~key:index ~data:from_ env) body
199229
| For_loop { index; from_; to_; body; trace_it = true } ->
200230
for data = from_ to min to_ (from_ + virtualize_settings.max_tracing_dim) do
201-
loop_proc (Map.add_exn ~key:index ~data env) body
231+
loop_proc
232+
~first_visit:(first_visit && data = from_)
233+
(Map.add_exn ~key:index ~data env)
234+
body
202235
done
203236
| Zero_out tn ->
204237
let traced : traced_array = get_node traced_store tn in
205238
if Hash_set.is_empty traced.assignments && Hashtbl.is_empty traced.accesses then (
206239
traced.zero_initialized <- true;
240+
traced.is_complex <- false;
207241
if is_scalar_dims tn then traced.is_scalar_constexpr <- true);
208242
traced.zeroed_out <- true
209243
| Set { tn; idcs; llv; debug = _ } ->
@@ -215,6 +249,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
215249
then traced.is_scalar_constexpr <- is_constexpr_comp traced_store llv
216250
(* Note: this prevents detection if the same constant is assigned inside a loop. *)
217251
else if not @@ Hash_set.is_empty traced.assignments then traced.is_scalar_constexpr <- false;
252+
if first_visit then
253+
traced.is_complex <- traced.is_complex || is_complex_comp traced_store llv;
218254
Hash_set.add traced.assignments (lookup env idcs);
219255
Array.iter idcs ~f:(function
220256
| Fixed_idx _ -> ()
@@ -238,7 +274,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
238274
let at_pos = lookup env indices in
239275
Hashtbl.update traced.accesses at_pos
240276
~f:(visit ~is_assigned:(traced.zeroed_out || Hash_set.mem traced.assignments at_pos))
241-
| Local_scope { body; _ } -> loop_proc env body
277+
| Local_scope { body; _ } -> loop_proc ~first_visit:true env body
242278
| Get_local _ -> ()
243279
| Access (Merge_buffer { source }, _) ->
244280
let source_node_id = source.Tn.id in
@@ -263,15 +299,21 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
263299
loop llv2
264300
| Unop (_, llv) -> loop llv
265301
in
266-
loop_proc Indexing.empty_env llc;
302+
loop_proc ~first_visit:true Indexing.empty_env llc;
267303
Hashtbl.iter traced_store ~f:(fun traced ->
268304
let tn = traced.tn in
269305
if
270306
virtualize_settings.inline_scalar_constexprs && traced.is_scalar_constexpr
271307
&& not (Tn.known_non_virtual tn)
272308
then Tn.update_memory_mode tn Virtual 40;
273-
if Option.is_none tn.memory_mode && Hashtbl.exists traced.accesses ~f:is_too_many then
274-
Tn.update_memory_mode tn Never_virtual 1;
309+
let skip_simple =
310+
virtualize_settings.inline_simple_computations && (not traced.is_complex)
311+
&& not (Tn.known_non_virtual tn)
312+
in
313+
if
314+
(not skip_simple) && Option.is_none tn.memory_mode
315+
&& Hashtbl.exists traced.accesses ~f:is_too_many
316+
then Tn.update_memory_mode tn Never_virtual 1;
275317
if (not traced.zeroed_out) && Hash_set.is_empty traced.assignments then (
276318
(* The tensor node is read-only/recurrent for this computation, but maybe computed or
277319
specified as virtual by another routine. However, if the memory mode is unspecified, we
@@ -287,7 +329,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
287329
before the first routine using it gets compiled; another routine re-uses that \
288330
computation. Debug: %{Tn.debug_memory_mode tn.Tn.memory_mode}"]))
289331
else if Tn.known_non_virtual tn then Tn.update_memory_mode tn Materialized 35);
290-
if Hashtbl.exists traced.accesses ~f:is_recurrent then (
332+
(* We allow sharing virtual nodes across routines. *)
333+
if Hashtbl.exists traced.accesses ~f:is_recurrent && not (Tn.known_virtual tn) then (
291334
traced.read_before_write <- true;
292335
if Tn.mode_is_unspecified tn then
293336
Tn.update_memory_mode tn (Hosted (Changed_on_devices Unset)) 38

arrayjit/lib/low_level.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ type virtualize_settings = {
7676
mutable max_visits : int;
7777
mutable max_tracing_dim : int;
7878
mutable inline_scalar_constexprs : bool;
79+
mutable inline_simple_computations : bool;
7980
}
8081

8182
val virtualize_settings : virtualize_settings
@@ -99,6 +100,9 @@ type traced_array = {
99100
(** True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned
100101
before accessed, is assigned at most once, and from an expression involving only constants
101102
or tensor nodes that were at the time is_scalar_constexpr. *)
103+
mutable is_complex : bool;
104+
(** False only if the tensor node is built from index embeddings and scalar constant
105+
expressions. *)
102106
}
103107
[@@deriving sexp_of]
104108

arrayjit/lib/lowering_and_inlining.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ The optimization behavior is controlled by `virtualize_settings`:
116116
- `max_visits`: Maximum number of times a tensor can be accessed before being materialized
117117
- `max_tracing_dim`: Maximum dimension size for loop unrolling during analysis
118118
- `enable_device_only`: Whether to prefer device-only storage when possible
119-
- `inline_scalar_constexprs`: Whether to inline scalar constant expressions
119+
- `inline_scalar_constexprs`: Whether to inline scalar constant expressions regardless of accesses
120+
- `inline_simple_computations`: Currently, whether to inline computations built from index embeddings and scalar constant expressions, regardless of accesses
120121

121122
## Memory Mode Management
122123

arrayjit/lib/tnode.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ let known_volatile tn = match tn.memory_mode with Some (Hosted Volatile, _) -> t
260260
let known_non_virtual tn =
261261
match tn.memory_mode with None | Some ((Virtual | Effectively_constant), _) -> false | _ -> true
262262

263+
let known_virtual tn = match tn.memory_mode with Some (Virtual, _) -> true | _ -> false
264+
263265
let known_shared_cross_streams tn =
264266
match tn.memory_mode with
265267
| Some

ocannl_config.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ enable_device_only=true
9393
# If true, scalar constant expressions will always be inlined.
9494
inline_scalar_constexprs=true
9595

96+
# If true, if the tensor node is built from index embeddings and scalar constant expressions, regardless of accesses, it will be inlined.
97+
inline_simple_computations=true
98+
9699
# The random number library. Options: `stdlib` -- `Base.Random`;
97100
# `for_tests` -- simplistic randomness with 32 bit seed, focused on reproducibility.
98101
randomness_lib=stdlib

0 commit comments

Comments
 (0)