Skip to content

Commit 9b2af95

Browse files
committed
Syntax extensions documentation: The syntax for %cd
1 parent 9331267 commit 9b2af95

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

lib/syntax_extensions.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,60 @@ Functions inside `Operation.NTDSL` use `~grad_spec:Prohibit_grad` when calling i
1616

1717
The extension points open `NTDSL.O`, resp. `TDSL.O`, for the scope of the extension point, to expose the corresponding iterators.
1818

19+
## The syntax for `%cd`
20+
21+
The basic building blocks of the `%cd` syntax are individual assignments, separated by semicolons. The assignments, represented via `Assignments.Accum_binop` and `Assignments.Accum_unop`, are in full generality accumulating:
22+
23+
```ocaml
24+
type Assignments.t =
25+
...
26+
| Accum_binop of {
27+
initialize_neutral : bool;
28+
accum : Ops.binop;
29+
op : Ops.binop;
30+
lhs : Tnode.t;
31+
rhs1 : buffer;
32+
rhs2 : buffer;
33+
projections : Indexing.projections Lazy.t;
34+
}
35+
| Accum_unop of {
36+
initialize_neutral : bool;
37+
accum : Ops.binop;
38+
op : Ops.unop;
39+
lhs : Tnode.t;
40+
rhs : buffer;
41+
projections : Indexing.projections Lazy.t;
42+
}
43+
```
44+
45+
For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
46+
47+
The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the `<rhs1> <op> <rhs2>` part have a straightfowrad syntax: `<op>` is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `?/`. The fields `<lhs>`, `<rhs1>`, `<rhs2>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. Further details about the "slot fillers" are below (TODO: section link).
48+
49+
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default is a pointwise operation.
50+
51+
Here we see an example of tensor multiplication -- extending matrix multiplication to arbitrary number of axes -- multiplying `a` by `b` to get `c`. In `=:+`, `=` is required to separate the assigned-to part from the computation, `:` clears-out `c` before the computation, `+` selects addition to accumulate the results.
52+
53+
```ocaml
54+
c =:+ a * b ~logic:"@"
55+
```
56+
57+
Compare the following two ways of updating a parameter `p`:
58+
59+
```ocaml
60+
p =+ learning_rate * p.grad ~logic:"."
61+
```
62+
63+
and:
64+
65+
```ocaml
66+
p =+ learning_rate *. p.grad
67+
```
68+
69+
In the first case, we have a binary assignment calculated pointwise. The resulting representation is `Accum_binop` where `accum` is `Add` and `op` is `Mul` (multiplication). In the second case, `*.` is not recognized as one of the built-in operators. This leaves the expression `learning_rate *. p.grad` un-transformed. Since `(*.)` is bound in `NTDSL.O` to pointwise tensor multiplication, this creates an intermediate tensor, that is then added onto p. The resulting representation is `Accum_unop` where `accum` is `Add` and `op` is `Identity`. Both variants end up with the same result, and even with the same computation, because the second variant's computation will get optimized (unless configured not to).
70+
71+
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections: e.g. `t1`, `v1`, `g1` are "slot RHS1" of the projections, `t2`, `v2`, `g2` are "slot RHS2", `t`, `g` are "slot LHS".
72+
1973
## Numeric and N-dimensional array literals
2074

2175
Both `%cd` and `%op` extensions use a shared syntax for N-dimensional array literals. `%cd` uses `NTDSL.number` and `NTDSL.ndarray` functions, while `%op` uses `TDSL.number` and `TDSL.ndarray` functions. (This is just for consistency: `TDSL.ndarray` invokes `Tensor.ndarray ~grad_spec:If_needed`, which will figure out the gradient is not needed and will make the tensor non-differentiable.)

0 commit comments

Comments
 (0)