Skip to content

Commit 7ca2ae2

Browse files
committed
The great renaming: *+ --> +* (einsum operation built-in syntax)
1 parent 7a01b5c commit 7ca2ae2

File tree

8 files changed

+55
-55
lines changed

8 files changed

+55
-55
lines changed

bin/einsum_trivia.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ let _suspended () =
1010
let module Backend = (val Backends.fresh_backend ()) in
1111
let a = TDSL.range_of_shape ~label:[ "a" ] ~input_dims:[ 2 ] ~output_dims:[ 2 ] () in
1212
let b = TDSL.range_of_shape ~label:[ "b" ] ~input_dims:[ 2; 3; 4 ] ~output_dims:[ 2 ] () in
13-
let%op c = a *+ "i->1; ij...->0 => ...->ji" b in
13+
let%op c = a +* "i->1; ij...->0 => ...->ji" b in
1414
ignore (Train.forward_once (module Backend) c);
1515
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
1616
Stdio.printf "\n%!"
@@ -50,12 +50,12 @@ let _suspended () =
5050

5151
let a = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
5252
let b = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 4 ] ~output_dims:[ 5 ] () in
53-
let%op _ = a *+ "b|i->o; b|i->o => b|i->o" a in
54-
let%op c = b *+ "b|h->o; b|i->h => b|i->o" a in
53+
let%op _ = a +* "b|i->o; b|i->o => b|i->o" a in
54+
let%op c = b +* "b|h->o; b|i->h => b|i->o" a in
5555
Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend c));
56-
(* let%op d = a *+ "a|i->h; b|h->o => ab|i->o" b in Utils.capture_stdout_logs (fun () ->
57-
ignore (Train.forward_once backend d)); let%op e = a *+ "b|i->h; b|h->o => i->o" b in
58-
Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend e)); let%op f = a *+
56+
(* let%op d = a +* "a|i->h; b|h->o => ab|i->o" b in Utils.capture_stdout_logs (fun () ->
57+
ignore (Train.forward_once backend d)); let%op e = a +* "b|i->h; b|h->o => i->o" b in
58+
Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend e)); let%op f = a +*
5959
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> ignore (Train.forward_once backend f)); *)
6060
(* Train.printf ~here:[%here] ~with_code:false ~with_grad:false a2; *)
6161
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c

lib/ppx_cd.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ let translate ?ident_label (expr : expression) : result =
992992
{ res1 with typ = Tensor; expr = [%expr NTDSL.O.( **. ) [%e res1.expr] [%e expr2]] }
993993
| [%expr
994994
[%e? expr1]
995-
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }] [%e? expr2]]
995+
+* [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }] [%e? expr2]]
996996
when String.contains spec_str '>' ->
997997
let res1 = loop ~proj_in_scope expr1 in
998998
let res2 = loop ~proj_in_scope expr2 in
@@ -1007,7 +1007,7 @@ let translate ?ident_label (expr : expression) : result =
10071007
}
10081008
| [%expr
10091009
[%e? expr1]
1010-
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
1010+
+* [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
10111011
([%e? { pexp_desc = Pexp_constant (Pconst_string _); _ } as head] :: [%e? rest])
10121012
[%e? expr2]]
10131013
when String.contains spec_str '>' ->

lib/ppx_op.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
185185
)
186186
| [%expr
187187
[%e? expr1]
188-
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }] [%e? expr2]]
188+
+* [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }] [%e? expr2]]
189189
when String.contains spec_str '>' ->
190190
let vbs1, e1 = loop expr1 in
191191
let vbs2, e2 = loop expr2 in
@@ -199,7 +199,7 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
199199
(vbs1, [%expr einsum1 ?label:[%e opt_expr ~loc label] [%e spec] [%e e1]])
200200
| [%expr
201201
[%e? expr1]
202-
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
202+
+* [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]
203203
([%e? { pexp_desc = Pexp_constant (Pconst_string _); _ } as head] :: [%e? rest])
204204
[%e? expr2]]
205205
when String.contains spec_str '>' ->

lib/syntax_extensions.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ type Assignments.t =
212212

213213
For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
214214

215-
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where `<rhs1>`, `<rhs2>` (for binary and ternary ops), `<rhs3>` (for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>` will often be either special-purpose identifiers (specifically `v`, `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
215+
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where `<rhs1>`, `<rhs2>` (for binary and ternary ops), `<rhs3>` (for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>` will often be either special-purpose identifiers (specifically `v`, `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `+*` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
216216

217217
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default depends on the primitive operation, but it is almost always a pointwise operation.
218218

@@ -318,9 +318,9 @@ let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }
318318

319319
## Using OCANNL's generalized einsum notation
320320

321-
As we mentioned above, in the `%cd` syntax you can set up an arbitrary assignment with projections derived from a generalized einsum specification, by passing the specification as a string with the `~logic` label. However, both the `%cd` and `%op` syntaxes support built-in operators that take an einsum specification: `*+` binding to `NTDSL.einsum` resp. `TDSL.einsum`, and `++` binding to `NTDSL.einsum1` resp. `TDSL.einsum1`. `*+` is a "ternary" operator, binary wrt. tensor arguments, and `++` is a binary operator, unary postfix wrt. tensor arguments. The einsum specification string should directly follow `*+` and `++`.
321+
As we mentioned above, in the `%cd` syntax you can set up an arbitrary assignment with projections derived from a generalized einsum specification, by passing the specification as a string with the `~logic` label. However, both the `%cd` and `%op` syntaxes support built-in operators that take an einsum specification: `+*` binding to `NTDSL.einsum` resp. `TDSL.einsum`, and `++` binding to `NTDSL.einsum1` resp. `TDSL.einsum1`. `+*` is a "ternary" operator, binary wrt. tensor arguments, and `++` is a binary operator, unary postfix wrt. tensor arguments. The einsum specification string should directly follow `+*` and `++`.
322322

323-
Both `*+` and `++` use addition for the accumulation operation; `*+` uses multiplication. You can verify that looking at the `Operation.einsum` and `Operation.einsum1` definitions. You can find examples of `*+` and `++` behavior in the test suite [einsum_trivia.ml](test/einsum_trivia.ml). A frequent use-case for `++` is to sum out all axes of a tensor:
323+
Both `+*` and `++` use addition for the accumulation operation; `+*` uses multiplication. You can verify that looking at the `Operation.einsum` and `Operation.einsum1` definitions. You can find examples of `+*` and `++` behavior in the test suite [einsum_trivia.ml](test/einsum_trivia.ml). A frequent use-case for `++` is to sum out all axes of a tensor:
324324

325325
```ocaml
326326
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
@@ -365,8 +365,8 @@ The syntax of an axis spec:
365365
- A number specifies the particular dimension within the axis,
366366
- A `+` sign specifies a convolution input axis with the output on the left of `+` and the kernel on the right of `+`.
367367
- In both the output part and the kernel part you can prefix the axis variable by a constant coefficient with the `*` sign.
368-
- The coefficient can directly only be an integer, e.g. `"2*i+3*k"`, but under the `%op` and `%cd` syntax extensions, it can also be an identifier of an integer value, e.g. `let stride = 2 and dilation = 3 in [%op "input" *+ "stride * a + dilation * b; b=>a," "kernel"]`.
369-
- Note the comma above. The syntax extension's expansion of stride and dilation respects the "multichar" mode. Without the comma we are limited to single-character identifiers, e.g. `let s = 2 and d = 3 in [%op "input" *+ "is*a+d*bc;b=>iac" "kernel"]`.
368+
- The coefficient can directly only be an integer, e.g. `"2*i+3*k"`, but under the `%op` and `%cd` syntax extensions, it can also be an identifier of an integer value, e.g. `let stride = 2 and dilation = 3 in [%op "input" +* "stride * a + dilation * b; b=>a," "kernel"]`.
369+
- Note the comma above. The syntax extension's expansion of stride and dilation respects the "multichar" mode. Without the comma we are limited to single-character identifiers, e.g. `let s = 2 and d = 3 in [%op "input" +* "is*a+d*bc;b=>iac" "kernel"]`.
370370

371371
Examples:
372372

@@ -385,7 +385,7 @@ Examples:
385385

386386
### Capturing the dimensions of selected axes for further computation or to add shape constraints
387387

388-
The syntaxes `*+` and `++` accept an optional list of strings argument after the specification string. When passed, the strings should be some of the identifiers used in the specification. Both dimension variable and row variable labels are supported. This will introduce bindings for `Indexing.variable_ref` objects at the same point as the inline parameter definition bindings, and will pass these objects with the `~capture_dims` argument to `einsum` resp. `einsum1`. The bound objects can later be used with `Operation.embed_dim` or its alias `Operation.TDSL.O.dim` to embed the solved dimension of the corresponding variable (as a number) into a tensor expression. For a row variable, the number will be the product of the dimensions it resolved into.
388+
The syntaxes `+*` and `++` accept an optional list of strings argument after the specification string. When passed, the strings should be some of the identifiers used in the specification. Both dimension variable and row variable labels are supported. This will introduce bindings for `Indexing.variable_ref` objects at the same point as the inline parameter definition bindings, and will pass these objects with the `~capture_dims` argument to `einsum` resp. `einsum1`. The bound objects can later be used with `Operation.embed_dim` or its alias `Operation.TDSL.O.dim` to embed the solved dimension of the corresponding variable (as a number) into a tensor expression. For a row variable, the number will be the product of the dimensions it resolved into.
389389

390390
## Further features of the syntax extension %cd
391391

test/einsum/einsum_trivia.ml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ let%expect_test "einsum outer product" =
225225

226226
let a = TDSL.range_of_shape ~batch_dims:[] ~input_dims:[] ~output_dims:[ 2 ] () in
227227
let b = TDSL.range_of_shape ~batch_dims:[] ~input_dims:[] ~output_dims:[ 3 ] () in
228-
let%op c = (a + 1) *+ "i; j => i->j" b in
228+
let%op c = (a + 1) +* "i; j => i->j" b in
229229
ignore (Train.forward_once backend c);
230230
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
231231
[%expect
@@ -244,7 +244,7 @@ let%expect_test "einsum outer product" =
244244
|}];
245245
let a = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
246246
let b = TDSL.range_of_shape ~batch_dims:[ 5 ] ~input_dims:[ 6 ] ~output_dims:[ 7 ] () in
247-
let%op c = a *+ "i|j->k; l|m->n => il|jm->kn" b in
247+
let%op c = a +* "i|j->k; l|m->n => il|jm->kn" b in
248248
ignore (Train.forward_once backend c);
249249
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
250250
[%expect
@@ -413,15 +413,15 @@ let%expect_test "einsum matrix/inner+outer products" =
413413

414414
let a = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
415415
let b = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 4 ] ~output_dims:[ 5 ] () in
416-
let%op a2 = a *+ "b|i->o; b|i->o => b|i->o" a in
416+
let%op a2 = a +* "b|i->o; b|i->o => b|i->o" a in
417417
let ctx = Train.forward_once backend a2 in
418-
let%op c = b *+ "b|h->o; b|i->h => b|i->o" a in
418+
let%op c = b +* "b|h->o; b|i->h => b|i->o" a in
419419
let ctx = Train.forward_once backend ~ctx c in
420-
let%op d = a *+ "a|i->h; b|h->o => ab|i->o" b in
420+
let%op d = a +* "a|i->h; b|h->o => ab|i->o" b in
421421
ignore (Train.forward_once backend ~ctx d);
422-
let%op e = a *+ "b|i->h; b|h->o => i->o" b in
422+
let%op e = a +* "b|i->h; b|h->o => i->o" b in
423423
ignore (Train.forward_once backend ~ctx e);
424-
let%op f = a *+ "a|i->h; b|h->o => i->o" b in
424+
let%op f = a +* "a|i->h; b|h->o => i->o" b in
425425
ignore (Train.forward_once backend ~ctx f);
426426
Train.printf ~here:[%here] ~with_code:false ~with_grad:false a2;
427427
[%expect
@@ -792,7 +792,7 @@ let%expect_test "einsum broadcast or sum out prefix axes" =
792792

793793
let a = TDSL.range_of_shape ~batch_dims:[ 3 ] ~input_dims:[ 4 ] ~output_dims:[ 2 ] () in
794794
let b = TDSL.range_of_shape ~batch_dims:[ 3 ] ~input_dims:[ 1 ] ~output_dims:[ 4 ] () in
795-
let%op c = a *+ "...|i->...; ...|...->i => ...|i" b in
795+
let%op c = a +* "...|i->...; ...|...->i => ...|i" b in
796796
ignore (Train.forward_once backend c);
797797
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
798798
[%expect
@@ -812,7 +812,7 @@ let%expect_test "einsum broadcast or sum out prefix axes" =
812812
(* Broadcast with a shift. *)
813813
let d = TDSL.range_of_shape ~input_dims:[ 2 ] ~output_dims:[ 3 ] () in
814814
let e = TDSL.range_of_shape ~input_dims:[ 4 ] ~output_dims:[ 3 ] () in
815-
let%op f = d *+ "i->...;j->... => ...ij" e in
815+
let%op f = d +* "i->...;j->... => ...ij" e in
816816
ignore (Train.forward_once backend f);
817817
Train.printf ~here:[%here] ~with_code:false ~with_grad:false f;
818818
[%expect
@@ -930,7 +930,7 @@ let%expect_test "einsum with fixed dim axes" =
930930

931931
let a = TDSL.range_of_shape ~batch_dims:[ 3 ] ~input_dims:[ 4 ] ~output_dims:[ 2 ] () in
932932
let b = TDSL.range_of_shape ~batch_dims:[ 3 ] ~input_dims:[ 1 ] ~output_dims:[ 4 ] () in
933-
let%op c = a *+ "...|i->1; ...|...->i => ...|i" b in
933+
let%op c = a +* "...|i->1; ...|...->i => ...|i" b in
934934
ignore (Train.forward_once backend c);
935935
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
936936
[%expect
@@ -1156,7 +1156,7 @@ let%expect_test "einsum with a leftmost input axis preserved as output axis" =
11561156
let b =
11571157
TDSL.range_of_shape ~label:[ "b" ] ~batch_dims:[ 3 ] ~input_dims:[ 2; 3 ] ~output_dims:[ 4 ] ()
11581158
in
1159-
let%op c = a *+ "...|i->1; ...|j...->i => ...|ij" b in
1159+
let%op c = a +* "...|i->1; ...|j...->i => ...|ij" b in
11601160
ignore (Train.forward_once backend c);
11611161
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
11621162
[%expect
@@ -1190,7 +1190,7 @@ let%expect_test "einsum permuting two leftmost input axes as output axes" =
11901190

11911191
let a = TDSL.range_of_shape ~label:[ "a" ] ~input_dims:[ 2 ] ~output_dims:[ 2 ] () in
11921192
let b = TDSL.range_of_shape ~label:[ "b" ] ~input_dims:[ 2; 3; 4 ] ~output_dims:[ 2 ] () in
1193-
let%op c = a *+ "i->1; ij...->0 => ...->ji" b in
1193+
let%op c = a +* "i->1; ij...->0 => ...->ji" b in
11941194
ignore (Train.forward_once backend c);
11951195
Train.printf ~here:[%here] ~with_code:false ~with_grad:false c;
11961196
[%expect

0 commit comments

Comments
 (0)