Skip to content

Commit 5bd3107

Browse files
committed
%cd syntax: Provide projections for scalar constants directly
TODO: also special handle operators `!.` and `!..` to use the Scalar slot.
1 parent 1511626 commit 5bd3107

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

lib/operation.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,8 @@ let where ?(label = []) ~grad_spec t1 t2 t3 =
308308
let%cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: where v1 v2 v3 in
309309
(* TODO: introduce a special-case projection for constants *)
310310
let%cd grad_asn ~t:_ ~g ~t1 ~t2 ~t3 ~projections =
311-
g2 =+ where v1 g (t3 - t3);
312-
g3 =+ where v1 (t2 - t2) g
311+
g2 =+ where v1 g 0;
312+
g3 =+ where v1 0 g
313313
in
314314
Tensor.ternop ~label:("where" :: label) ~ternary_op:Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
315315
t2 t3

lib/ppx_cd.ml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ type expr_type =
3333

3434
let is_unknown = function Unknown -> true | _ -> false
3535

36-
type projections_slot = LHS | RHS1 | RHS2 | RHS3 | Nonslot | Undet [@@deriving equal, sexp]
36+
type projections_slot = LHS | RHS1 | RHS2 | RHS3 | Scalar | Nonslot | Undet
37+
[@@deriving equal, sexp]
3738

3839
type result = {
3940
vbs : value_binding Map.M(String).t;
@@ -136,6 +137,7 @@ let project_p_slot debug loc slot =
136137
| RHS1 -> [%expr p.project_rhs.(0)]
137138
| RHS2 -> [%expr p.project_rhs.(1)]
138139
| RHS3 -> [%expr p.project_rhs.(2)]
140+
| Scalar -> [%expr [| Arrayjit.Indexing.Fixed_idx 0 |]]
139141
| Nonslot ->
140142
Ast_builder.Default.pexp_extension ~loc
141143
@@ Location.error_extensionf ~loc
@@ -152,6 +154,7 @@ let project_p_dims debug loc slot =
152154
| RHS1 -> [%expr p.rhs_dims.(0)]
153155
| RHS2 -> [%expr p.rhs_dims.(1)]
154156
| RHS3 -> [%expr p.rhs_dims.(2)]
157+
| Scalar -> [%expr [| 1 |]]
155158
| Nonslot ->
156159
Ast_builder.Default.pexp_extension ~loc
157160
@@ Location.error_extensionf ~loc
@@ -276,7 +279,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
276279
| RHS1 -> [%pat? nondiff__rhs1]
277280
| RHS2 -> [%pat? nondiff__rhs2]
278281
| RHS3 -> [%pat? nondiff__rhs3]
279-
| Nonslot | Undet -> [%pat? nondiff__tensor]
282+
| Scalar | Nonslot | Undet -> [%pat? nondiff__tensor]
280283
in
281284
let t = pat2expr v in
282285
let vb = Some (A.Vb.mk ~loc v filler) in
@@ -659,16 +662,20 @@ let translate (expr : expression) : result =
659662
in
660663
match expr with
661664
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
662-
{ default_result with expr = [%expr NTDSL.number [%e expr]] }
665+
{ default_result with expr = [%expr NTDSL.number [%e expr]]; slot = Scalar }
663666
| { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
664-
{ default_result with expr = [%expr NTDSL.number (Float.of_int [%e expr])] }
667+
{ default_result with expr = [%expr NTDSL.number (Float.of_int [%e expr])]; slot = Scalar }
665668
| [%expr
666669
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
667670
[%e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] ->
668671
let axis =
669672
Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None))
670673
in
671-
{ default_result with expr = [%expr NTDSL.number ~axis_label:[%e axis] [%e f]] }
674+
{
675+
default_result with
676+
expr = [%expr NTDSL.number ~axis_label:[%e axis] [%e f]];
677+
slot = Scalar;
678+
}
672679
| [%expr
673680
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
674681
[%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
@@ -678,6 +685,7 @@ let translate (expr : expression) : result =
678685
{
679686
default_result with
680687
expr = [%expr NTDSL.number ~axis_label:[%e axis] (Float.of_int [%e i])];
688+
slot = Scalar;
681689
}
682690
| { pexp_desc = Pexp_constant (Pconst_string (name, str_loc, _)); _ } ->
683691
{

lib/syntax_extensions.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ type Assignments.t =
205205

206206
For example the binary case in pseudocode: `if initialize_neutral then lhs = 0; lhs = lhs accum (rhs1 op rhs2)` (assuming the neutral element of `accum` is 0). The representation also has a field `projections` which determines which loops should be run and how the tensor nodes should be indexed to perform the computation.
207207

208-
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where `<rhs1>`, `<rhs2>` (for binary and ternary ops), `<rhs3>` (for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>` will often be either special-purpose identifiers (e.g. `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
208+
The basic `%cd` syntax for assignments has the form: `<lhs> <asgn-op> <primitive-op-application[rhs1, rhs2?, rhs3?]>`. See [Primitive operations](#primitive-operations) for the syntax of primitive operation application, where `<rhs1>`, `<rhs2>` (for binary and ternary ops), `<rhs3>` (for ternary ops) are subexpressions. `<asgn-op>` starts with `=`, followed by `:` only if `initialize_neutral` is true, then followed by the operator syntax variant of a binary primitive operation. The fields `<lhs>`, `<rhs1>`, `<rhs2>`, `<rhs3>` will often be either special-purpose identifiers (specifically `v`, `t`, `t1`, `t2`, `t3`, `g`, `g1`, `g2`, `g3`) or identifiers bound to tensors. `<rhs1>`, `<rsh2>`, `<rsh3>` will also often be (non-differentiable) tensor expressions. The notation `<tensor>.grad` stands for the gradient node of the given tensor. For more about "slot fillers", and to learn about the operators `*+` and `++`, see the section [further features of the syntax extension %cd](#further-features-of-the-syntax-extension-cd).
209209

210210
How is the `projections` field determined? `projections` can be given explicitly as a labeled argument `~projections`. If they aren't but `%cd` realizes there is a `~projections` parameter in scope, it uses it -- see `lib/operation.ml` where this option is used to define tensor operations. If instead of `~projections` a `~logic` labeled argument is given, the string passed is used to determine projections. `~logic:"."` means a pointwise operation. `~logic:"@"` means an "output axes of rhs2 match input axes of rhs1" operation (matrix multiplication is a special case). `~logic:"T"` means transpose of input and output axes. The string passed to `~logic` can also use OCANNL's generalization of the einsum notation, allowing arbitrary permutations and reductions of axes. If no information is given, the default depends on the primitive operation, but it is almost always a pointwise operation.
211211

@@ -229,7 +229,7 @@ p =+ learning_rate *. p.grad
229229

230230
In the first case, we have a binary assignment calculated pointwise. The resulting representation is `Accum_binop` where `accum` is `Add` and `op` is `Mul` (multiplication). In the second case, `*.` is not recognized as one of the built-in operators. This leaves the expression `learning_rate *. p.grad` un-transformed. Since `(*.)` is bound in `NTDSL.O` to pointwise tensor multiplication, this creates an intermediate tensor, that is then added onto p. The resulting representation is `Accum_unop` where `accum` is `Add` and `op` is `Identity`. Both variants end up with the same result, and even with the same computation, because the second variant's computation will get optimized (unless configured not to).
231231

232-
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections. Specifically, the identifiers `rhs1`, `t1`, `v1`, `g1` are "slot RHS1" of the projections, `rhs2`, `t2`, `v2`, `g2` are "slot RHS2", `lhs,`, `t`, `v`, `g` are "slot LHS".
232+
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections. Specifically, the identifiers `rhs1`, `t1`, `v1`, `g1` are "slot RHS1" of the projections, `rhs2`, `t2`, `v2`, `g2` are "slot RHS2", `lhs,`, `t`, `v`, `g` are "slot LHS". Scalar constants are provided the projection directly, to make the automated derivation more expressive.
233233

234234
## Numeric and N-dimensional array literals
235235

0 commit comments

Comments
 (0)