Skip to content

Commit 2a15301

Browse files
committed
In progress: embedding of dimensions in tensor expressions
1 parent 06744a8 commit 2a15301

File tree

10 files changed

+92
-26
lines changed

10 files changed

+92
-26
lines changed

arrayjit/lib/assignments.ml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type fetch_op =
3737
| Slice of { batch_idx : Indexing.static_symbol; sliced : Tn.t }
3838
| Embed_symbol of Indexing.static_symbol
3939
| Embed_self_id (** Embeds the id of the [array] field of the [Fetch] constructor. *)
40+
| Embed_dim of Indexing.variable_ref
4041
[@@deriving sexp_of, equal]
4142

4243
type accum_rhs =
@@ -345,6 +346,10 @@ let%track4_sexp to_low_level code =
345346
| Fetch { array; fetch_op = Embed_self_id; dims } ->
346347
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
347348
set array idcs @@ Constant_bits (Int64.of_int array.id))
349+
| Fetch { array; fetch_op = Embed_dim variable_ref; dims } ->
350+
(* Note: we are guaranteed all shape inference is forced before we access variable_ref. *)
351+
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
352+
set array idcs @@ Constant (Float.of_int @@ Option.value_exn variable_ref.solved_dim))
348353
| Fetch { array; fetch_op = Range_over_offsets; dims = (lazy dims) } ->
349354
Low_level.loop_over_dims dims ~body:(fun idcs ->
350355
let offset = Indexing.reflect_projection ~dims ~projection:idcs in
@@ -443,6 +448,7 @@ let to_doc ?name ?static_indices () c =
443448
| Embed_symbol { static_symbol; static_range = _ } ->
444449
string ("!@" ^ Indexing.symbol_ident static_symbol)
445450
| Embed_self_id -> string "!@self_id"
451+
| Embed_dim { ref_label; _ } -> string ("(dim " ^ ref_label ^ ")")
446452
in
447453

448454
let rec doc_of_code = function

arrayjit/lib/indexing.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ let reflect_projection ~(dims : int array) ~(projection : axis_index array) =
302302
| Sub_axis -> (stride * dim, symbols, offset))
303303
|> fun (_, symbols, offset) -> Affine { symbols; offset }
304304

305+
type variable_ref = { ref_label : string; mutable solved_dim : int option }
306+
[@@deriving sexp_of, equal]
307+
305308
module Doc_helpers = struct
306309
let ( ^^ ) = PPrint.( ^^ )
307310
let ( !^ ) = PPrint.( !^ )

lib/operation.ml

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,35 +85,37 @@ let matmul ?(label = []) =
8585
8686
Note that ["a,b->c"] from [numpy] is ["a;b=>c"] in OCANNL, since ["->"] is used to separate the
8787
input and the output axes. *)
88-
let einsum ?(label = []) spec =
88+
let einsum ?(label = []) ?(capture_dims = []) spec =
8989
let module NTDSL = Initial_NTDSL in
9090
let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in
9191
let%cd grad_asn ~t:_ ~g ~t1 ~t2 ~projections =
9292
g1 =+ g * v2;
9393
g2 =+ v1 * g
9494
in
95-
Tensor.binop ~label:(";=>" :: label) ~compose_op:(Einsum spec) ~op_asn ~grad_asn
95+
Tensor.binop ~label:(";=>" :: label) ~compose_op:(Einsum (spec, capture_dims)) ~op_asn ~grad_asn
9696

9797
(** Like [einsum], but adds instead than multiplying the resulting values. *)
98-
let outer_sum ?(label = []) spec =
98+
let outer_sum ?(label = []) ?(capture_dims = []) spec =
9999
let module NTDSL = Initial_NTDSL in
100100
let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 + v2 in
101101
let%cd grad_asn ~t:_ ~g ~t1 ~t2 ~projections =
102102
g1 =+ g;
103103
g2 =+ g
104104
in
105-
Tensor.binop ~label:(";=>+" :: label) ~compose_op:(Einsum spec) ~op_asn ~grad_asn
105+
Tensor.binop ~label:(";=>+" :: label) ~compose_op:(Einsum (spec, capture_dims)) ~op_asn ~grad_asn
106106

107107
(** Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract
108108
diagonals, compute traces etc.
109109
110110
Note that ["a->c"] from [numpy] is ["a=>c"] in OCANNL, since ["->"] is used to separate the
111111
input and the output axes. *)
112-
let einsum1 ?(label = []) spec =
112+
let einsum1 ?(label = []) ?(capture_dims = []) spec =
113113
let module NTDSL = Initial_NTDSL in
114114
let%cd op_asn ~v ~t1 ~projections = v =:+ v1 in
115115
let%cd grad_asn ~t:_ ~g ~t1 ~projections = g1 =+ g in
116-
Tensor.unop ~transpose_op:(Shape.Permute spec) ~op_asn ~grad_asn ~label:("=>" :: label)
116+
Tensor.unop
117+
~transpose_op:(Shape.Permute (spec, capture_dims))
118+
~op_asn ~grad_asn ~label:("=>" :: label)
117119

118120
module NDO_before_pow = struct
119121
let ( * ) t1 t2 = matmul ~grad_spec:Prohibit_grad t1 t2 ()
@@ -455,11 +457,15 @@ let slice (batch_idx : Idx.static_symbol) =
455457
Tensor.unop ~transpose_op:(Batch_slice batch_idx) ~op_asn ~grad_asn ~label:("@|" :: label)
456458

457459
let embed_symbol ?grad_spec ?(label = []) static_sym =
458-
Tensor.term ~fetch_op:(Embed_symbol static_sym) ?grad_spec ~label:("!@" :: label)
459-
~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ()
460+
Tensor.term ~fetch_op:(Embed_symbol static_sym) ?grad_spec ~label:("!@" :: label) ~batch_dims:[]
461+
~input_dims:[] ~output_dims:[ 1 ] ()
460462

461463
let embed_self_id ?grad_spec ?(label = []) () =
462-
Tensor.term ~fetch_op:Embed_self_id ?grad_spec ~label:("!@self_id" :: label)
464+
Tensor.term ~fetch_op:Embed_self_id ?grad_spec ~label:("!@self_id" :: label) ~batch_dims:[]
465+
~input_dims:[] ~output_dims:[ 1 ] ()
466+
467+
let embed_dim ?grad_spec ?(label = []) variable_ref =
468+
Tensor.term ~fetch_op:(Embed_dim variable_ref) ?grad_spec ~label:("!@self_id" :: label)
463469
~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ()
464470

465471
let uniform ?grad_spec () =
@@ -590,6 +596,7 @@ struct
590596
Tensor.number ?label ?axis_label ~grad_spec:Grad_spec.grad_spec (Float.of_int i)
591597

592598
let embed_symbol = embed_symbol ~grad_spec:Grad_spec.grad_spec
599+
let embed_dim = embed_dim ~grad_spec:Grad_spec.grad_spec
593600
let sub = sub ~grad_spec:Grad_spec.grad_spec
594601
let pointdiv = pointdiv ~grad_spec:Grad_spec.grad_spec
595602
let slice = slice ~grad_spec:Grad_spec.grad_spec
@@ -627,6 +634,7 @@ struct
627634
let ( !.. ) ?label i = number ?label @@ Float.of_int i
628635
let ( !% ) ?label i = bits ?label i
629636
let ( !@ ) = embed_symbol
637+
let dim = embed_dim
630638
let ( - ) ?label t1 t2 = sub ?label t1 t2 ()
631639
let ( ~- ) ?label t = pointmul ?label (number (-1.)) t ()
632640
let ( /. ) ?label t1 t2 = pointdiv ?label t1 t2 ()

lib/ppx_cd.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,7 @@ let translate ?ident_label (expr : expression) : result =
12921292
let loc = s_loc in
12931293
if String.equal spec "." then [%expr Shape.Pointwise_bin]
12941294
else if String.equal spec "@" then [%expr Shape.Compose]
1295-
else [%expr Shape.Einsum [%e logic]]
1295+
else [%expr Shape.Einsum ([%e logic], [])]
12961296
in
12971297
let _, bin_op = binary_op bin_op in
12981298
process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
@@ -1348,7 +1348,7 @@ let translate ?ident_label (expr : expression) : result =
13481348
let loc = s_loc in
13491349
if String.equal spec "." then [%expr Shape.Pointwise_un]
13501350
else if String.equal spec "T" then [%expr Shape.Transpose]
1351-
else [%expr Shape.Permute [%e logic]]
1351+
else [%expr Shape.Permute ([%e logic], [])]
13521352
in
13531353
let _, un_op = Hashtbl.find_exn unary_ops unop_ident loc in
13541354
process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic

lib/ppx_op.ml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ let operators =
2222
("!..", "number_int");
2323
("!%", "bits");
2424
("!@", "embed_symbol");
25+
("dim", "embed_dim");
2526
("-", "sub");
2627
("~-", "num_neg");
2728
("/.", "pointdiv");
@@ -212,6 +213,29 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
212213
let vbs1, e1 = loop expr1 in
213214
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
214215
(vbs1, [%expr einsum1 ?label:[%e opt_expr ~loc label] [%e spec] [%e e1]])
216+
| [%expr
217+
[%e? expr1]
218+
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
219+
[%e? { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ }]
220+
[%e? expr2]]
221+
when String.contains spec_str '>' ->
222+
(* FIXME: introduce inline definitions for new Indexing.variable_ref objects corresponding to
223+
the strings in the list, and pass them as ~capture_dims *)
224+
let vbs1, e1 = loop expr1 in
225+
let vbs2, e2 = loop expr2 in
226+
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
227+
( reduce_vbss [ vbs1; vbs2 ],
228+
[%expr einsum ?label:[%e opt_expr ~loc label] [%e spec] [%e e1] [%e e2]] )
229+
| [%expr
230+
[%e? expr1]
231+
++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
232+
[%e? { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ }]]
233+
when String.contains spec_str '>' ->
234+
(* FIXME: introduce inline definitions for new Indexing.variable_ref objects corresponding to
235+
the strings in the list, and pass them as ~capture_dims *)
236+
let vbs1, e1 = loop expr1 in
237+
let spec = substitute_identifiers_in_einsum_spec ~loc spec_str in
238+
(vbs1, [%expr einsum1 ?label:[%e opt_expr ~loc label] [%e spec] [%e e1]])
215239
| { pexp_desc = Pexp_record ([], _); _ } ->
216240
(* Empty record - not a tensor definition *)
217241
(no_vbs, expr)

lib/shape.ml

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ module AxisKey = struct
2727

2828
type t = {
2929
in_axes : kind;
30-
pos : int; (** Indices start at [1], counted from the end if [from_end] is true. *)
30+
pos : int;
31+
(** Indices start at [1] (note this is axis index, dimension indices are always 0-based),
32+
counted from the end if [from_end] is true. *)
3133
from_end : bool;
3234
(** Axes are indexed from the front (rarely) or from the end (typically), to avoid
3335
reindexing when broadcasting. *)
@@ -90,15 +92,16 @@ let row_of_kind = function `Batch -> batch | `Input -> input | `Output -> output
9092
type deduce_within_shape = Not_constrained | Input_equals_output
9193
[@@deriving compare, sexp, variants]
9294

93-
type compose_type = Pointwise_bin | Compose | Einsum of string [@@deriving sexp, equal]
95+
type compose_type = Pointwise_bin | Compose | Einsum of string * Idx.variable_ref list
96+
[@@deriving sexp_of, equal]
9497

9598
type transpose_type =
9699
| Transpose
97100
| Pointwise_un
98-
| Permute of string
101+
| Permute of string * Idx.variable_ref list
99102
| Batch_slice of Idx.static_symbol
100103
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
101-
[@@deriving equal, sexp]
104+
[@@deriving equal, sexp_of]
102105

103106
type terminal_type = Data of Ir.Assignments.init_data | Fetch of Ir.Assignments.fetch_op
104107
[@@deriving equal, sexp_of]
@@ -260,7 +263,7 @@ let logic_to_spec = function
260263
| Broadcast_tern (Pointwise_tern, _, _, _) ->
261264
"."
262265
| Broadcast (Compose, _, _) | Broadcast_tern (Compose_accumulate, _, _, _) -> "@"
263-
| Broadcast (Einsum spec, _, _) | Transpose (Permute spec, _) -> spec
266+
| Broadcast (Einsum (spec, _), _, _) | Transpose (Permute (spec, _), _) -> spec
264267
| Transpose (Transpose, _) -> "T"
265268
| Transpose (Batch_slice _, _) -> "@|"
266269
| Transpose (Uint4x32_to_prec _, _) -> "U4x32"
@@ -470,6 +473,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
470473
:: mark_terminal () )
471474
else (Row.dim_map_empty, mark_terminal ())
472475
| Terminal (Fetch (Embed_symbol _)) -> (Row.dim_map_empty, mark_terminal ())
476+
| Terminal (Fetch (Embed_dim _)) -> (Row.dim_map_empty, mark_terminal ())
473477
| Terminal (Fetch Embed_self_id) -> (Row.dim_map_empty, mark_terminal ())
474478
| Transpose (Transpose, sh) ->
475479
( Row.dim_map_empty,
@@ -560,7 +564,8 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
560564
Row_eq { r1 = cur_sh.input; r2 = sh.input };
561565
Row_eq { r1 = cur_sh.output; r2 = sh.output };
562566
] )
563-
| Transpose (Permute spec, sh) ->
567+
| Transpose (Permute (spec, _dim_refs), sh) ->
568+
(* FIXME: support dim_refs *)
564569
let ls_rhs, ls_lhs =
565570
match einsum_of_spec spec with
566571
| ls_rhs, None, ls_lhs -> (ls_rhs, ls_lhs)
@@ -610,7 +615,8 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
610615
{ numerator = Row.Strided_var { coeff; var; denom = 1 }; divided_by = [] };
611616
};
612617
] )
613-
| Broadcast (Einsum spec, sh1, sh2) ->
618+
| Broadcast (Einsum (spec, _dim_refs), sh1, sh2) ->
619+
(* FIXME: support dim_refs *)
614620
let ls_rhs1, ls_rhs2, ls_lhs =
615621
match einsum_of_spec spec with
616622
| ls_rhs1, Some ls_rhs2, ls_lhs -> (ls_rhs1, ls_rhs2, ls_lhs)

lib/shape.mli

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,34 +87,41 @@ type compose_type =
8787
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape
8888
of [fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix
8989
multiply). *)
90-
| Einsum of string
90+
| Einsum of string * Ir.Indexing.variable_ref list
9191
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
9292
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
9393
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
9494
notation, to line up the axes and row variables. The symmetric difference / disjunctive
9595
union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
9696
97+
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
98+
corresponding to the specification labels equal to [ref_label] of a reference.
99+
97100
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs",
98101
"rhs1;rhs2=>lhs". *)
99-
[@@deriving sexp, equal]
102+
[@@deriving sexp_of, equal]
100103

101104
type transpose_type =
102105
| Transpose (** Swaps inputs and outputs of a shape, preserves batch axes. *)
103106
| Pointwise_un (** Preserves the shape. *)
104-
| Permute of string (** The unary "einsum" syntax: RHS1=>LHS. *)
107+
| Permute of string * Ir.Indexing.variable_ref list
108+
(** The unary "einsum" syntax: RHS1=>LHS.
109+
110+
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
111+
corresponding to the specification labels equal to [ref_label] of a reference. *)
105112
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)
106113
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
107114
(** Converts precision in a bit-effient way, with a corresponding conversion in total number
108115
of elements. Currently, assumes the incoming tensor (RHS) has just a single axis to not
109116
force unnecessary minimum sizes on output axes. *)
110-
[@@deriving equal, sexp]
117+
[@@deriving equal, sexp_of]
111118

112119
(** If you miss expressivity here, leave a note on
113120
{{:https://github.com/ahrefs/ocannl/issues/305}issue 305}. *)
114121
type ternary_type =
115122
| Pointwise_tern (** As in the operation [Where]. *)
116123
| Compose_accumulate (** As in the operation [FMA]. *)
117-
[@@deriving equal, sexp]
124+
[@@deriving equal, sexp_of]
118125

119126
(** Extracts any available shape information from the initialization or fetch. *)
120127
type terminal_type = Data of Ir.Assignments.init_data | Fetch of Ir.Assignments.fetch_op

lib/shape_inference.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,20 +104,28 @@ type compose_type =
104104
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape
105105
of [fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix
106106
multiply). *)
107-
| Einsum of string
107+
| Einsum of string * Ir.Indexing.variable_ref list
108108
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
109109
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
110110
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
111111
notation, to line up the axes and row variables. The symmetric difference / disjunctive
112112
union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
113113
114+
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
115+
corresponding to the specification labels equal to [ref_label] of a reference.
116+
114117
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs",
115118
"rhs1;rhs2=>lhs". *)
119+
[@@deriving sexp, equal]
116120
117121
type transpose_type =
118122
| Transpose (** Swaps inputs and outputs of a shape, preserves batch axes. *)
119123
| Pointwise_un (** Preserves the shape. *)
120-
| Permute of string (** The unary "einsum" syntax: RHS1=>LHS. *)
124+
| Permute of string * Ir.Indexing.variable_ref list
125+
(** The unary "einsum" syntax: RHS1=>LHS.
126+
127+
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions
128+
corresponding to the specification labels equal to [ref_label] of a reference. *)
121129
| Batch_slice of Ir.Indexing.static_symbol (** Removes the leftmost batch axis. *)
122130
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
123131
(** Converts precision in a bit-effient way, with a corresponding conversion in total number

lib/syntax_extensions.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ Examples:
383383
- `..v..|...ijk => ..v..kji`: reverse the three rightmost output axes, reduce any other output axes, pointwise for batch axes, pairing the batch axes with the leftmost output axes of the result. Fails if the argument has input axes.
384384
- `2..v..|... => ..v..`: slice the tensor at dimension 2 of the leftmost batch axis, reduce all its output axes, preserve its other batch axes as output axes. Fails if the argument has input axes.
385385

386+
### Capturing the dimensions of selected axes for further computation
387+
388+
The syntaxes `*+` and `++` accept an optional list of strings argument after the specification string. When passed, the strings should be some of the identifiers used in the specification. Both dimension variable and row variable labels are supported. This will introduce bindings for `Indexing.variable_ref` objects at the same point as the inline parameter definition bindings, and will pass these objects with the `~capture_dims` argument to `einsum` resp. `einsum1`. The bound objects can later be used with `Operation.embed_dim` or its alias `Operation.TDSL.O.dim` to embed the solved dimension of the corresponding variable (as a number) into a tensor expression. For a row variable, the number will be the product of the dimensions it resolved into.
389+
386390
## Further features of the syntax extension %cd
387391

388392
### Referencing arrays: tensor value, tensor gradient, merge buffer of a tensor node

lib/tensor.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ let%track7_sexp term ?init_data ?fetch_op ?grad_spec ?(label = []) ?(top_down_pr
482482
match fetch_op with
483483
| None -> Asgns.empty_comp
484484
| Some
485-
(( Constant _ | Constant_bits _ | Slice _ | Embed_symbol _ | Embed_self_id
485+
(( Constant _ | Constant_bits _ | Slice _ | Embed_symbol _ | Embed_dim _ | Embed_self_id
486486
| Range_over_offsets | Constant_fill _ ) as fetch_op) ->
487487
Asgns.to_comp @@ Fetch { array = v; fetch_op; dims }
488488
in

0 commit comments

Comments
 (0)