Skip to content

Commit 874fa31

Browse files
committed
Refactoring progress: implement code expansion for the new fetch ops Constant_fill and Range_over_offsets
1 parent 45809ac commit 874fa31

File tree

12 files changed

+108
-93
lines changed

12 files changed

+108
-93
lines changed

arrayjit/lib/assignments.ml

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ type buffer = Node of Tn.t | Merge_buffer of Tn.t [@@deriving sexp_of, equal]
1414
(** Resets a array by performing the specified computation or data fetching. *)
1515
type fetch_op =
1616
| Constant of float
17-
| Constant_fill of { values : float array; strict : bool }
18-
(** Fills in the numbers where the rightmost axis is contiguous. If [strict=false], loops over
19-
the provided values. *)
17+
| Constant_fill of float array
18+
(** Fills in the numbers where the rightmost axis is contiguous. Does not loop over the
19+
provided values; shape inference will require the assigned tensor to have the same number
20+
of elements. This unrolls all assignments and should be used only for small arrays.
21+
Consider using {!Tnode.set_values} instead for larger arrays. *)
2022
| Range_over_offsets
2123
(** Fills in the offset number of each cell, i.e. how many cells away it is from the
2224
beginning, in the logical representation of the tensor node. (The actual in-memory
@@ -141,8 +143,7 @@ let%diagn2_sexp to_low_level code =
141143
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
142144
match buffer with
143145
| Node tn -> Low_level.Get (tn, idcs)
144-
| Merge_buffer tn ->
145-
Low_level.Access (Low_level.Merge_buffer { source = tn }, Some idcs)
146+
| Merge_buffer tn -> Low_level.Access (Low_level.Merge_buffer { source = tn }, Some idcs)
146147
in
147148
let set tn idcs llv =
148149
if not (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)) then
@@ -239,18 +240,13 @@ let%diagn2_sexp to_low_level code =
239240
| Fetch { array; fetch_op = Access global; dims } ->
240241
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
241242
set array idcs @@ Access (global, Some idcs))
242-
| Fetch { array; fetch_op = Range_over_offsets; dims } ->
243-
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
244-
let offset = Array.foldi idcs ~init:0 ~f:(fun _i acc idx ->
245-
match idx with
246-
| Fixed_idx j -> acc + j
247-
| Iterator _ -> acc (* Will be computed dynamically *)
248-
| Affine _ -> acc (* Will be computed dynamically *)) in
249-
set array idcs @@ Constant (Float.of_int offset))
250-
| Fetch { array; fetch_op = Constant_fill { values; strict }; dims } ->
251-
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
252-
let value = if strict then values.(0) else values.(0) in (* TODO: implement proper indexing *)
253-
set array idcs @@ Constant value)
243+
| Fetch { array; fetch_op = Range_over_offsets; dims = (lazy dims) } ->
244+
Low_level.loop_over_dims dims ~body:(fun idcs ->
245+
let offset = Indexing.reflect_projection ~dims ~projection:idcs in
246+
set array idcs @@ Embed_index offset)
247+
| Fetch { array; fetch_op = Constant_fill values; dims = (lazy dims) } ->
248+
Low_level.unroll_dims dims ~body:(fun idcs ~offset ->
249+
set array idcs @@ Constant values.(offset))
254250
in
255251
loop code
256252

@@ -315,13 +311,14 @@ let to_doc ?name ?static_indices () c =
315311
let doc_of_fetch_op (op : fetch_op) =
316312
match op with
317313
| Constant f -> string (Float.to_string f)
318-
| Constant_fill { values; strict } ->
319-
let values_str = String.concat ~sep:", " (Array.to_list (Array.map values ~f:Float.to_string)) in
320-
string ("constant_fill([" ^ values_str ^ "], strict=" ^ Bool.to_string strict ^ ")")
314+
| Constant_fill values ->
315+
let values_str =
316+
String.concat ~sep:", " (Array.to_list (Array.map values ~f:Float.to_string))
317+
in
318+
string ("constant_fill([" ^ values_str ^ "])")
321319
| Range_over_offsets -> string "range_over_offsets"
322320
| Access (Low_level.C_function c) -> string (c ^ "()")
323-
| Access (Low_level.Merge_buffer { source }) ->
324-
string (ident source ^ ".merge")
321+
| Access (Low_level.Merge_buffer { source }) -> string (ident source ^ ".merge")
325322
| Access (Low_level.External_unsafe { ptr; prec; dims = _ }) ->
326323
string (Ops.ptr_to_string_hum ptr prec)
327324
| Access (Low_level.File_mapped (file, file_prec)) ->

arrayjit/lib/indexing.ml

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -198,32 +198,18 @@ let identity_projections ?debug_info ?derived_for ~lhs_dims () =
198198
debug_info;
199199
}
200200

201-
let derive_index ~product_syms ~(projection : axis_index array) =
202-
let sym_to_i =
203-
Array.mapi product_syms ~f:(fun i s -> (s, i))
204-
|> Array.to_list
205-
|> Map.of_alist_exn (module Symbol)
206-
in
207-
let positions =
208-
Array.map projection ~f:(function
209-
| Iterator s when Map.mem sym_to_i s -> Either.First (Map.find_exn sym_to_i s)
210-
| Fixed_idx _ as it -> Second it
211-
| Affine _ as it -> Second it
212-
| Iterator _ as it -> Second it)
213-
in
214-
fun ~product ->
215-
Array.map positions ~f:(function
216-
| First p -> product.(p)
217-
| Second (Fixed_idx i) -> i
218-
| Second (Iterator s) ->
219-
(* This shouldn't happen if sym_to_i is complete *)
220-
failwith ("derive_index: unresolved iterator " ^ symbol_ident s)
221-
| Second (Affine { symbols; offset }) ->
222-
List.fold symbols ~init:offset ~f:(fun acc (coeff, s) ->
223-
match Map.find sym_to_i s with
224-
| Some idx -> acc + (coeff * product.(idx))
225-
| None ->
226-
failwith ("derive_index: unresolved symbol in affine index " ^ symbol_ident s)))
201+
let reflect_projection ~(dims : int array) ~(projection : axis_index array) =
202+
Array.zip_exn dims projection
203+
|> Array.fold_right ~init:(1, [], 0) ~f:(fun (dim, idx) (stride, symbols, offset) ->
204+
match idx with
205+
| Fixed_idx fixed_offset -> (stride * dim, symbols, offset + (fixed_offset * stride))
206+
| Iterator sym -> (stride * dim, (stride, sym) :: symbols, offset)
207+
| Affine { symbols = affine_symbols; offset = affine_offset } ->
208+
let new_symbols =
209+
List.map affine_symbols ~f:(fun (coeff, sym) -> (coeff * stride, sym))
210+
in
211+
(stride * dim, new_symbols @ symbols, offset + (affine_offset * stride)))
212+
|> fun (_, symbols, offset) -> Affine { symbols; offset }
227213

228214
module Pp_helpers = struct
229215
open PPrint

arrayjit/lib/low_level.ml

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ let%diagn2_sexp check_and_store_virtual traced static_indices top_llc =
295295
function
296296
| Fixed_idx _ -> None
297297
| Iterator s -> Option.some_if (not @@ Set.mem static_indices s) s
298-
| Affine { symbols; offset } -> (
298+
| Affine { symbols; offset = _ } -> (
299299
(* For affine indices, collect all symbols that are not static *)
300300
List.filter_map symbols ~f:(fun (_, s) ->
301301
Option.some_if (not @@ Set.mem static_indices s) s)
@@ -991,18 +991,21 @@ let to_doc_cstyle ?name ?static_indices () llc =
991991
string (Ops.ptr_to_string_hum ptr prec)
992992
| Access (External_unsafe { ptr; prec; dims = _ }, Some idcs) ->
993993
string (Ops.ptr_to_string_hum ptr prec) ^^ brackets (pp_indices idcs)
994-
| Access (Merge_buffer { source }, None) ->
995-
doc_ident source ^^ string ".merge"
994+
| Access (Merge_buffer { source }, None) -> doc_ident source ^^ string ".merge"
996995
| Access (Merge_buffer { source }, Some idcs) ->
997996
group (doc_ident source ^^ string ".merge" ^^ brackets (pp_indices idcs))
998997
| Access (File_mapped (file, prec), None) ->
999-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.precision_to_string prec ^ ")")
998+
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
1000999
| Access (File_mapped (file, prec), Some idcs) ->
1001-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.precision_to_string prec ^ ")") ^^ brackets (pp_indices idcs)
1000+
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
1001+
^^ brackets (pp_indices idcs)
10021002
| Access (Uint4x32_to_prec_uniform { source; prec }, None) ->
1003-
string ("uint4x32_to_" ^ Ops.precision_to_string prec ^ "_uniform(") ^^ doc_ident source ^^ string ")"
1003+
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1004+
^^ doc_ident source ^^ string ")"
10041005
| Access (Uint4x32_to_prec_uniform { source; prec }, Some idcs) ->
1005-
string ("uint4x32_to_" ^ Ops.precision_to_string prec ^ "_uniform(") ^^ doc_ident source ^^ string ")" ^^ brackets (pp_indices idcs)
1006+
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1007+
^^ doc_ident source ^^ string ")"
1008+
^^ brackets (pp_indices idcs)
10061009
| Get (tn, idcs) -> group (doc_ident tn ^^ brackets (pp_indices idcs))
10071010
| Constant c -> string (Printf.sprintf "%.16g" c)
10081011
| Embed_index idx -> pp_axis_index idx
@@ -1075,18 +1078,21 @@ let to_doc ?name ?static_indices () llc =
10751078
string (Ops.ptr_to_string_hum ptr prec)
10761079
| Access (External_unsafe { ptr; prec; dims = _ }, Some idcs) ->
10771080
string (Ops.ptr_to_string_hum ptr prec) ^^ brackets (pp_indices idcs)
1078-
| Access (Merge_buffer { source }, None) ->
1079-
doc_ident source ^^ string ".merge"
1081+
| Access (Merge_buffer { source }, None) -> doc_ident source ^^ string ".merge"
10801082
| Access (Merge_buffer { source }, Some idcs) ->
10811083
group (doc_ident source ^^ string ".merge" ^^ brackets (pp_indices idcs))
10821084
| Access (File_mapped (file, prec), None) ->
1083-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.precision_to_string prec ^ ")")
1085+
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
10841086
| Access (File_mapped (file, prec), Some idcs) ->
1085-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.precision_to_string prec ^ ")") ^^ brackets (pp_indices idcs)
1087+
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
1088+
^^ brackets (pp_indices idcs)
10861089
| Access (Uint4x32_to_prec_uniform { source; prec }, None) ->
1087-
string ("uint4x32_to_" ^ Ops.precision_to_string prec ^ "_uniform(") ^^ doc_ident source ^^ string ")"
1090+
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1091+
^^ doc_ident source ^^ string ")"
10881092
| Access (Uint4x32_to_prec_uniform { source; prec }, Some idcs) ->
1089-
string ("uint4x32_to_" ^ Ops.precision_to_string prec ^ "_uniform(") ^^ doc_ident source ^^ string ")" ^^ brackets (pp_indices idcs)
1093+
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1094+
^^ doc_ident source ^^ string ")"
1095+
^^ brackets (pp_indices idcs)
10901096
| Get (tn, idcs) -> group (doc_ident tn ^^ brackets (pp_indices idcs))
10911097
| Constant c -> string (Printf.sprintf "%.16g" c)
10921098
| Embed_index idx -> pp_axis_index idx
@@ -1138,3 +1144,33 @@ let loop_over_dims dims ~body =
11381144
}
11391145
in
11401146
for_loop [] (Array.to_list dims)
1147+
1148+
let unroll_dims dims ~body =
1149+
if Array.is_empty dims then body [||] ~offset:0
1150+
else
1151+
(* Calculate strides for each dimension (rightmost changes fastest) *)
1152+
let strides = Array.create ~len:(Array.length dims) 1 in
1153+
for i = Array.length dims - 2 downto 0 do
1154+
strides.(i) <- strides.(i + 1) * dims.(i + 1)
1155+
done;
1156+
1157+
(* Generate all combinations of indices *)
1158+
let rec generate_all_combinations indices_so_far offset dim_index =
1159+
if dim_index >= Array.length dims then
1160+
(* We have a complete combination, call the body *)
1161+
body (Array.of_list_rev indices_so_far) ~offset
1162+
else
1163+
(* Generate all values for current dimension *)
1164+
let results = ref [] in
1165+
for i = 0 to dims.(dim_index) - 1 do
1166+
let new_offset = offset + (i * strides.(dim_index)) in
1167+
let result =
1168+
generate_all_combinations
1169+
(Indexing.Fixed_idx i :: indices_so_far)
1170+
new_offset (dim_index + 1)
1171+
in
1172+
results := result :: !results
1173+
done;
1174+
unflat_lines (List.rev !results)
1175+
in
1176+
generate_all_combinations [] 0 0

arrayjit/lib/low_level.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ val apply_op : Ops.op -> float_t array -> float_t
6666
val flat_lines : t list -> t list
6767
val unflat_lines : t list -> t
6868
val loop_over_dims : int array -> body:(Indexing.axis_index array -> t) -> t
69+
val unroll_dims : int array -> body:(Indexing.axis_index array -> offset:int -> t) -> t
6970

7071
(** {2 Optimization} *)
7172

arrayjit/lib/lowering_and_inlining.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ TODO: flesh out explanation.
7676

7777
## Translation
7878

79-
The translation `Assignments.to_low_level` is straightforward. Commented code blocks are delineated by `Low_level.Comment "end"` statements. Indices into tensor nodes are derived from the `projections` fields by the `Indexing.derive_index` function. We translate `projections.product_space` elements into for loops. `to_low_level` returns all the data that `Low_level` optimizations generated, so that backends can make more informed decisions when jitting, i.e. emitting the backend-specific code.
79+
The translation `Assignments.to_low_level` is straightforward. Commented code blocks are delineated by `Low_level.Comment "end"` statements. Indices into tensor nodes are derived from the `projections` fields. We translate `projections.product_space` elements into for loops. `to_low_level` returns all the data that `Low_level` optimizations generated, so that backends can make more informed decisions when jitting, i.e. emitting the backend-specific code.
8080

8181
## Inlining
8282

arrayjit/test/test_numerical_types.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ let test_bfloat16_conversions () =
1616
(* Test round-trip through ndarray *)
1717
let arr =
1818
Ndarray.create_array ~debug:"test" Ops.bfloat16 ~dims:[| 3; 2 |] ~padding:None
19-
(Ops.Constant_fill { values = [| 1.0; 2.0; 3.14; -1.5; 0.125; 1000.0 |]; strict = true })
19+
(Assignments.Constant_fill [| 1.0; 2.0; 3.14; -1.5; 0.125; 1000.0 |])
2020
in
2121

2222
Stdio.printf "\nBFloat16 array values:\n";

lib/operation.ml

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,10 @@ let embed_symbol ?(label = []) static_sym : Tensor.t =
386386

387387
let random_seed =
388388
let seed = Option.value ~default:42 @@ Utils.settings.fixed_state_for_init in
389-
let res = Tensor.term ~label:[ "random_seed" ] ~grad_spec:Prohibit_grad
390-
~fetch_op:(Asgns.Constant_fill { values = [| seed |]; strict = true })
391-
() in
389+
let res =
390+
Tensor.term ~label:[ "random_seed" ] ~grad_spec:Prohibit_grad
391+
~fetch_op:(Asgns.Constant_fill [| seed |]) ()
392+
in
392393
Tn.update_memory_mode res.value Tn.Effectively_constant 24;
393394
Tn.update_prec res.value Ir.Ops.uint4x32;
394395
ref res
@@ -462,21 +463,15 @@ module TDSL = struct
462463
let stop_gradient = stop_gradient
463464

464465
(** The input [i] dimensions default to empty. The batch dimensions will be inferred if omitted.
465-
[strict] controls whether [Constant_fill] will try to fit the given values in the tensor and
466-
contribute to shape inference. If it is not provided explicitly, it will be [true] if [b] is
467-
omitted, and [false] otherwise. *)
468-
let init_const ~l ?strict ?b ?(i = []) ~o values =
469-
let strict =
470-
match (strict, b) with Some s, _ -> s | None, Some _ -> false | None, None -> true
471-
in
466+
*)
467+
let init_const ~l ?b ?(i = []) ~o values =
472468
Tensor.term ~label:[ l ] ~grad_spec:Prohibit_grad ?batch_dims:b ~input_dims:i ~output_dims:o
473-
~fetch_op:(Constant_fill { values; strict })
474-
()
469+
~fetch_op:(Asgns.Constant_fill values) ()
475470

476471
(** It's like `Tensor.param` but without shape inference. *)
477472
let init_param ~l ?(b = []) ?(i = []) ?(o = []) values =
478473
Tensor.term ~label:[ l ] ~grad_spec:Require_grad ~batch_dims:b ~input_dims:i ~output_dims:o
479-
~fetch_op:(Constant_fill { values; strict = false })
474+
~fetch_op:(Asgns.Constant_fill values)
480475
()
481476
end
482477

lib/shape.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,16 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
394394
[ Terminal_row cur_sh.batch; Terminal_row cur_sh.input; Terminal_row cur_sh.output ]
395395
in
396396
match logic with
397-
| Terminal (Range_over_offsets | Standard_uniform | Constant_fill { strict = false; _ }) ->
397+
| Terminal (Range_over_offsets | Standard_uniform) ->
398398
(Row.dim_map_empty, mark_terminal ())
399-
| Terminal (Constant_fill { values; strict = true }) ->
399+
| Terminal (Constant_fill values) ->
400400
let len = Array.length values in
401401
let io_dims =
402402
try List.map ~f:dim_to_int_exn @@ cur_sh.output.dims @ cur_sh.input.dims
403403
with Invalid_argument _ ->
404404
raise
405405
@@ Shape_error
406-
( "unify_shapes Constant_fill strict: non-batch dimensions must be known",
406+
( "unify_shapes Constant_fill: non-batch dimensions must be known",
407407
[ Shape_mismatch [ cur_sh ] ] )
408408
in
409409
let batch_elems = len / abs (List.fold ~init:1 ~f:( * ) io_dims) in
@@ -423,7 +423,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
423423
with Invalid_argument _ ->
424424
raise
425425
@@ Shape_error
426-
( "unify_shapes Constant_fill strict: non-batch dimensions must be known",
426+
( "unify_shapes Constant_fill: non-batch dimensions must be known",
427427
[ Shape_mismatch [ cur_sh ] ] )
428428
in
429429
let batch_elems = len / abs (List.fold ~init:1 ~f:( * ) io_dims) in

lib/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ There is an important and intentional difference between `dims` in the `arrayjit
210210
Other important functions in the `Shape` module.
211211

212212
* `einsum_slot_spec_to_dims_bio ~generative` parses an einsum spec for a single shape, returns the three rows and a mapping from axis (`dim`) variables to indices where the einsum specifies fixed indexing. When `generative` is true for the kind of a row, when an axis has a fixed projection to dimension 0, the axis is not a variable added to the fixed indexing mapping, but is instead dimension-1 (solved). The "generative" rows are the ones with no initial user-provided shape information. This is just a heuristic to avoid surprises where a tensor axis with only dimension 0 populated gets inferred a bigger dimension size -- it might be revisited in the future.
213-
* `get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill { values; strict = true }` and `File_mapped (filename, prec)` (where the file is scanned to get its length). For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape ≥ lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
213+
* `get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values` and `File_mapped (filename, prec)` (where the file is scanned to get its length). For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape ≥ lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
214214
* `propagate_shapes` gets and then solves the inequalities, using a global state for the environment. It udpates the shapes in-place with the partial solution. It is invoked twice for each `update_step`: first during the bottom-up process of building tensors, and then in reverse order from `finish_inference`.
215215
* `finish_inference` is called right before some projections or array dimensions are required (typically, because of jitting). It performs a second round of `propagate_shapes`, and then once again attempts to solve any remaining constraints that `propagate_shapes` didn't solve. Then it "closes the shapes": substitutes out remaining shape variables by their LUBs if any, or dimension-1 / `Broadcastable` (no-more-axes). Then it resets the environment state, since the shapes are now guaranteed to not have variables.
216216
* `derive_projections` starts by freshening the `proj_id`s in the `update_step`. Then it generates and solves shape inequalities, and then generates and solves projection equations, and constructs the `projections` record.

0 commit comments

Comments
 (0)