Skip to content

Commit 9dcfcaa

Browse files
committed
In progress: New shape specification functionality: add equality constraints via captured variables
Move to using `Shape`-level `delayed_var_ref` for variable capture. Extend Shape API with `set_dim` and `set_equal` introducing equality constraints, where "equal" between a row and a dimension is interpreted via `Total_elems` -- not by assuming the row is that single axis!
1 parent 500bf0f commit 9dcfcaa

File tree

7 files changed

+78
-24
lines changed

7 files changed

+78
-24
lines changed

lib/operation.ml

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ let matmul ?(label = []) =
8080
let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in
8181
mul Compose ~op_asn ~label:("*" :: label)
8282

83-
let capture_dims_to_refs = List.map ~f:(fun var_ref -> { Shape.var_ref; var = `Not_set_yet })
84-
8583
(** Similar to the explicit mode of [numpy.einsum], the binary variant. Can compute various forms of
8684
matrix multiplication, inner and outer products, etc.
8785
@@ -94,9 +92,7 @@ let einsum ?(label = []) ?(capture_dims = []) spec =
9492
g1 =+ g * v2;
9593
g2 =+ v1 * g
9694
in
97-
Tensor.binop ~label:(";=>" :: label)
98-
~compose_op:(Einsum (spec, capture_dims_to_refs capture_dims))
99-
~op_asn ~grad_asn
95+
Tensor.binop ~label:(";=>" :: label) ~compose_op:(Einsum (spec, capture_dims)) ~op_asn ~grad_asn
10096

10197
(** Like [einsum], but adds instead than multiplying the resulting values. *)
10298
let outer_sum ?(label = []) ?(capture_dims = []) spec =
@@ -106,9 +102,7 @@ let outer_sum ?(label = []) ?(capture_dims = []) spec =
106102
g1 =+ g;
107103
g2 =+ g
108104
in
109-
Tensor.binop ~label:(";=>+" :: label)
110-
~compose_op:(Einsum (spec, capture_dims_to_refs capture_dims))
111-
~op_asn ~grad_asn
105+
Tensor.binop ~label:(";=>+" :: label) ~compose_op:(Einsum (spec, capture_dims)) ~op_asn ~grad_asn
112106

113107
(** Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract
114108
diagonals, compute traces etc.
@@ -120,7 +114,7 @@ let einsum1 ?(label = []) ?(capture_dims = []) spec =
120114
let%cd op_asn ~v ~t1 ~projections = v =:+ v1 in
121115
let%cd grad_asn ~t:_ ~g ~t1 ~projections = g1 =+ g in
122116
Tensor.unop
123-
~transpose_op:(Shape.Permute (spec, capture_dims_to_refs capture_dims))
117+
~transpose_op:(Shape.Permute (spec, capture_dims))
124118
~op_asn ~grad_asn ~label:("=>" :: label)
125119

126120
module NDO_before_pow = struct
@@ -471,8 +465,8 @@ let embed_self_id ?grad_spec ?(label = []) () =
471465
~input_dims:[] ~output_dims:[ 1 ] ()
472466

473467
let embed_dim ?grad_spec ?(label = []) variable_ref =
474-
Tensor.term ~fetch_op:(Embed_dim variable_ref) ?grad_spec ~label:("!@self_id" :: label)
475-
~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ()
468+
Tensor.term ~fetch_op:(Embed_dim variable_ref.Shape.var_ref) ?grad_spec
469+
~label:("!@self_id" :: label) ~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ()
476470

477471
let uniform ?grad_spec () =
478472
uint4x32_to_prec_uniform ?grad_spec
@@ -674,6 +668,7 @@ end
674668
module DSL_modules = struct
675669
module Shape = Shape
676670
module Tensor = Tensor
671+
677672
module TDSL = Make_DSL (struct
678673
let grad_spec = Tensor.If_needed
679674
end)

lib/ppx_shared.ml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,7 @@ let collect_capture_labels ~loc head rest =
334334
in
335335
let capture_refs, capture_bindings =
336336
List.map capture_labels ~f:(fun (loc, label) ->
337-
let ref_expr =
338-
[%expr
339-
{
340-
Ir.Indexing.ref_label = [%e Ast_builder.Default.estring ~loc label];
341-
solved_dim = None;
342-
}]
343-
in
337+
let ref_expr = [%expr Shape.get_variable_ref [%e Ast_builder.Default.estring ~loc label]] in
344338
let binding =
345339
Ast_builder.Default.value_binding ~loc
346340
~pat:(Ast_builder.Default.pvar ~loc label)

lib/row.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ let phantom_row_id = row_id ~sh_id:(-1) ~kind:`Output
169169
type t = { dims : dim list; bcast : bcast; id : row_id } [@@deriving equal, hash, compare, sexp]
170170
type row = t [@@deriving equal, sexp]
171171

172+
let get_row_for_var ?(row_id = phantom_row_id) v =
173+
{ dims = []; bcast = Row_var { v; beg_dims = [] }; id = row_id }
174+
172175
let dims_label_assoc dims =
173176
let f = function Var { label = Some l; _ } as d -> Some (l, d) | _ -> None in
174177
List.filter_map dims.dims ~f

lib/row.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type bcast =
6161
type t = { dims : dim list; bcast : bcast; id : row_id } [@@deriving equal, hash, compare, sexp]
6262

6363
val dims_label_assoc : t -> (string * dim) list
64+
val get_row_for_var : ?row_id:row_id -> row_var -> t
6465

6566
type environment [@@deriving sexp_of]
6667
type error_trace = ..

lib/shape.ml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ type delayed_var_ref = {
9898
}
9999
[@@deriving equal, sexp_of]
100100

101+
let get_variable_ref ref_label =
102+
{ var_ref = Ir.Indexing.{ ref_label; solved_dim = None }; var = `Not_set_yet }
103+
101104
type compose_type = Pointwise_bin | Compose | Einsum of string * delayed_var_ref list
102105
[@@deriving sexp_of, equal]
103106

@@ -685,6 +688,46 @@ let state = ref Row.empty_env
685688
let active_update_steps = ref []
686689
let active_constraints = ref []
687690

691+
let set_dim delayed_var_ref dim =
692+
match delayed_var_ref with
693+
| { var_ref = { solved_dim = Some dim2; _ }; _ } when dim2 = dim -> ()
694+
| { var_ref = { solved_dim = Some dim2; ref_label; _ }; _ } ->
695+
raise
696+
@@ Row.Shape_error
697+
( "Cannot set dimension for variable reference with label " ^ ref_label,
698+
[ Row.Dim_mismatch [ Row.get_dim ~d:dim2 (); Row.get_dim ~d:dim () ] ] )
699+
| { var_ref = { solved_dim = None; _ }; var = `Not_set_yet } ->
700+
delayed_var_ref.var_ref.solved_dim <- Some dim
701+
| { var_ref = { solved_dim = None; _ }; var = `Dim dim_var } ->
702+
delayed_var_ref.var_ref.solved_dim <- Some dim;
703+
active_constraints :=
704+
Row.Dim_eq { d1 = Row.Var dim_var; d2 = Row.get_dim ~d:dim () } :: !active_constraints
705+
| { var_ref = { solved_dim = None; _ }; var = `Row row_var } ->
706+
delayed_var_ref.var_ref.solved_dim <- Some dim;
707+
active_constraints :=
708+
Row.Rows_constr
709+
{
710+
(* TODO: actually, the Row.row_id should be the one of the shape that the row variable
711+
is in, should be stored in `Row and in env_row_var. *)
712+
r = [ Row.get_row_for_var row_var ];
713+
constr = Total_elems { numerator = Num_elems dim; divided_by = [] };
714+
}
715+
:: !active_constraints
716+
717+
let set_equal delayed_ref1 delayed_ref2 =
718+
match delayed_ref1, delayed_ref2 with
719+
| { var_ref = { solved_dim = Some dim1; _ }; _ },
720+
{ var_ref = { solved_dim = Some dim2; _ }; _ } ->
721+
if dim1 = dim2 then ()
722+
else
723+
raise
724+
@@ Row.Shape_error
725+
( "Cannot set equal dimensions for variable references with different values",
726+
[ Row.Dim_mismatch [ Row.get_dim ~d:dim1 (); Row.get_dim ~d:dim2 () ] ] )
727+
| _ ->
728+
(* FIXME: NOT IMPLEMENTED YET *)
729+
()
730+
688731
let unsafe_reinitialize () =
689732
update_uid := 0;
690733
state := Row.empty_env;

lib/shape.mli

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ type delayed_var_ref = {
8686
}
8787
[@@deriving equal, sexp_of]
8888

89+
val get_variable_ref : string -> delayed_var_ref
90+
(** Returns a fully unset variable reference with the given label. *)
91+
92+
val set_dim : delayed_var_ref -> int -> unit
93+
(** Sets the dimension resp. total elements of the dim resp. row variable reference to the given
94+
value. This will propagate through shape inference.
95+
96+
For row variables, this means the product of the dimensions, via the [Total_elems] constraint.
97+
*)
98+
99+
val set_equal : delayed_var_ref -> delayed_var_ref -> unit
100+
(** Sets the two variable references to be equal (in some sense). This will propagate through shape
101+
inference.
102+
103+
When both references are dimension variables or both are row variables, this means they are
104+
precisely equal. When one is a dimension variable and the other is a row variable, this means
105+
they have the same number of total elements. *)
106+
89107
type compose_type =
90108
| Pointwise_bin
91109
(** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *)

test/operations/test_einsum_capture.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ let () =
1313

1414
(* Check if dimensions were captured *)
1515
Stdio.printf "Dimension a: %s\n"
16-
(match a.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
16+
(match a.var_ref.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
1717
Stdio.printf "Dimension b: %s\n"
18-
(match b.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
18+
(match b.var_ref.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
1919
Stdio.printf "Dimension c: %s\n"
20-
(match c.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
20+
(match c.var_ref.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
2121

2222
let%op x2 = { x2 = uniform1 (); o = [ 5; 7 ] } in
2323
(* Manually call einsum1 with capture_dims for now *)
@@ -28,9 +28,9 @@ let () =
2828

2929
(* Check if dimensions were captured *)
3030
Stdio.printf "Dimension i: %s\n"
31-
(match i.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
31+
(match i.var_ref.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
3232
Stdio.printf "Dimension j: %s\n"
33-
(match j.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
33+
(match j.var_ref.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
3434

3535
(* Test capturing row variables *)
3636
let%op x3 = { x3 = uniform1 (); o = [ 2; 3; 4 ] } in
@@ -42,7 +42,7 @@ let () =
4242

4343
(* Check if row variable was captured *)
4444
Stdio.printf "Row variable r (product of dims): %s\n"
45-
(match r.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
45+
(match r.var_ref.solved_dim with Some d -> Int.to_string d | None -> "not resolved");
4646

4747
let%op dim_calc = dim a + dim j + dim r in
4848
let _ctx = Train.forward_once (module Backend) ~ctx dim_calc in

0 commit comments

Comments
 (0)