You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/syntax_extensions.md
+54Lines changed: 54 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,6 +16,60 @@ Functions inside `Operation.NTDSL` use `~grad_spec:Prohibit_grad` when calling i
16
16
17
17
The extension points open `NTDSL.O`, resp. `TDSL.O`, for the scope of the extension point, to expose the corresponding iterators.
18
18
19
+
## The syntax for `%cd`
20
+
21
+
The basic building blocks of the `%cd` syntax are individual assignments, separated by semicolons. The assignments, represented via `Assignments.Accum_binop` and `Assignments.Accum_unop`, are in full generality accumulating:
22
+
23
+
```ocaml
24
+
type Assignments.t =
25
+
...
26
+
| Accum_binop of {
27
+
initialize_neutral : bool;
28
+
accum : Ops.binop;
29
+
op : Ops.binop;
30
+
lhs : Tnode.t;
31
+
rhs1 : buffer;
32
+
rhs2 : buffer;
33
+
projections : Indexing.projections Lazy.t;
34
+
}
35
+
| Accum_unop of {
36
+
initialize_neutral : bool;
37
+
accum : Ops.binop;
38
+
op : Ops.unop;
39
+
lhs : Tnode.t;
40
+
rhs : buffer;
41
+
projections : Indexing.projections Lazy.t;
42
+
}
43
+
```
44
+
45
+
For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
46
+
47
+
The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the `<rhs1> <op> <rhs2>` part have a straightfowrad syntax: `<op>` is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `?/`. The fields `<lhs>`, `<rhs1>`, `<rhs2>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. Further details about the "slot fillers" are below (TODO: section link).
48
+
49
+
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default is a pointwise operation.
50
+
51
+
Here we see an example of tensor multiplication -- extending matrix multiplication to arbitrary number of axes -- multiplying `a` by `b` to get `c`. In `=:+`, `=` is required to separate the assigned-to part from the computation, `:` clears-out `c` before the computation, `+` selects addition to accumulate the results.
52
+
53
+
```ocaml
54
+
c =:+ a * b ~logic:"@"
55
+
```
56
+
57
+
Compare the following two ways of updating a parameter `p`:
58
+
59
+
```ocaml
60
+
p =+ learning_rate * p.grad ~logic:"."
61
+
```
62
+
63
+
and:
64
+
65
+
```ocaml
66
+
p =+ learning_rate *. p.grad
67
+
```
68
+
69
+
In the first case, we have a binary assignment calculated pointwise. The resulting representation is `Accum_binop` where `accum` is `Add` and `op` is `Mul` (multiplication). In the second case, `*.` is not recognized as one of the built-in operators. This leaves the expression `learning_rate *. p.grad` un-transformed. Since `(*.)` is bound in `NTDSL.O` to pointwise tensor multiplication, this creates an intermediate tensor, that is then added onto p. The resulting representation is `Accum_unop` where `accum` is `Add` and `op` is `Identity`. Both variants end up with the same result, and even with the same computation, because the second variant's computation will get optimized (unless configured not to).
70
+
71
+
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections: e.g. `t1`, `v1`, `g1` are "slot RHS1" of the projections, `t2`, `v2`, `g2` are "slot RHS2", `t`, `g` are "slot LHS".
72
+
19
73
## Numeric and N-dimensional array literals
20
74
21
75
Both `%cd` and `%op` extensions use a shared syntax for N-dimensional array literals. `%cd` uses `NTDSL.number` and `NTDSL.ndarray` functions, while `%op` uses `TDSL.number` and `TDSL.ndarray` functions. (This is just for consistency: `TDSL.ndarray` invokes `Tensor.ndarray ~grad_spec:If_needed`, which will figure out the gradient is not needed and will make the tensor non-differentiable.)
0 commit comments