Skip to content

Commit 3b98f9b

Browse files
committed
Inline single getters i.e. "views"
1 parent e0f5eb1 commit 3b98f9b

File tree

7 files changed

+88
-137
lines changed

7 files changed

+88
-137
lines changed

arrayjit/lib/low_level.ml

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ type traced_array = {
119119
mutable read_before_write : bool;
120120
mutable read_only : bool;
121121
mutable is_scalar_constexpr : bool;
122+
mutable is_accessing : bool;
122123
mutable is_complex : bool;
123124
}
124125
[@@deriving sexp_of]
@@ -149,6 +150,7 @@ let get_node store tn =
149150
read_before_write = false;
150151
read_only = false;
151152
is_scalar_constexpr = false;
153+
is_accessing = false;
152154
is_complex = false;
153155
})
154156

@@ -180,18 +182,19 @@ let is_constexpr_comp traced_store llsc =
180182
in
181183
loop llsc
182184

183-
let is_complex_comp traced_store llsc =
185+
let is_accessing_comp traced_store llsc =
184186
let rec loop llsc =
185187
match llsc with
186188
| Get_local { tn; _ } | Local_scope { id = { tn; _ }; _ } ->
187189
let traced = get_node traced_store tn in
188-
traced.is_complex
190+
traced.is_accessing
189191
| Get (tn, _) ->
190192
let traced = get_node traced_store tn in
191193
not traced.is_scalar_constexpr
192194
| Get_merge_buffer (tn, _) ->
193195
let traced = get_node traced_store tn in
194-
not traced.is_scalar_constexpr
196+
traced.is_accessing <- true;
197+
true
195198
| Ternop (_, v1, v2, v3) -> loop v1 || loop v2 || loop v3
196199
| Binop (_, v1, v2) -> loop v1 || loop v2
197200
| Unop (_, v) -> loop v
@@ -200,6 +203,20 @@ let is_complex_comp traced_store llsc =
200203
in
201204
loop llsc
202205

206+
let is_complex_comp traced_store llsc =
207+
let accessing = is_accessing_comp traced_store in
208+
match llsc with
209+
| Get_local { tn; _ } | Local_scope { id = { tn; _ }; _ } ->
210+
let traced = get_node traced_store tn in
211+
traced.is_complex
212+
| Get _ -> false
213+
| Get_merge_buffer _ -> false
214+
| Ternop (_, v1, v2, v3) -> accessing v1 || accessing v2 || accessing v3
215+
| Binop (_, v1, v2) -> accessing v1 || accessing v2
216+
| Unop (_, v) -> accessing v
217+
| Constant _ -> false
218+
| Embed_index _ -> false
219+
203220
let is_scalar_dims tn = Array.for_all ~f:(( = ) 1) @@ Lazy.force tn.Tn.dims
204221

205222
let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
@@ -234,6 +251,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
234251
let traced : traced_array = get_node traced_store tn in
235252
if Hash_set.is_empty traced.assignments && Hashtbl.is_empty traced.accesses then (
236253
traced.zero_initialized <- true;
254+
traced.is_accessing <- false;
237255
traced.is_complex <- false;
238256
if is_scalar_dims tn then traced.is_scalar_constexpr <- true);
239257
traced.zeroed_out <- true
@@ -246,8 +264,9 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
246264
then traced.is_scalar_constexpr <- is_constexpr_comp traced_store llsc
247265
(* Note: this prevents detection if the same constant is assigned inside a loop. *)
248266
else if not @@ Hash_set.is_empty traced.assignments then traced.is_scalar_constexpr <- false;
249-
if first_visit then
250-
traced.is_complex <- traced.is_complex || is_complex_comp traced_store llsc;
267+
if first_visit then (
268+
traced.is_accessing <- traced.is_accessing || is_accessing_comp traced_store llsc;
269+
traced.is_complex <- traced.is_complex || is_complex_comp traced_store llsc);
251270
Hash_set.add traced.assignments (lookup env idcs);
252271
Array.iter idcs ~f:(function
253272
| Fixed_idx _ -> ()
@@ -265,7 +284,9 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
265284
let traced : traced_array = get_node traced_store tn in
266285
(* Vector operations cannot be scalar constexpr *)
267286
traced.is_scalar_constexpr <- false;
268-
if first_visit then traced.is_complex <- false;
287+
if first_visit then (
288+
traced.is_accessing <- traced.is_accessing || is_accessing_comp traced_store arg;
289+
traced.is_complex <- traced.is_complex || not (is_constexpr_comp traced_store arg));
269290
(* Mark all positions that will be written to *)
270291
for i = 0 to length - 1 do
271292
let pos_idcs = Array.copy idcs in

arrayjit/lib/low_level.mli

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,12 @@ type traced_array = {
9090
(** True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned
9191
before accessed, is assigned at most once, and from an expression involving only constants
9292
or tensor nodes that were at the time is_scalar_constexpr. *)
93-
mutable is_complex : bool;
93+
mutable is_accessing : bool;
9494
(** False only if the tensor node is built from index embeddings and scalar constant
9595
expressions. *)
96+
mutable is_complex : bool;
97+
(** True only if the tensor node is built acciessing computations that are not a single
98+
getter. *)
9699
}
97100
[@@deriving sexp_of]
98101

arrayjit/lib/lowering_and_inlining.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ The optimization behavior is controlled by `virtualize_settings`:
117117
- `max_tracing_dim`: Maximum dimension size for loop unrolling during analysis
118118
- `enable_device_only`: Whether to prefer device-only storage when possible
119119
- `inline_scalar_constexprs`: Whether to inline scalar constant expressions regardless of accesses
120-
- `inline_simple_computations`: Currently, whether to inline computations built from index embeddings and scalar constant expressions, regardless of accesses
120+
- `inline_simple_computations`: Currently, whether to inline computations built from either single getters, or index embeddings and scalar constant expressions, regardless of accesses
121121

122122
## Memory Mode Management
123123

ocannl_config.example

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ enable_device_only=true
101101
# If true, scalar constant expressions will always be inlined.
102102
inline_scalar_constexprs=true
103103

104-
# If true, if the tensor node is built from index embeddings and scalar constant expressions, regardless of accesses, it will be inlined.
104+
# If true, if the tensor node is built from either single getters, or index embeddings and
105+
# scalar constant expressions, regardless of accesses, it will be inlined.
105106
inline_simple_computations=true
106107

107108
# The random number library. Options: `stdlib` -- `Base.Random`;

test/einsum/moons_demo_variant.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ n17 threefry4x32 as threefry4x32: Virt/15; uint4x32 prec 4; mem in bytes: <not-i
2222
n18 threefry4x32 as threefry4x32: Virt/15; uint4x32 prec 4; mem in bytes: <not-in-yet>
2323
n19 w2 as w2: Host&stream/412410; single prec 1x16; mem in bytes: <not-in-yet>
2424
n20 grad_w2 as w2.grad: Local/26046; single prec 1x16; mem in bytes: <not-in-yet>
25-
n21 @|_moons_input as moons_input: Local/1046; single prec 10x2; mem in bytes: <not-in-yet>
26-
n24 @|_moons_class as moons_class: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
25+
n21 @|_moons_input as moons_input: Virt/15; single prec 10x2; mem in bytes: <not-in-yet>
26+
n24 @|_moons_class as moons_class: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
2727
n27 * as n27: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
2828
n28 grad_* as n27.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
2929
n29 + as n29: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>

0 commit comments

Comments
 (0)