Skip to content

Commit f7224a3

Browse files
committed
By Claude Sonnet: fix missing Affine expansion during inlining, document low_level.ml optimizations
1 parent 982b813 commit f7224a3

File tree

2 files changed

+119
-76
lines changed

2 files changed

+119
-76
lines changed

arrayjit/lib/low_level.ml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,29 @@ let inline_computation ~id traced static_indices call_args =
416416
in
417417
let subst env = function
418418
| Indexing.Iterator s when Map.mem env s -> Map.find_exn env s
419+
| Indexing.Affine { symbols; offset } ->
420+
(* We need to substitute each symbol in the affine expression.
421+
If a symbol maps to a non-Iterator, we need to handle it specially. *)
422+
let expand_symbol (coeff, s) =
423+
match Map.find env s with
424+
| Some (Indexing.Iterator new_s) -> [(coeff, new_s)]
425+
| Some (Indexing.Fixed_idx _) -> [] (* Fixed index contributes to offset *)
426+
| Some (Indexing.Affine { symbols = inner_symbols; offset = _ }) ->
427+
(* Expand nested affine: coeff * (inner_symbols + inner_offset) *)
428+
List.map inner_symbols ~f:(fun (inner_coeff, inner_s) -> (coeff * inner_coeff, inner_s))
429+
| None -> [(coeff, s)]
430+
in
431+
let all_terms = List.concat_map symbols ~f:expand_symbol in
432+
(* Calculate the new offset by adding contributions from Fixed_idx substitutions *)
433+
let offset_additions =
434+
List.fold symbols ~init:0 ~f:(fun acc (coeff, s) ->
435+
match Map.find env s with
436+
| Some (Indexing.Fixed_idx i) -> acc + (coeff * i)
437+
| Some (Indexing.Affine { offset = inner_offset; _ }) -> acc + (coeff * inner_offset)
438+
| _ -> acc)
439+
in
440+
let new_offset = offset + offset_additions in
441+
Indexing.Affine { symbols = all_terms; offset = new_offset }
419442
| idx -> idx
420443
in
421444
let rec loop env llc : t option =
Lines changed: 96 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,10 @@
11
# Compilation to Cross-Backend Low-Level Representation, and Backend-Independent Optimizations
22

3-
Computation in OCANNL is imperative. At the high-level, we store tensor node assignments as `Assignments.t`:
3+
Computation in OCANNL is imperative. At the high-level, we store tensor node assignments as `Assignments.t`, which provides high-level operations like `Accum_binop`, `Accum_unop`, and `Fetch`. This is translated to a low-level representation `Low_level.t` which is a C-like mini-language operating on scalars.
44

5-
```ocaml
6-
(** Resets a array by performing the specified computation or data fetching. *)
7-
type fetch_op =
8-
| Constant of float
9-
| Access of Low_level.dedicated_access
10-
| Slice of { batch_idx : Indexing.static_symbol; sliced : Tnode.t }
11-
| Embed_symbol of Indexing.static_symbol
12-
13-
and t =
14-
| Noop
15-
| Seq of t * t
16-
| Block_comment of string * t (** Same as the given code, with a comment. *)
17-
| Accum_binop of {
18-
initialize_neutral : bool;
19-
accum : Ops.binop;
20-
op : Ops.binop;
21-
lhs : Tnode.t;
22-
rhs1 : Tnode.t;
23-
rhs2 : Tnode.t;
24-
projections : Indexing.projections Lazy.t;
25-
}
26-
| Accum_unop of {
27-
initialize_neutral : bool;
28-
accum : Ops.binop;
29-
op : Ops.unop;
30-
lhs : Tnode.t;
31-
rhs : Tnode.t;
32-
projections : Indexing.projections Lazy.t;
33-
}
34-
| Fetch of { array : Tnode.t; fetch_op : fetch_op; dims : int array Lazy.t }
35-
```
36-
37-
The effect of `Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections }` is:
38-
39-
> if `initialize_neutral` then `lhs` := neutral value of `accum`;
40-
> `lhs` := `lhs` `accum` (`rhs1` `op` `rhs2`)
5+
## Low-Level Representation
416

42-
The `Assignments` module depends on the `Low_level` module and puts the pieces together in the `compile_proc` function. In addition to the assignments, `compile_proc` takes a `Indexing.static_symbol list` of the static indices, currently they are needed for optimization but not remembered in the `Assignments.t` nor `Low_level.t` types.
43-
44-
The low-level representation is a C-like mini-language operating on scalars.
7+
The `Low_level.t` type represents a C-like imperative language with for loops and scalar operations:
458

469
```ocaml
4710
type t =
@@ -55,68 +18,125 @@ type t =
5518
| Set_local of scope_id * float_t
5619
5720
and float_t =
58-
| Local_scope of {
59-
id : scope_id;
60-
prec : Ops.prec;
61-
body : t;
62-
orig_indices : Indexing.axis_index array;
63-
}
21+
| Local_scope of { id : scope_id; body : t; orig_indices : Indexing.axis_index array }
6422
| Get_local of scope_id
65-
| Access of Low_level.dedicated_access * Indexing.axis_index array option
23+
| Access of dedicated_access * Indexing.axis_index array option
6624
| Get of Tnode.t * Indexing.axis_index array
25+
| Ternop of Ops.ternop * float_t * float_t * float_t
6726
| Binop of Ops.binop * float_t * float_t
6827
| Unop of Ops.unop * float_t
6928
| Constant of float
7029
| Embed_index of Indexing.axis_index
7130
```
7231

73-
The odd part is the `Staged_compilation` element. Backends can use `Staged_compilation` to embed some emitted code within on-the-fly generated `Low_level.t` code. Currently this works only for `PPrint.document` based backends like `C_syntax` derivatives, but this covers almost all backends.
32+
`t` represents code/statements while `float_t` represents scalar expressions. The `trace_it` flag in `For_loop` indicates whether the loop should be traced for optimization (its initial segment will be unrolled for analysis).
7433

75-
TODO: flesh out explanation.
34+
## Translation from Assignments
7635

77-
## Translation
36+
The translation `Assignments.to_low_level` is straightforward:
7837

79-
The translation `Assignments.to_low_level` is straightforward. Commented code blocks are delineated by `Low_level.Comment "end"` statements. Indices into tensor nodes are derived from the `projections` fields. We translate `projections.product_space` elements into for loops. `to_low_level` returns all the data that `Low_level` optimizations generated, so that backends can make more informed decisions when jitting, i.e. emitting the backend-specific code.
38+
1. **Projections to Loops**: `projections.product_space` elements become nested for loops
39+
2. **Index Translation**: Tensor indices are derived from `projections.project_lhs` and `projections.project_rhs`
40+
3. **Operations**: High-level operations like `Accum_binop` become loops over scalar operations
41+
4. **Initialization**: If `initialize_neutral` is true and the operation isn't total, we initialize with the neutral element
8042

81-
## Inlining
43+
## Backend-Independent Optimizations
8244

83-
Inlining is a process where we take the computations pertaining to a tensor node, and inline them at the `Get` access sites on a per-scalar basis.
45+
The optimization pipeline in `optimize_proc` consists of three main phases:
8446

85-
```ocaml
86-
type virtualize_settings = {
87-
mutable enable_device_only : bool;
88-
mutable max_visits : int;
89-
mutable max_tracing_dim : int;
90-
}
47+
### 1. Tracing Phase (`visit_llc`)
9148

92-
type visits =
93-
| Visits of int
94-
| Recurrent (** A [Recurrent] visit is when there is an access prior to any assignment in an update. *)
49+
This phase symbolically executes the computation to build a `traced_store` mapping each tensor node to a `traced_array`:
9550

51+
```ocaml
9652
type traced_array = {
97-
nd : Tn.t;
53+
tn : Tn.t;
9854
mutable computations : (Indexing.axis_index array option * t) list;
99-
(** The computations (of the tensor node) are retrieved for optimization just as they are populated,
100-
so that the inlined code corresponds precisely to the changes to the arrays that would happen
101-
up till that point. Within the code blocks paired with an index tuple, all assignments and accesses
102-
must happen via the index tuple; if this is not the case for some assignment, the node cannot
103-
be virtual. Currently, we only allow for-loop symbols in assignment indices of virtual nodes. *)
10455
assignments : int array Hash_set.t;
10556
accesses : (int array, visits) Hashtbl.t;
106-
(** For dynamic indexes, we take a value of 0. This leads to an overestimate of visits, which is safe. *)
10757
mutable zero_initialized : bool;
10858
mutable zeroed_out : bool;
109-
mutable read_before_write : bool; (** The node is read before it is written (i.e. it is recurrent). *)
59+
mutable read_before_write : bool;
11060
mutable read_only : bool;
11161
mutable is_scalar_constexpr : bool;
112-
(** True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned
113-
before accessed, is assigned at most once, and from an expression involving only constants
114-
or tensor nodes that were at the time is_scalar_constexpr. *)
11562
}
11663
```
11764

118-
- The `visit_llc` function interprets (symbolically executes) the given computation, and fills in `traced_store`: the map of `traced_array`s.
119-
-
65+
Key analyses performed:
66+
67+
- **Access Pattern Analysis**: Tracks which positions are read/written and how many times (`visits`)
68+
- **Dependency Analysis**: Determines read-before-write patterns (recurrence)
69+
- **Scalar Constant Expression Detection**: Identifies tensor nodes that are constant scalars
70+
- **Memory Mode Inference**: Decides whether tensors should be virtual, materialized, etc.
71+
72+
### 2. Virtualization and Inlining Phase (`virtual_llc`)
73+
74+
This is the core optimization phase that implements **computation inlining**:
75+
76+
#### Virtualization Decision
77+
78+
- Tensors with too many accesses (`> max_visits`) are marked `Never_virtual`
79+
- Read-only tensors are typically materialized
80+
- Recurrent tensors (read-before-write) are materialized
81+
82+
#### Inlining Process (`inline_computation`)
83+
84+
When a tensor node is accessed via `Get`, if it's determined to be virtual:
85+
86+
1. **Retrieve Computations**: Get the stored computations for the tensor from `traced_array.computations`
87+
2. **Symbol Freshening**: Create fresh symbols to avoid variable capture when inlining
88+
3. **Substitution**: Replace the definition's indices with the call site's indices
89+
4. **Code Generation**: Generate a `Local_scope` that computes the value inline
90+
91+
#### Critical Invariant: Symbol Freshening
92+
93+
When inlining, we must ensure that loop variables don't clash. The `subst` function handles index substitution, mapping old symbols to new ones. This is crucial for correctness.
94+
95+
### 3. Cleanup and Simplification Phase (`cleanup_virtual_llc` + `simplify_llc`)
96+
97+
#### Cleanup (`cleanup_virtual_llc`)
98+
99+
- **Environment Validation**: Ensures all symbols are properly bound in their scope
100+
- **Virtual Tensor Removal**: Removes references to virtual tensors that were successfully inlined
101+
- **Constraint Checking**: Validates that symbol substitution was correct
102+
103+
#### Simplification (`simplify_llc`)
104+
105+
A traditional optimizing compiler pass that performs:
106+
107+
- **Constant Folding**: `Constant 2.0 + Constant 3.0``Constant 5.0`
108+
- **Algebraic Simplification**: `x + 0``x`, `x * 1``x`, etc.
109+
- **Dead Code Elimination**: Removes `Local_scope` that just return values
110+
- **Integer Power Unrolling**: `x ** 3``x * x * x` for small integer powers
111+
112+
## Optimization Settings
113+
114+
The optimization behavior is controlled by `virtualize_settings`:
115+
116+
- `max_visits`: Maximum number of times a tensor can be accessed before being materialized
117+
- `max_tracing_dim`: Maximum dimension size for loop unrolling during analysis
118+
- `enable_device_only`: Whether to prefer device-only storage when possible
119+
- `inline_scalar_constexprs`: Whether to inline scalar constant expressions
120+
121+
## Memory Mode Management
122+
123+
The optimization process works closely with OCANNL's memory mode system:
124+
125+
- **Virtual**: Computations are inlined, no storage allocated
126+
- **Materialized**: Tensor is stored and reused
127+
- **Device_only**: Stored only on device, not accessible from host
128+
- **Hosted**: Stored on both host and device
129+
130+
The optimizer uses provenance tracking (the `int` in memory mode tuples) to debug conflicts in memory mode decisions.
131+
132+
## Code Generation Integration
133+
134+
The optimized `Low_level.t` can be:
135+
136+
1. **Printed** using `to_doc` (OCANNL %cd syntax) or `to_doc_cstyle` (C-like syntax)
137+
2. **Backend Compilation**: Each backend pattern-matches on `Low_level.t` to generate device-specific code
138+
3. **Staged Compilation**: `Staged_compilation` nodes allow backends to embed generated code during optimization
120139

140+
The `Staged_compilation` construct is particularly important for backends that need to emit complex code patterns that can't be easily represented in the simple `Low_level.t` grammar.
121141

122-
## Rewriting
142+
This optimization pipeline enables OCANNL to achieve high performance by eliminating intermediate tensor allocations and generating specialized code for each computation pattern.

0 commit comments

Comments
 (0)