You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/syntax_extensions.md
+6-6Lines changed: 6 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -212,7 +212,7 @@ type Assignments.t =
212
212
213
213
For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
214
214
215
-
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where `<rhs1>`, `<rhs2>` (for binary and ternary ops), `<rhs3>` (for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>` will often be either special-purpose identifiers (specifically `v`, `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
215
+
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where `<rhs1>`, `<rhs2>` (for binary and ternary ops), `<rhs3>` (for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>` will often be either special-purpose identifiers (specifically `v`, `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `+*` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
216
216
217
217
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default depends on the primitive operation, but it is almost always a pointwise operation.
218
218
@@ -318,9 +318,9 @@ let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] }
318
318
319
319
## Using OCANNL's generalized einsum notation
320
320
321
-
As we mentioned above, in the `%cd` syntax you can set up an arbitrary assignment with projections derived from a generalized einsum specification, by passing the specification as a string with the `~logic` label. However, both the `%cd` and `%op` syntaxes support built-in operators that take an einsum specification: `*+` binding to `NTDSL.einsum` resp. `TDSL.einsum`, and `++` binding to `NTDSL.einsum1` resp. `TDSL.einsum1`. `*+` is a "ternary" operator, binary wrt. tensor arguments, and `++` is a binary operator, unary postfix wrt. tensor arguments. The einsum specification string should directly follow `*+` and `++`.
321
+
As we mentioned above, in the `%cd` syntax you can set up an arbitrary assignment with projections derived from a generalized einsum specification, by passing the specification as a string with the `~logic` label. However, both the `%cd` and `%op` syntaxes support built-in operators that take an einsum specification: `+*` binding to `NTDSL.einsum` resp. `TDSL.einsum`, and `++` binding to `NTDSL.einsum1` resp. `TDSL.einsum1`. `+*` is a "ternary" operator, binary wrt. tensor arguments, and `++` is a binary operator, unary postfix wrt. tensor arguments. The einsum specification string should directly follow `+*` and `++`.
322
322
323
-
Both `*+` and `++` use addition for the accumulation operation; `*+` uses multiplication. You can verify that looking at the `Operation.einsum` and `Operation.einsum1` definitions. You can find examples of `*+` and `++` behavior in the test suite [einsum_trivia.ml](test/einsum_trivia.ml). A frequent use-case for `++` is to sum out all axes of a tensor:
323
+
Both `+*` and `++` use addition for the accumulation operation; `+*` uses multiplication. You can verify that looking at the `Operation.einsum` and `Operation.einsum1` definitions. You can find examples of `+*` and `++` behavior in the test suite [einsum_trivia.ml](test/einsum_trivia.ml). A frequent use-case for `++` is to sum out all axes of a tensor:
- A number specifies the particular dimension within the axis,
366
366
- A `+` sign specifies a convolution input axis with the output on the left of `+` and the kernel on the right of `+`.
367
367
- In both the output part and the kernel part you can prefix the axis variable by a constant coefficient with the `*` sign.
368
-
- The coefficient can directly only be an integer, e.g. `"2*i+3*k"`, but under the `%op` and `%cd` syntax extensions, it can also be an identifier of an integer value, e.g. `let stride = 2 and dilation = 3 in [%op "input" *+ "stride * a + dilation * b; b=>a," "kernel"]`.
369
-
- Note the comma above. The syntax extension's expansion of stride and dilation respects the "multichar" mode. Without the comma we are limited to single-character identifiers, e.g. `let s = 2 and d = 3 in [%op "input" *+ "is*a+d*bc;b=>iac" "kernel"]`.
368
+
- The coefficient can directly only be an integer, e.g. `"2*i+3*k"`, but under the `%op` and `%cd` syntax extensions, it can also be an identifier of an integer value, e.g. `let stride = 2 and dilation = 3 in [%op "input" +* "stride * a + dilation * b; b=>a," "kernel"]`.
369
+
- Note the comma above. The syntax extension's expansion of stride and dilation respects the "multichar" mode. Without the comma we are limited to single-character identifiers, e.g. `let s = 2 and d = 3 in [%op "input" +* "is*a+d*bc;b=>iac" "kernel"]`.
370
370
371
371
Examples:
372
372
@@ -385,7 +385,7 @@ Examples:
385
385
386
386
### Capturing the dimensions of selected axes for further computation or to add shape constraints
387
387
388
-
The syntaxes `*+` and `++` accept an optional list of strings argument after the specification string. When passed, the strings should be some of the identifiers used in the specification. Both dimension variable and row variable labels are supported. This will introduce bindings for `Indexing.variable_ref` objects at the same point as the inline parameter definition bindings, and will pass these objects with the `~capture_dims` argument to `einsum` resp. `einsum1`. The bound objects can later be used with `Operation.embed_dim` or its alias `Operation.TDSL.O.dim` to embed the solved dimension of the corresponding variable (as a number) into a tensor expression. For a row variable, the number will be the product of the dimensions it resolved into.
388
+
The syntaxes `+*` and `++` accept an optional list of strings argument after the specification string. When passed, the strings should be some of the identifiers used in the specification. Both dimension variable and row variable labels are supported. This will introduce bindings for `Indexing.variable_ref` objects at the same point as the inline parameter definition bindings, and will pass these objects with the `~capture_dims` argument to `einsum` resp. `einsum1`. The bound objects can later be used with `Operation.embed_dim` or its alias `Operation.TDSL.O.dim` to embed the solved dimension of the corresponding variable (as a number) into a tensor expression. For a row variable, the number will be the product of the dimensions it resolved into.
0 commit comments