You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/syntax_extensions.md
+54-3Lines changed: 54 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,6 +2,7 @@
2
2
3
3
- Table of contents
4
4
-[Preliminaries](#preliminaries)
5
+
-[Primitive operations](#primitive-operations)
5
6
-[The syntax for %op](#the-syntax-for-op)
6
7
-[The syntax for %cd](#the-syntax-for-cd)
7
8
-[Numeric and N-dimensional array literals](#numeric-and-n-dimensional-array-literals)
@@ -26,7 +27,7 @@
26
27
27
28
## Preliminaries
28
29
29
-
OCANNL, and arrayjit specifically, is built around a fixed number of numeric operations, declared in `arrayjit/ops.ml`. We assign lexical operators to many of the operations, inventing novel operators if needed. For example, Rectified Linear Unit `Relu` operation, which computes `f(x) = max(0,x)`, gets the operator `relu`, and the ReLU-Gate `Relu_gate` operation, which computes `f(x,y) = if x > 0.0 then y else 0.0`, gets the operator `-?/`. These built-in numeric operations are used to construct assignments (`Assignments.t` packaged as `Assignments.comp`). The syntax `%cd` is needed to build assignments concisely. On the other hand, while the syntax `%op` helps build tensors (`Tensor.t`), they can be expressed concisely in pure OCaml. Unlike for assignments, the building blocks for tensor expressions are easy to extend. The meaningful basic ones are provided in `lib/operation.ml`.
30
+
OCANNL, and arrayjit specifically, is built around a fixed number of numeric operations, declared in `arrayjit/ops.ml`. We assign lexical operators to the binary operations, inventing novel operators if needed. For example, Rectified Linear Unit `Relu` operation, which computes `f(x) = max(0,x)`, is called `relu`, while the ReLU-Gate `Relu_gate` operation, which computes `f(x,y) = if x > 0.0 then y else 0.0`, gets the operator `-?/` in addition to name `relu_gate`. These built-in numeric operations are used to construct assignments (`Assignments.t` packaged as `Assignments.comp`). The syntax `%cd` is needed to build assignments concisely, and the assignment operators always start with `=` (unlike in C where they end with `=`). On the other hand, while the syntax `%op` helps build tensors (`Tensor.t`), they can be expressed concisely in pure OCaml. Unlike for assignments, the building blocks for tensor expressions are easy to extend. The meaningful basic ones are provided in `lib/operation.ml`.
30
31
31
32
In OCANNL, we call a tensor that is prohibited from propagating gradients, does not have a gradient node nor backprop code, a _non-differentiable tensor_. Accordingly we can call the "plain" tensors with a gradient node _differentiable tensors_. Expressions in the `%cd` syntax will sometimes build new non-differentiable tensors as components of assignments (they will never build new differentiable tensors). The syntax extensions make the following assumption:
32
33
@@ -37,6 +38,56 @@ Functions inside `Operation.NTDSL` use `~grad_spec:Prohibit_grad` when calling i
37
38
38
39
The extension points open `NTDSL.O`, resp. `TDSL.O`, for the scope of the extension point, to expose the corresponding operators.
39
40
41
+
## Primitive operations
42
+
43
+
To accomodate stylistic preferences, OCANNL supports both curried and uncurried syntaxes for primitive operation application. Binary operators are associated with infix operators, in addition to having alphabetic identifiers. This stems from the following restriction: in the `%cd` syntax, the assignment is always an infix operator, and it needs to pick the accumulation operation.
44
+
45
+
The unary primitive operations:
46
+
47
+
| Identifier | Default projection | Constructor in `Arrayjit.Ops`|
| Identifier | Default projection | Constructor in `Arrayjit.Ops`|
87
+
|------------|--------------------|-------------|
88
+
|`where`| pointwise |`Where`|
89
+
|`fma`| compose-accumulate |`FMA`|
90
+
40
91
## The syntax for %op
41
92
42
93
The `%op` syntax is simpler than the `%cd` syntax since it relies more on regular OCaml expressions. For example, we can write without syntax extensions:
@@ -99,9 +150,9 @@ type Assignments.t =
99
150
100
151
For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
101
152
102
-
The basic `%cd` syntax for binary operator assignments has the form: `<lhs> <asgn-op> <rhs1> <op> <rhs2>` (or `<lhs> <asgn-op> <op> <rhs1> <rhs2>` when `<op>` is not an operator). The binary operators in the`<rhs1> <op> <rhs2>`part have a straightfowrad syntax:`<op>`is one of `+`, `-`, `*`, `/`, `**` (to-power-of), `-?/` (ReLU-Gate). `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by one of `+`, `-`, `*`, `/`, `**`, `relu`. The fields `<lhs>`, `<rhs1>`, `<rhs2>`will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `g`, `g1`, `g2`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
153
+
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where`<rhs1>`, `<rhs2>`(for binary and ternary ops),`<rhs3>`(for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>`will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
103
154
104
-
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default is a pointwise operation.
155
+
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default depends on the primitive operation, but it is almost always a pointwise operation.
105
156
106
157
Here we see an example of tensor multiplication -- extending matrix multiplication to arbitrary number of axes -- multiplying `a` by `b` to get `c`. In `=:+`, `=` is required to separate the assigned-to part from the computation, `:` clears-out `c` before the computation, `+` selects addition to accumulate the results.
0 commit comments