Skip to content

Commit 8db6934

Browse files
committed
Syntax extensions documentation: lots of details about %op, many examples
Two major features of `%op` described are still TODO
1 parent 9b2af95 commit 8db6934

File tree

1 file changed

+162
-13
lines changed

1 file changed

+162
-13
lines changed

lib/syntax_extensions.md

Lines changed: 162 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
1-
# Syntax extensions `%cd` and `%op`
2-
3-
- Syntax extension `%cd` stands for "code", to express assignments: `Assignments.t`.
4-
- Syntax extension `%op` stands for "operation", to express tensors: `Tensor.t`.
1+
# Syntax extensions `%cd` and `%op` {#syntax-extensions-cd-and-op}
2+
3+
- Table of contents
4+
- [Preliminaries](#preliminaries)
5+
- [The syntax for %op {#syntax-for-op}](#syntax-for-op)
6+
- [The syntax for %cd](#syntax-for-cd)
7+
- [Numeric and N-dimensional array literals](#numeric-and-n-dimensional-array-literals)
8+
- [Wildcard bindings](#wildcard-bindings)
9+
- [Inline declarations](#inline-declarations)
10+
- [Using OCANNL's generalized einsum notation](#using-ocannls-generalized-einsum-notation)
11+
- [Further features of the syntax extension %cd](#features-of-syntax-cd)
12+
- [Referencing arrays: tensor value, tensor gradient, merge buffer of a tensor node](#referencing-arrays-tensor-value-tensor-gradient-merge-buffer-of-a-tensor-node)
13+
- [Block comments](#block-comments)
14+
- [Further features of the syntax extension %op](#features-of-syntax-op)
15+
- [Name from binding](#name-from-binding)
16+
- [Label from function argument](#label-from-function-argument)
17+
- [Lifting of the applications of ~config arguments: if it's an error, refactor your code](#lifting-of-the-applications-of-config-arguments-if-its-an-error-refactor-your-code)
18+
- [Implementation details](#implementation-details)
19+
- [Syntax extension %cd](#implementation-extension-cd)
20+
- [Syntax extension %op](#implementation-extension-op)
21+
- In a nutshell
22+
- Syntax extension `%cd` stands for "code", to express assignments: `Assignments.t`.
23+
- Syntax extension `%op` stands for "operation", to express tensors: `Tensor.t`.
524

625
## Preliminaries
726

@@ -16,7 +35,37 @@ Functions inside `Operation.NTDSL` use `~grad_spec:Prohibit_grad` when calling i
1635

1736
The extension points open `NTDSL.O`, resp. `TDSL.O`, for the scope of the extension point, to expose the corresponding iterators.
1837

19-
## The syntax for `%cd`
38+
## The syntax for `%op` {#syntax-for-op}
39+
40+
The `%op` syntax is simpler than the `%cd` syntax since it relies more on regular OCaml expressions. For example, we can write without syntax extensions:
41+
42+
```ocaml
43+
let hid_dim = 8 in
44+
let w = Tensor.param "w" in
45+
let b = Tensor.param ~output_dims:[ hid_dim ] "b" in
46+
let layer x = TDSL.O.( !/(w * x + b) ) in
47+
...
48+
```
49+
50+
Since `TDSL.O` is opened for the scope of an extension point `%op`:
51+
52+
```ocaml
53+
let hid_dim = 8 in
54+
let w = Tensor.param "w" in
55+
let b = Tensor.param ~output_dims:[ hid_dim ] "b" in
56+
let%op layer x = !/(w * x + b) in
57+
...
58+
```
59+
60+
Using [inline declarations](#inline-declarations), this becomes more concise:
61+
62+
```ocaml
63+
let hid_dim = 8 in
64+
let%op mlp_layer x = !/("w" * x + "b" hid_dim) in
65+
...
66+
```
67+
68+
## The syntax for `%cd` {#syntax-for-cd}
2069

2170
The basic building blocks of the `%cd` syntax are individual assignments, separated by semicolons. The assignments, represented via `Assignments.Accum_binop` and `Assignments.Accum_unop`, are in full generality accumulating:
2271

@@ -44,7 +93,7 @@ type Assignments.t =
4493

4594
For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
4695

47-
The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the `<rhs1> <op> <rhs2>` part have a straightfowrad syntax: `<op>` is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `?/`. The fields `<lhs>`, `<rhs1>`, `<rhs2>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. Further details about the "slot fillers" are below (TODO: section link).
96+
The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the `<rhs1> <op> <rhs2>` part have a straightfowrad syntax: `<op>` is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `?/`. The fields `<lhs>`, `<rhs1>`, `<rhs2>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#features-of-syntax-cd).
4897

4998
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default is a pointwise operation.
5099

@@ -68,7 +117,7 @@ p =+ learning_rate *. p.grad
68117

69118
In the first case, we have a binary assignment calculated pointwise. The resulting representation is `Accum_binop` where `accum` is `Add` and `op` is `Mul` (multiplication). In the second case, `*.` is not recognized as one of the built-in operators. This leaves the expression `learning_rate *. p.grad` un-transformed. Since `(*.)` is bound in `NTDSL.O` to pointwise tensor multiplication, this creates an intermediate tensor, that is then added onto p. The resulting representation is `Accum_unop` where `accum` is `Add` and `op` is `Identity`. Both variants end up with the same result, and even with the same computation, because the second variant's computation will get optimized (unless configured not to).
70119

71-
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections: e.g. `t1`, `v1`, `g1` are "slot RHS1" of the projections, `t2`, `v2`, `g2` are "slot RHS2", `t`, `g` are "slot LHS".
120+
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections. Specifically, the identifiers `rhs1`, `t1`, `v1`, `g1` are "slot RHS1" of the projections, `rhs2`, `t2`, `v2`, `g2` are "slot RHS2", `lhs,`, `t`, `v`, `g` are "slot LHS".
72121

73122
## Numeric and N-dimensional array literals
74123

@@ -94,17 +143,117 @@ When an extension is over a wildcard (ignore result) binding: `let%cd _ = ...` a
94143

95144
Both `%cd` and `%op` syntaxes support inline declarations of tensors. For `%op` these are differentiable, for `%cd` non-differentiable tensors. A declaration site uses the string syntax, the content of the string is the is bound to the newly created tensor, and the string itself functions equivalently to using the newly introduced identifier. The scope of the binding is the full scope of the extension point, even if the declaring string appeared in the body of a function that's inside the extension point scope (except for `%op` there is a special case of `~config` labeled argument discussed below). The first element of the label of the created tensor is the string that introduced it.
96145

97-
For `%cd`, the declaration is (currently) only allowed on the left-hand-side, i.e. in the assigned-to position, of an assignment. If possible, one of the tensors on the right-hand-side is picked to provide additional label information. In particular, tensors that are function parameters inside the scope of the extension point, cannot be picked to provide label information, as they would escape their scope at the point the tensor is created.
146+
For `%cd`, the declaration is (currently) only allowed on the left-hand-side, i.e. in the assigned-to position, of an assignment. If possible, one of the tensors on the right-hand-side is picked to provide additional label information. In particular, tensors that are function parameters inside the scope of the extension point, cannot be picked to provide label information, as they would escape their scope at the point the tensor is created. Example showing two tensor nodes declared inline, both of them include the label of the param `p` in their labels:
147+
148+
```ocaml
149+
let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
150+
[%cd
151+
"sgd_delta" =: p.grad + (!.weight_decay *. p);
152+
if Float.(momentum > 0.0) then (
153+
"sgd_momentum" =: (!.momentum *. sgd_momentum) + sgd_delta;
154+
if nesterov then sgd_delta =+ !.momentum *. sgd_momentum else sgd_delta =: sgd_momentum);
155+
p =- learning_rate *. sgd_delta]
156+
```
157+
158+
For `%op`, the declaration is allowed anywhere. If there is a `~config` function parameter used inside the extension scope, for example as `fun ~config ... -> ...` or a more specific example `let%op mlp ~config x = ...`, the scope of an inline-declared tensor is no longer the full scope of the extension point. Instead, the tensor is defined right underneath the introduction of the `~config` parameter: `fun ~config -> let <definitions of the inline-declared tensors> in ...`. The config value passed to the generated code must be a record with at least a field `label : string list`. The inline-declared tensor that's defined under a `~config` parameter is defined as `TDSL.param ~more_label:config.label ...` Example showing two param tensors declared inline, including `config.label` in their labels:
159+
160+
```ocaml
161+
type mlp_layer_config = { label : string list; hid_dim : int }
162+
163+
let%op mlp_layer ~config x = !/ ("w" * x + "b" config.hid_dim)
164+
```
165+
166+
## Using OCANNL's generalized einsum notation
167+
168+
## Further features of the syntax extension `%cd` {#features-of-syntax-cd}
169+
170+
### Referencing arrays: tensor value, tensor gradient, merge buffer of a tensor node
171+
172+
The `%cd` syntax uses record-style notation to point to:
173+
174+
- the value tensor node of a tensor `<tensor>.value`,
175+
- the gradient tensor node of a tensor `<tensor>.grad`,
176+
- the merge buffer of a tensor node `<tensor-node>.merge`; `<tensor>.merge` is a shorthand for `<tensor>.value.merge`.
177+
178+
The accessor `.value` can (almost?) always be dropped: by default, tensors in the `%cd` syntax refer to their value nodes.
179+
180+
For example, in a data-parallel computation, gradients of the same param `p` can be merged across devices using the code `p.grad =+ p.grad.merge`, combined with an explicit device-to-device transfer.
98181

99-
For `%op`, the declaration is allowed anywhere. If there is a `~config` function parameter used inside the extension scope, for example as `fun ~config ... -> ...` or a more specific example `let%op mlp ~config x = ...`, the scope of an inline-declared tensor is no longer the full scope of the extension point. Instead, the tensor is defined right underneath the introduction of the `~config` parameter: `fun ~config -> let <definitions of the inline-declared tensors> in ...`. The config value passed to the generated code must be a record with at least a field `label : string list`. The inline-declared tensor that's defined under a `~config` parameter is defined as `TDSL.param ~more_label:config.label ...`
182+
### Block comments
100183

101-
## Features specific to the syntax extension `%cd`
184+
## Further features of the syntax extension `%op` {#features-of-syntax-op}
185+
186+
### Name from binding
187+
188+
When an extension point is applied to a let-binding, e.g. `let%op mlp_layer ~config x = !/ ("w" * x + "b" config.hid_dim)`, it uses the name of the binding (`mlp_layer` in the example) for the label of the primary tensor created by the extension, if any. This is why the resulting layer tensor in the example has its label starting with `"mlp_layer"`. If the extension is over a semicolon-separated sequence of expressions, the primary tensor can only be in the last component of the sequence, other syntax constructs are handled analogously.
189+
190+
### Label from function argument
191+
192+
The resulting (primary) tensor's label will also have incorporated the label of the input argument, if any. In our example, the resulting `mlp_layer` tensor will also include the label of the actually applied `x`.
193+
194+
Note that we do not include `config.label`, even if `config` is available, because the actually applied input argument will typically have more specific information.
195+
196+
### Lifting of the applications of `~config` arguments: if it's an error, refactor your code
197+
198+
If you recall, inline declared param tensors get lifted out of functions except for the function `fun ~config ->`, where they get defined. Our example `let%op mlp_layer ~config x = !/ ("w" * x + "b" config.hid_dim)` translates as:
199+
200+
```ocaml
201+
let mlp_layer ~config =
202+
let w = Tensor.param "w" and b = Tensor.param ~output_dims:[ config.hid_dim ] in
203+
fun x -> TDSL.O.(w * x + b)
204+
```
205+
206+
For this to work properly, when employing such network blocks, their params also need to be introduced at the right moment. Therefore, the `%op` syntax ensures that this example:
207+
208+
```ocaml
209+
type tlp_config = { label : string list; dim1 : int; dim2 : int; dim3 : int }
210+
211+
let%op three_layer_perceptron ~config x =
212+
mlp_layer ~config:{ label = [ "L3" ]; hid_dim = config.dim3 }
213+
(mlp_layer ~config:{ label = [ "L2" ]; hid_dim = config.dim2 }
214+
(mlp_layer ~config:{ label = [ "L1" ]; hid_dim = config.dim1 } x))
215+
```
216+
217+
gets expanded to:
218+
219+
```ocaml
220+
type tlp_config = { label : string list; dim1 : int; dim2 : int; dim3 : int }
221+
222+
let three_layer_perceptron ~config =
223+
let config_block__1 = mlp_layer ~config:{ label = [ "L3" ]; hid_dim = config.dim3 }
224+
and config_block__2 = mlp_layer ~config:{ label = [ "L2" ]; hid_dim = config.dim2 }
225+
and config_block__3 = mlp_layer ~config:{ label = [ "L1" ]; hid_dim = config.dim1 } in
226+
fun x -> config_block__1 (config_block__2 (config_block__3 x))
227+
```
228+
229+
However, this raises a concern for more complex situations. Consider this code that fails to compile:
230+
231+
```ocaml
232+
type mlp_config = { label : string list; hid_dims : int list }
233+
234+
let%op mlp ~config x =
235+
List.foldi config.hid_dims ~init:x ~f:(fun i x hid_dim ->
236+
mlp_layer ~config:{ label = [ "L" ^ Int.to_string i ]; hid_dim } x)
237+
```
238+
239+
The attempted lifting breaks because of the escaping variables `i` and `hid_dim`. This reminds us to rewrite the example, ensuring the proper introduction of params:
240+
241+
```ocaml
242+
type mlp_config = { label : string list; hid_dims : int list }
243+
244+
let mlp ~config =
245+
let layers =
246+
List.mapi config.hid_dims ~f:(fun i hid_dim ->
247+
mlp_layer ~config:{ label = [ "L" ^ Int.to_string i ]; hid_dim })
248+
in
249+
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
250+
```
102251

103-
## Features specific to the syntax extension `%op`
252+
Unfortunately, we need to be mindful to introduce params at the right times.
104253

105254
## Implementation details
106255

107-
### Syntax extension `%cd`
256+
### Syntax extension `%cd` {#implementation-extension-cd}
108257

109258
The translate function returns an record. The `expr` field (filler expression) meaning depends on `typ` (filler type): for `Code`, this is an `Assignments.t` expression. For `Unknown` and `Tensor`, this is a `Tensor.t` expression. For `Array` and `Merge_value`, this is a non-optional `Tnode.t` expression, and for `Grad_of_tensor` and `Merge_grad`, it's an optional `Tnode.t` expresssion.
110259

@@ -123,4 +272,4 @@ type expr_type =
123272
type projections_slot = LHS | RHS1 | RHS2 | Nonslot | Undet
124273
```
125274

126-
### Syntax extension `%op`
275+
### Syntax extension `%op` {#implementation-extension-op}

0 commit comments

Comments
 (0)