Skip to content

Commit 04aebe8

Browse files
committed
In progress: preparations for threefry, get rid of File_mapped
1 parent f24a1e6 commit 04aebe8

File tree

8 files changed

+47
-91
lines changed

8 files changed

+47
-91
lines changed

arrayjit/lib/assignments.ml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,7 @@ let to_doc ?name ?static_indices () c =
362362
| Access (Low_level.Merge_buffer { source }) -> string (ident source ^ ".merge")
363363
| Access (Low_level.External_unsafe { ptr; prec; dims = _ }) ->
364364
string (Ops.ptr_to_string_hum ptr prec)
365-
| Access (Low_level.File_mapped (file, file_prec)) ->
366-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string file_prec ^ ")")
367-
| Access (Low_level.Uint4x32_to_prec_uniform { source; prec = target_prec }) ->
365+
| Access (Low_level.Uint4x32_to_prec_uniform { source; target_prec; target_dims = _ }) ->
368366
string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(" ^ ident source ^ ")")
369367
| Slice { batch_idx; sliced } ->
370368
string (ident sliced ^ " @| " ^ Indexing.symbol_ident batch_idx.static_symbol)

arrayjit/lib/c_syntax.ml

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -617,23 +617,14 @@ module C_syntax (B : C_syntax_config) = struct
617617
^^ offset_doc ^^ string "))" ^^ string postfix
618618
in
619619
(empty, expr)
620-
| Access (Low_level.File_mapped (file, source_prec), Some idcs) ->
621-
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
622-
let expr =
623-
string prefix
624-
^^ string ("file_mapped_data_" ^ file ^ "[")
625-
^^ pp_array_offset (idcs, [||])
626-
^^ string "]" ^^ string postfix
627-
in
628-
(empty, expr)
629-
| Access (Low_level.Uint4x32_to_prec_uniform { source; prec = source_prec }, Some idcs) ->
620+
| Access (Low_level.Uint4x32_to_prec_uniform { source; target_prec; target_dims }, Some idcs) ->
630621
let tn = source in
631-
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
632-
let offset_doc = pp_array_offset (idcs, Lazy.force tn.dims) in
622+
let prefix, postfix = B.convert_precision ~from:target_prec ~to_:prec in
623+
let offset_doc = pp_array_offset (idcs, Lazy.force target_dims) in
633624
let source_ident = string (get_ident tn) in
634625
let expr =
635626
string prefix
636-
^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
627+
^^ string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(")
637628
^^ source_ident ^^ brackets offset_doc ^^ string ")" ^^ string postfix
638629
in
639630
(empty, expr)
@@ -736,34 +727,20 @@ module C_syntax (B : C_syntax_config) = struct
736727
string prefix ^^ string ("external[%u]{=" ^ B.float_log_style ^ "}") ^^ string postfix
737728
in
738729
(expr_doc, [ `Accessor (idcs, dims_val); `Value access_doc ])
739-
| Access (Low_level.File_mapped (file, source_prec), Some idcs) ->
740-
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
741-
let access_doc =
742-
string prefix
743-
^^ string ("file_mapped_data_" ^ file ^ "[")
744-
^^ pp_array_offset (idcs, [||])
745-
^^ string "]" ^^ string postfix
746-
in
747-
let expr_doc =
748-
string prefix
749-
^^ string ("file_mapped_" ^ file ^ "[%u]{=" ^ B.float_log_style ^ "}")
750-
^^ string postfix
751-
in
752-
(expr_doc, [ `Accessor (idcs, [||]); `Value access_doc ])
753-
| Access (Low_level.Uint4x32_to_prec_uniform { source; prec = source_prec }, Some idcs) ->
730+
| Access (Low_level.Uint4x32_to_prec_uniform { source; target_prec; target_dims }, Some idcs) ->
754731
let tn = source in
755-
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
756-
let dims = Lazy.force tn.dims in
732+
let prefix, postfix = B.convert_precision ~from:target_prec ~to_:prec in
733+
let dims = Lazy.force target_dims in
757734
let offset_doc = pp_array_offset (idcs, dims) in
758735
let source_ident = string (get_ident tn) in
759736
let access_doc =
760737
string prefix
761-
^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
738+
^^ string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(")
762739
^^ source_ident ^^ brackets offset_doc ^^ string ")" ^^ string postfix
763740
in
764741
let expr_doc =
765742
string prefix
766-
^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
743+
^^ string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(")
767744
^^ source_ident
768745
^^ brackets (string "%u")
769746
^^ string "){=" ^^ string B.float_log_style ^^ string "}" ^^ string postfix

arrayjit/lib/low_level.ml

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,12 @@ let _get_local_debug_runtime = Utils.get_local_debug_runtime
1010

1111
type dedicated_access =
1212
| C_function of string
13-
| External_unsafe of {
14-
ptr : Ops.voidptr;
15-
prec : (Ops.prec[@equal.ignore] [@compare.ignore]);
16-
dims : int array Lazy.t;
17-
}
13+
| External_unsafe of { ptr : Ops.voidptr; prec : Ops.prec; dims : int array Lazy.t }
1814
| Merge_buffer of { source : Tnode.t }
19-
| File_mapped of string * Ops.prec
2015
| Uint4x32_to_prec_uniform of {
2116
source : Tnode.t;
22-
prec : (Ops.prec[@equal.ignore] [@compare.ignore]);
17+
target_prec : Ops.prec;
18+
target_dims : int array Lazy.t;
2319
}
2420
[@@deriving sexp_of, equal, compare]
2521

@@ -100,9 +96,7 @@ let virtualize_settings =
10096
let max_tracing_dim =
10197
Int.of_string @@ Utils.get_global_arg ~arg_name:"virtualize_max_tracing_dim" ~default:"5"
10298
in
103-
let enable_device_only =
104-
Utils.get_global_flag ~default:true ~arg_name:"enable_device_only"
105-
in
99+
let enable_device_only = Utils.get_global_flag ~default:true ~arg_name:"enable_device_only" in
106100
let inline_scalar_constexprs =
107101
Utils.get_global_flag ~default:true ~arg_name:"inline_scalar_constexprs"
108102
in
@@ -1083,16 +1077,11 @@ let to_doc_cstyle ?name ?static_indices () llc =
10831077
| Access (Merge_buffer { source }, None) -> doc_ident source ^^ string ".merge"
10841078
| Access (Merge_buffer { source }, Some idcs) ->
10851079
group (doc_ident source ^^ string ".merge" ^^ brackets (pp_indices idcs))
1086-
| Access (File_mapped (file, prec), None) ->
1087-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
1088-
| Access (File_mapped (file, prec), Some idcs) ->
1089-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
1090-
^^ brackets (pp_indices idcs)
1091-
| Access (Uint4x32_to_prec_uniform { source; prec }, None) ->
1092-
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1080+
| Access (Uint4x32_to_prec_uniform { source; target_prec; target_dims = _ }, None) ->
1081+
string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(")
10931082
^^ doc_ident source ^^ string ")"
1094-
| Access (Uint4x32_to_prec_uniform { source; prec }, Some idcs) ->
1095-
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1083+
| Access (Uint4x32_to_prec_uniform { source; target_prec; target_dims = _ }, Some idcs) ->
1084+
string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(")
10961085
^^ doc_ident source ^^ string ")"
10971086
^^ brackets (pp_indices idcs)
10981087
| Get (tn, idcs) -> group (doc_ident tn ^^ brackets (pp_indices idcs))
@@ -1170,16 +1159,11 @@ let to_doc ?name ?static_indices () llc =
11701159
| Access (Merge_buffer { source }, None) -> doc_ident source ^^ string ".merge"
11711160
| Access (Merge_buffer { source }, Some idcs) ->
11721161
group (doc_ident source ^^ string ".merge" ^^ brackets (pp_indices idcs))
1173-
| Access (File_mapped (file, prec), None) ->
1174-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
1175-
| Access (File_mapped (file, prec), Some idcs) ->
1176-
string ("file_mapped(\"" ^ file ^ "\", " ^ Ops.prec_string prec ^ ")")
1177-
^^ brackets (pp_indices idcs)
1178-
| Access (Uint4x32_to_prec_uniform { source; prec }, None) ->
1179-
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1162+
| Access (Uint4x32_to_prec_uniform { source; target_prec; target_dims = _ }, None) ->
1163+
string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(")
11801164
^^ doc_ident source ^^ string ")"
1181-
| Access (Uint4x32_to_prec_uniform { source; prec }, Some idcs) ->
1182-
string ("uint4x32_to_" ^ Ops.prec_string prec ^ "_uniform(")
1165+
| Access (Uint4x32_to_prec_uniform { source; target_prec; target_dims = _ }, Some idcs) ->
1166+
string ("uint4x32_to_" ^ Ops.prec_string target_prec ^ "_uniform(")
11831167
^^ doc_ident source ^^ string ")"
11841168
^^ brackets (pp_indices idcs)
11851169
| Get (tn, idcs) -> group (doc_ident tn ^^ brackets (pp_indices idcs))

arrayjit/lib/low_level.mli

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,20 @@ open Base
77
(** A dedicated access that might need to be implemented differently for each backend. *)
88
type dedicated_access =
99
| C_function of string (** Calls a no-argument or indices-arguments C function. *)
10-
| External_unsafe of {
11-
ptr : Ops.voidptr;
12-
prec : (Ops.prec[@equal.ignore] [@compare.ignore]);
13-
dims : int array Lazy.t;
14-
}
10+
| External_unsafe of { ptr : Ops.voidptr; prec : Ops.prec; dims : int array Lazy.t }
1511
| Merge_buffer of { source : Tnode.t }
1612
(** Each device has at most one merge buffer, which is re-used, and re-allocated as needed, by
1713
merge operations. The merge buffer is associated with the source node of the device's most
1814
recent [device_to_device ~into_merge_buffer:true] operation. *)
19-
| File_mapped of string * Ops.prec
20-
(** Reads the data using [Unix.openfile] and [Unix.map_file]. *)
2115
| Uint4x32_to_prec_uniform of {
2216
source : Tnode.t;
23-
prec : (Ops.prec[@equal.ignore] [@compare.ignore]);
17+
target_prec : Ops.prec;
18+
target_dims : int array Lazy.t;
2419
}
2520
(** Converts the given Uint4x32 to the given precision in a bit-efficient manner. For random
2621
bits, the result is uniform over the range of the precision for integer precisions, and
27-
over the range [[0.0, 1.0)] for floating point precisions. *)
22+
over the range \[0.0, 1.0) for floating point precisions. When used in an access pattern,
23+
the indices are converted to a byte offset depending on the given precision. *)
2824
[@@deriving sexp_of, equal, compare]
2925

3026
module Scope_id : sig

arrayjit/lib/ops.ml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,15 @@ let interpret_binop op v1 v2 =
343343
(* | Shr -> v1 / (int_pow 2. @@ to_int v2) *)
344344
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
345345
| And -> if v1 <> 0. && v2 <> 0. then 1. else 0.
346-
| Threefry4x32 -> invalid_arg "interpret_binop: Threefry4x32 requires hardware implementation"
346+
| Threefry4x32 ->
347+
(* NOTE: the purpose of this code is to reflect the reference implementation that all backends
348+
should implement. Due to precision constraints and the fact that threefry is inherently a
349+
non-numerical, bit-level operation, this code reinterprets the arguments and the result as
350+
the lower 64 bits of the 128-bit values; with the upper 64 bits being 0 for the arguments
351+
and ignored for the result. This agrees with the Bigarray reinterpretation of the
352+
[Uint4x32] precision as the [Complex.t] type with the real part exposed. *)
353+
(* FIXME: NOT IMPLEMENTED YET *)
354+
failwith "NOT IMPLEMENTED YET: Threefry4x32"
347355

348356
let interpret_unop op v =
349357
let open Float in
@@ -461,7 +469,9 @@ let binop_c_syntax prec v =
461469
(* | Shr, _ -> ("((", ") / exp2(", "))") *)
462470
| Or, _ -> ("(", " ||", ")")
463471
| And, _ -> ("(", " &&", ")")
464-
| Threefry4x32, _ -> ("threefry4x32(", ",", ")")
472+
| Threefry4x32, _ ->
473+
(* This corresponds to the pure C implementation in arrayjit_stubs.c. *)
474+
("arrayjit_threefry4x32(", ",", ")")
465475

466476
let is_assign_op = function
467477
| Arg1 | Mod | Threefry4x32 (* | Shl | Shr *) | Cmplt | Cmpeq | Cmpne -> false

lib/shape.ml

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -442,18 +442,11 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
442442
| Terminal (`Fetch (Access (C_function _))) -> (Row.dim_map_empty, mark_terminal ())
443443
| Terminal (`Fetch (Access (External_unsafe _))) -> (Row.dim_map_empty, mark_terminal ())
444444
| Terminal (`Fetch (Access (Merge_buffer _))) -> (Row.dim_map_empty, mark_terminal ())
445-
| Terminal (`Fetch (Access (Uint4x32_to_prec_uniform _))) -> (Row.dim_map_empty, mark_terminal ())
446-
| Terminal (`Fetch (Access (File_mapped (filename, prec)))) ->
447-
let fd = Unix.openfile filename [ Unix.O_RDONLY ] 0o640 in
448-
let len = Unix.lseek fd 0 Unix.SEEK_END / Ir.Ops.prec_in_bytes prec in
449-
Unix.close fd;
450-
( Row.dim_map_empty,
451-
Rows_constr
452-
{
453-
r = [ cur_sh.batch; cur_sh.output; cur_sh.input ];
454-
constr = Total_elems { nominator = len; divided_by = dim_var_set_empty };
455-
}
456-
:: mark_terminal () )
445+
| Terminal (`Fetch (Access (Uint4x32_to_prec_uniform _))) ->
446+
(* FIXME: NOT IMPLEMENTED YET -- we need to propagate the precision-adjusted dimensions
447+
between the source tensor and the target tensor. This is tricky because the dimensions
448+
are not known at the time of the shape inference. *)
449+
(Row.dim_map_empty, mark_terminal ())
457450
| Terminal (`Fetch (Slice { sliced = tn; batch_idx = _ })) ->
458451
if Lazy.is_val tn.dims then
459452
( dim_map_empty,

lib/shape.mli

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ type logic =
149149
[s1], hence the name. *)
150150
| Broadcast_tern of ternary_type * t * t * t (** Matches the shapes for a ternary operation. *)
151151
| Terminal of [ `Data of Ir.Assignments.init_data | `Fetch of Ir.Assignments.fetch_op ]
152-
(** Extracts any available shape information from the initialization. E.g. for
153-
[`Fetch (File_mapped fn)], opens the file [fn] to check its length. *)
152+
(** Extracts any available shape information from the initialization. *)
154153
[@@deriving equal, sexp_of]
155154

156155
type update_id [@@deriving equal, compare, hash, sexp]

lib/shape_inference.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ type logic =
126126
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of [s1],
127127
hence the name. *)
128128
| Terminal of [ `Data of Ir.Assignments.init_data | `Fetch of Ir.Assignments.fetch_op ]
129-
(** Extracts any available shape information from the initialization. E.g.
130-
for [File_mapped fn], opens the file [fn] to check its length. *)
129+
(** Extracts any available shape information from the initialization, e.g. the number of elements. *)
131130
```
132131

133132
### Non-tensor-like constraints
@@ -208,7 +207,7 @@ There is an important and intentional difference between `dims` in the `arrayjit
208207
Other important functions in the `Shape` module.
209208

210209
* `einsum_slot_spec_to_dims_bio ~generative` parses an einsum spec for a single shape, returns the three rows and a mapping from axis (`dim`) variables to indices where the einsum specifies fixed indexing. When `generative` is true for the kind of a row, when an axis has a fixed projection to dimension 0, the axis is not a variable added to the fixed indexing mapping, but is instead dimension-1 (solved). The "generative" rows are the ones with no initial user-provided shape information. This is just a heuristic to avoid surprises where a tensor axis with only dimension 0 populated gets inferred a bigger dimension size -- it might be revisited in the future.
211-
* `get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values` and `File_mapped (filename, prec)` (where the file is scanned to get its length). For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape ≥ lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
210+
* `get_inequalities` builds row inequalities by pairing the rows of the current shape (as `cur`) with the rows of sub-shapes (as `subr`). It also derives a batch row constraint for terminals initialized with `Constant_fill values`. For `Batch_slice` (the `@|` operation) it waits till the batch row variables (if any) are solved, and derives row equations (not inequalities) between the current shape and the sub-shape, with `cur_sh.batch.dims` expanded to account for the slicing / indexing. For einsum specs, it derives inequalities, roughly: _current shape ≥ lhs spec shape_, and _rhs spec shape ≥ sub-shape_.
212211
* `propagate_shapes` gets and then solves the inequalities, using a global state for the environment. It udpates the shapes in-place with the partial solution. It is invoked twice for each `update_step`: first during the bottom-up process of building tensors, and then in reverse order from `finish_inference`.
213212
* `finish_inference` is called right before some projections or array dimensions are required (typically, because of jitting). It performs a second round of `propagate_shapes`, and then once again attempts to solve any remaining constraints that `propagate_shapes` didn't solve. Then it "closes the shapes": substitutes out remaining shape variables by their LUBs if any, or dimension-1 / `Broadcastable` (no-more-axes). Then it resets the environment state, since the shapes are now guaranteed to not have variables.
214213
* `derive_projections` starts by freshening the `proj_id`s in the `update_step`. Then it generates and solves shape inequalities, and then generates and solves projection equations, and constructs the `projections` record.

0 commit comments

Comments
 (0)