Skip to content

Commit 0a46459

Browse files
committed
inline_complex_computations: if so, don't count self-defining accesses toward virtualize_max_visits
1 parent 5c1b2d6 commit 0a46459

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

arrayjit/lib/low_level.ml

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ let is_complex_comp traced_store llsc =
220220
let is_scalar_dims tn = Array.for_all ~f:(( = ) 1) @@ Lazy.force tn.Tn.dims
221221

222222
let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
223+
let inline_complex_computations =
224+
Utils.get_global_flag ~default:true ~arg_name:"inline_complex_computations"
225+
in
223226
let is_too_many = function Visits i -> i > max_visits | Recurrent -> true in
224227
(* FIXME: migrate hashtable to use offsets instead of indices *)
225228
let lookup env indices =
@@ -256,7 +259,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
256259
if is_scalar_dims tn then traced.is_scalar_constexpr <- true);
257260
traced.zeroed_out <- true
258261
| Set { tn; idcs; llsc; debug = _ } ->
259-
loop_scalar env llsc;
262+
loop_scalar env (Some (lookup env idcs)) llsc;
260263
let traced : traced_array = get_node traced_store tn in
261264
if
262265
Hash_set.is_empty traced.assignments
@@ -280,7 +283,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
280283
let old_tn = Hashtbl.find_or_add reverse_node_map s ~default:(fun () -> tn) in
281284
assert (Tn.equal old_tn tn)))
282285
| Set_from_vec { tn; idcs; length; vec_unop = _; arg; debug = _ } ->
283-
loop_scalar env arg;
286+
loop_scalar env (Some (lookup env idcs)) arg;
284287
let traced : traced_array = get_node traced_store tn in
285288
(* Vector operations cannot be scalar constexpr *)
286289
traced.is_scalar_constexpr <- false;
@@ -323,18 +326,23 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
323326
List.iter symbols ~f:(fun (_, s) ->
324327
let old_tn = Hashtbl.find_or_add reverse_node_map s ~default:(fun () -> tn) in
325328
assert (Tn.equal old_tn tn)))
326-
| Set_local (_, llsc) -> loop_scalar env llsc
329+
| Set_local (_, llsc) -> loop_scalar env None llsc
327330
| Comment _ -> ()
328331
| Staged_compilation _ -> ()
329-
and loop_scalar env llsc =
330-
let loop = loop_scalar env in
332+
and loop_scalar env (access_pos : int array option) llsc =
333+
let loop = loop_scalar env access_pos in
331334
match llsc with
332335
| Constant _ -> ()
333336
| Get (ptr, indices) ->
334337
let traced : traced_array = get_node traced_store ptr in
335338
let at_pos = lookup env indices in
336-
Hashtbl.update traced.accesses at_pos
337-
~f:(visit ~is_assigned:(traced.zeroed_out || Hash_set.mem traced.assignments at_pos))
339+
if
340+
(not inline_complex_computations)
341+
|| Option.value_map access_pos ~default:true ~f:(fun pos ->
342+
not ([%equal: int array] pos at_pos))
343+
then
344+
Hashtbl.update traced.accesses at_pos
345+
~f:(visit ~is_assigned:(traced.zeroed_out || Hash_set.mem traced.assignments at_pos))
338346
| Local_scope { body; _ } -> loop_proc ~first_visit:true env body
339347
| Get_local _ -> ()
340348
| Get_merge_buffer (source, _) ->

ocannl_config.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ inline_scalar_constexprs=true
105105
# scalar constant expressions, regardless of accesses, it will be inlined.
106106
inline_simple_computations=true
107107

108+
# If true, virtualize_max_visits only counts accesses that are not used for assignment of
109+
# the same cell (typically accumulation). Otherwise, all accesses are counted, so computations
110+
# that reduce an axis will rarely be inlined.
111+
inline_complex_computations=true
112+
108113
# The random number library. Options: `stdlib` -- `Base.Random`;
109114
# `for_tests` -- simplistic randomness with 32 bit seed, focused on reproducibility.
110115
randomness_lib=stdlib

0 commit comments

Comments
 (0)