Skip to content

Commit f650c0c

Browse files
committed
Embedding of dimensions in tensor expressions: track variables so references can be updated by an upcoming apply_env_step
1 parent 2a15301 commit f650c0c

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

lib/operation.ml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ let matmul ?(label = []) =
8080
let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in
8181
mul Compose ~op_asn ~label:("*" :: label)
8282

83+
let capture_dims_to_refs = List.map ~f:(fun var_ref -> { Shape.var_ref; var = `Not_set_yet })
84+
8385
(** Similar to the explicit mode of [numpy.einsum], the binary variant. Can compute various forms of
8486
matrix multiplication, inner and outer products, etc.
8587
@@ -92,7 +94,9 @@ let einsum ?(label = []) ?(capture_dims = []) spec =
9294
g1 =+ g * v2;
9395
g2 =+ v1 * g
9496
in
95-
Tensor.binop ~label:(";=>" :: label) ~compose_op:(Einsum (spec, capture_dims)) ~op_asn ~grad_asn
97+
Tensor.binop ~label:(";=>" :: label)
98+
~compose_op:(Einsum (spec, capture_dims_to_refs capture_dims))
99+
~op_asn ~grad_asn
96100

97101
(** Like [einsum], but adds instead than multiplying the resulting values. *)
98102
let outer_sum ?(label = []) ?(capture_dims = []) spec =
@@ -102,7 +106,9 @@ let outer_sum ?(label = []) ?(capture_dims = []) spec =
102106
g1 =+ g;
103107
g2 =+ g
104108
in
105-
Tensor.binop ~label:(";=>+" :: label) ~compose_op:(Einsum (spec, capture_dims)) ~op_asn ~grad_asn
109+
Tensor.binop ~label:(";=>+" :: label)
110+
~compose_op:(Einsum (spec, capture_dims_to_refs capture_dims))
111+
~op_asn ~grad_asn
106112

107113
(** Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract
108114
diagonals, compute traces etc.
@@ -114,7 +120,7 @@ let einsum1 ?(label = []) ?(capture_dims = []) spec =
114120
let%cd op_asn ~v ~t1 ~projections = v =:+ v1 in
115121
let%cd grad_asn ~t:_ ~g ~t1 ~projections = g1 =+ g in
116122
Tensor.unop
117-
~transpose_op:(Shape.Permute (spec, capture_dims))
123+
~transpose_op:(Shape.Permute (spec, capture_dims_to_refs capture_dims))
118124
~op_asn ~grad_asn ~label:("=>" :: label)
119125

120126
module NDO_before_pow = struct

lib/shape.ml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,19 @@ let row_of_kind = function `Batch -> batch | `Input -> input | `Output -> output
9292
type deduce_within_shape = Not_constrained | Input_equals_output
9393
[@@deriving compare, sexp, variants]
9494

95-
type compose_type = Pointwise_bin | Compose | Einsum of string * Idx.variable_ref list
95+
type delayed_var_ref = {
96+
var_ref : Ir.Indexing.variable_ref;
97+
mutable var : [ `Row of Row.row_var | `Dim of Row.dim_var | `Not_set_yet ];
98+
}
99+
[@@deriving equal, sexp_of]
100+
101+
type compose_type = Pointwise_bin | Compose | Einsum of string * delayed_var_ref list
96102
[@@deriving sexp_of, equal]
97103

98104
type transpose_type =
99105
| Transpose
100106
| Pointwise_un
101-
| Permute of string * Idx.variable_ref list
107+
| Permute of string * delayed_var_ref list
102108
| Batch_slice of Idx.static_symbol
103109
| Uint4x32_to_prec of Ir.Ops.prec Lazy.t
104110
[@@deriving equal, sexp_of]
@@ -705,6 +711,7 @@ let%debug4_sexp propagate_shapes (update_step : update_step) : unit =
705711
active_constraints := ineqs @ !active_constraints;
706712
let ineqs', env = Row.solve_inequalities ~stage:Row.Stage1 ineqs !state in
707713
let _debug_remaining_constraints : Row.constraint_ list = ineqs' in
714+
(* FIXME: call apply_env_step instead *)
708715
iter_shapes update_step ~f:(apply_env_t env);
709716
state := env
710717

@@ -725,6 +732,7 @@ let%debug4_sexp finish_inference (() : unit) : unit =
725732
let unsolved, env = Row.solve_inequalities ~stage:Stage7 unsolved env in
726733
assert (List.is_empty unsolved);
727734
let _active_update_steps : update_step list = !active_update_steps in
735+
(* FIXME: call apply_env_step instead *)
728736
List.iter ~f:(iter_shapes ~f:(apply_env_t env)) !active_update_steps;
729737
let _applied_update_steps : update_step list = !active_update_steps in
730738
active_constraints := [];

lib/shape.mli

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,20 @@ type t = {
8080

8181
type deduce_within_shape = Not_constrained | Input_equals_output [@@deriving compare, sexp]
8282

83+
type delayed_var_ref = {
84+
var_ref : Ir.Indexing.variable_ref;
85+
mutable var : [ `Row of Row.row_var | `Dim of Row.dim_var | `Not_set_yet ];
86+
}
87+
[@@deriving equal, sexp_of]
88+
8389
type compose_type =
8490
| Pointwise_bin
8591
(** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *)
8692
| Compose
8793
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape
8894
of [fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix
8995
multiply). *)
90-
| Einsum of string * Ir.Indexing.variable_ref list
96+
| Einsum of string * delayed_var_ref list
9197
(** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications.
9298
Since OCANNL's extended einsum notation supports both axis variables and row variables, it
9399
makes other compose types redundant. The [axis_labels] use pseudo-labels local to the
@@ -104,7 +110,7 @@ type compose_type =
104110
type transpose_type =
105111
| Transpose (** Swaps inputs and outputs of a shape, preserves batch axes. *)
106112
| Pointwise_un (** Preserves the shape. *)
107-
| Permute of string * Ir.Indexing.variable_ref list
113+
| Permute of string * delayed_var_ref list
108114
(** The unary "einsum" syntax: RHS1=>LHS.
109115
110116
The optional {!Ir.Indexing.variable_ref}s will capture the solutions of the dimensions

0 commit comments

Comments
 (0)