Skip to content

Commit ad9a53e

Browse files
committed
Big refactoring: Uint4x32_to_prec_uniform moves from a fetch op to a proper unary op (Ops) with dedicated shape and projections inference support (once done); getting rid of the remaining dedicated_access fetch ops with a migration of Merge_buffer to a stand-alone Get_merge_buffer variant in float_t (Low_level); and better consistency with the new terminal_type (Shape).
- Introduced a new `uint4x32_t` structure and a stub for the `arrayjit_threefry4x32` function. - Updated `float_t` type to include `Get_merge_buffer` and removed the `dedicated_access` type. - Added `Uint4x32_to_prec_uniform` operation in `ops.ml` and updated related type definitions in `shape.ml` and `shape.mli`. - Modified tensor operation signatures to include a new `terminal_op` parameter for better expressivity.
1 parent 04aebe8 commit ad9a53e

File tree

7 files changed

+46
-39
lines changed

7 files changed

+46
-39
lines changed

arrayjit/lib/arrayjit_stubs.c

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,22 @@ static inline uint8_t float_to_fp8(float f)
113113
return (uint8_t)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));
114114
}
115115

116+
typedef struct {
117+
uint32_t v[4];
118+
} uint4x32_t;
119+
120+
/* Threefry4x32 implementation (C function) */
121+
uint4x32_t arrayjit_threefry4x32(uint4x32_t v1, uint4x32_t v2)
122+
{
123+
/* FIXME: NOT IMPLEMENTED YET */
124+
uint4x32_t result;
125+
result.v[0] = 0;
126+
result.v[1] = 0;
127+
result.v[2] = 0;
128+
result.v[3] = 0;
129+
return result;
130+
}
131+
116132
/* OCaml wrapper functions */
117133

118134
/* BFloat16 to Float conversion (OCaml wrapper) */
@@ -188,7 +204,7 @@ CAMLprim value arrayjit_copy_with_padding(value v_source, value v_target,
188204
source_total *= source_dims_ba[i];
189205
}
190206

191-
/* FIXME: Simple memcpy for now - must be optimized later for proper padding */
207+
/* FIXME: Simple memcpy for now - must implement proper padding */
192208
memcpy(target_data, source_data, source_total * elem_size);
193209

194210
CAMLreturn(Val_unit);

arrayjit/lib/low_level.ml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,6 @@ let _get_local_debug_runtime = Utils.get_local_debug_runtime
88
[%%global_debug_log_level 9]
99
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
1010

11-
type dedicated_access =
12-
| C_function of string
13-
| External_unsafe of { ptr : Ops.voidptr; prec : Ops.prec; dims : int array Lazy.t }
14-
| Merge_buffer of { source : Tnode.t }
15-
| Uint4x32_to_prec_uniform of {
16-
source : Tnode.t;
17-
target_prec : Ops.prec;
18-
target_dims : int array Lazy.t;
19-
}
20-
[@@deriving sexp_of, equal, compare]
21-
2211
module Scope_id = struct
2312
type t = { tn : Tn.t; scope_id : int } [@@deriving sexp_of, equal, hash, compare]
2413

@@ -53,8 +42,8 @@ type t =
5342
and float_t =
5443
| Local_scope of { id : scope_id; body : t; orig_indices : Indexing.axis_index array }
5544
| Get_local of scope_id
56-
| Access of dedicated_access * Indexing.axis_index array option
5745
| Get of Tn.t * Indexing.axis_index array
46+
| Get_merge_buffer of Tn.t * Indexing.axis_index array
5847
| Ternop of Ops.ternop * float_t * float_t * float_t
5948
| Binop of Ops.binop * float_t * float_t
6049
| Unop of Ops.unop * float_t

arrayjit/lib/low_level.mli

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,6 @@ open Base
44

55
(** {2 Global references} *)
66

7-
(** A dedicated access that might need to be implemented differently for each backend. *)
8-
type dedicated_access =
9-
| C_function of string (** Calls a no-argument or indices-arguments C function. *)
10-
| External_unsafe of { ptr : Ops.voidptr; prec : Ops.prec; dims : int array Lazy.t }
11-
| Merge_buffer of { source : Tnode.t }
12-
(** Each device has at most one merge buffer, which is re-used, and re-allocated as needed, by
13-
merge operations. The merge buffer is associated with the source node of the device's most
14-
recent [device_to_device ~into_merge_buffer:true] operation. *)
15-
| Uint4x32_to_prec_uniform of {
16-
source : Tnode.t;
17-
target_prec : Ops.prec;
18-
target_dims : int array Lazy.t;
19-
}
20-
(** Converts the given Uint4x32 to the given precision in a bit-efficient manner. For random
21-
bits, the result is uniform over the range of the precision for integer precisions, and
22-
over the range \[0.0, 1.0) for floating point precisions. When used in an access pattern,
23-
the indices are converted to a byte offset depending on the given precision. *)
24-
[@@deriving sexp_of, equal, compare]
25-
267
module Scope_id : sig
278
type t = { tn : Tnode.t; scope_id : int } [@@deriving sexp_of, equal, hash, compare]
289
type comparator_witness
@@ -50,8 +31,8 @@ type t =
5031
and float_t =
5132
| Local_scope of { id : scope_id; body : t; orig_indices : Indexing.axis_index array }
5233
| Get_local of scope_id
53-
| Access of dedicated_access * Indexing.axis_index array option
5434
| Get of Tnode.t * Indexing.axis_index array
35+
| Get_merge_buffer of Tnode.t * Indexing.axis_index array
5536
| Ternop of Ops.ternop * float_t * float_t * float_t
5637
| Binop of Ops.binop * float_t * float_t
5738
| Unop of Ops.unop * float_t

arrayjit/lib/ops.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,13 @@ type unop =
295295
| Neg
296296
| Tanh_approx
297297
| Not (** 0. -> 1. | _ -> 0. *)
298+
| Uint4x32_to_prec_uniform of prec
299+
(** Converts the given Uint4x32 to the given precision in a bit-efficient manner. For random
300+
bits, the result is uniform over the range of the precision for integer precisions, and
301+
over the range \[0.0, 1.0) for floating point precisions. When used in an access pattern,
302+
the indices are converted to a byte offset depending on the given precision. NOTE: this
303+
operation, unlike any others, impacts projections and shape inference (one input cell
304+
corresponds to a few output cells). *)
298305
[@@deriving sexp, compare, equal]
299306

300307
type ternop =

lib/shape.ml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,14 @@ type transpose_type =
9494
| Pointwise_un
9595
| Permute of string
9696
| Batch_slice of Idx.static_symbol
97+
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
9798
[@@deriving equal, sexp]
9899

100+
type terminal_type =
101+
| Data of Ir.Assignments.init_data
102+
| Fetch of Ir.Assignments.fetch_op
103+
[@@deriving equal, sexp_of]
104+
99105
type ternary_type = Pointwise_tern | Compose_accumulate [@@deriving sexp, equal]
100106

101107
let identifier ~multichar =

lib/shape.mli

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ type transpose_type =
101101
| Pointwise_un (** Preserves the shape. *)
102102
| Permute of string (** The unary "einsum" syntax: RHS1=>LHS. *)
103103
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)
104+
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
105+
(** Converts precision in a bit-effient way, with a corresponding conversion in total number
106+
of elements. Currently, assumes the incoming tensor (RHS) has just a single axis to not
107+
force unnecessary minimum sizes on output axes. *)
104108
[@@deriving equal, sexp]
105109

106110
(** If you miss expressivity here, leave a note on
@@ -110,6 +114,10 @@ type ternary_type =
110114
| Compose_accumulate (** As in the operation [FMA]. *)
111115
[@@deriving equal, sexp]
112116

117+
(** Extracts any available shape information from the initialization or fetch. *)
118+
type terminal_type = Data of Ir.Assignments.init_data | Fetch of Ir.Assignments.fetch_op
119+
[@@deriving equal, sexp_of]
120+
113121
val make :
114122
?batch_dims:int list ->
115123
?input_dims:int list ->
@@ -148,7 +156,7 @@ type logic =
148156
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of
149157
[s1], hence the name. *)
150158
| Broadcast_tern of ternary_type * t * t * t (** Matches the shapes for a ternary operation. *)
151-
| Terminal of [ `Data of Ir.Assignments.init_data | `Fetch of Ir.Assignments.fetch_op ]
159+
| Terminal of terminal_type
152160
(** Extracts any available shape information from the initialization. *)
153161
[@@deriving equal, sexp_of]
154162

lib/tensor.mli

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,19 @@ val op :
129129
?ternary_op:Shape.ternary_type ->
130130
?compose_op:Shape.compose_type ->
131131
?transpose_op:Shape.transpose_type ->
132-
?init_data:Ir.Assignments.init_data ->
133-
?fetch_op:fetch_op ->
132+
?terminal_op:Shape.terminal_type ->
134133
op_asn:(v:tn -> projections:projections Lazy.t -> comp) ->
135134
grad_asn:(t:t -> g:tn -> projections:projections Lazy.t -> comp) ->
136135
?grad_spec:grad_spec ->
137136
(debug_name:string -> id:int -> Shape.t) ->
138137
t list ->
139138
t
140-
(** At most one of [?ternary_op] or [?compose_op] or [?transpose_op] or [?init_data] or [?fetch_op]
141-
should be provided, except when the operation takes more than three arguments which uses both
139+
(** At most one of [?ternary_op] or [?compose_op] or [?transpose_op] or [?terminal_op] should be
140+
provided, except when the operation takes more than three arguments which uses both
142141
[?compose_op] or [?transpose_op]. The defaults are pointwise operations. The [grad_asn] function
143142
receives the non-differentiable variant of the tensor as an argument, which can be used to
144-
access the tensor's value in a tensor expression. *)
143+
access the tensor's value in a tensor expression. The [terminal_op] is used to specify the
144+
terminal operation of the tensor. *)
145145

146146
val binop :
147147
label:string list ->

0 commit comments

Comments
 (0)