Skip to content

Commit 2032408

Browse files
committed
More %cd flexibility: derive projections for !. and !..
1 parent 5bd3107 commit 2032408

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

lib/operation.ml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,11 @@ let fma ?(label = []) ~grad_spec t1 t2 t3 =
306306
let where ?(label = []) ~grad_spec t1 t2 t3 =
307307
let module NTDSL = NTDSL_before_div in
308308
let%cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: where v1 v2 v3 in
309-
(* TODO: introduce a special-case projection for constants *)
309+
(* Just to illustrate that both [0] and [!..0] are handled. *)
310+
let zero_cst = 0 in
310311
let%cd grad_asn ~t:_ ~g ~t1 ~t2 ~t3 ~projections =
311312
g2 =+ where v1 g 0;
312-
g3 =+ where v1 0 g
313+
g3 =+ where v1 !..zero_cst g
313314
in
314315
Tensor.ternop ~label:("where" :: label) ~ternary_op:Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
315316
t2 t3
@@ -410,9 +411,9 @@ module DO = struct
410411
let recip_sqrt = recip_sqrt ~grad_spec:If_needed
411412
let tanh = tanh ~grad_spec:If_needed
412413
let where = where ~grad_spec:If_needed
413-
let (<) = lt ~grad_spec:Prohibit_grad
414-
let (=) = eq ~grad_spec:Prohibit_grad
415-
let (<>) = ne ~grad_spec:Prohibit_grad
414+
let ( < ) = lt ~grad_spec:Prohibit_grad
415+
let ( = ) = eq ~grad_spec:Prohibit_grad
416+
let ( <> ) = ne ~grad_spec:Prohibit_grad
416417
end
417418

418419
module NDO = struct
@@ -435,9 +436,9 @@ module NDO = struct
435436
let recip_sqrt = recip_sqrt ~grad_spec:Prohibit_grad
436437
let tanh = tanh ~grad_spec:Prohibit_grad
437438
let where = where ~grad_spec:Prohibit_grad
438-
let (<) = lt ~grad_spec:Prohibit_grad
439-
let (=) = eq ~grad_spec:Prohibit_grad
440-
let (<>) = ne ~grad_spec:Prohibit_grad
439+
let ( < ) = lt ~grad_spec:Prohibit_grad
440+
let ( = ) = eq ~grad_spec:Prohibit_grad
441+
let ( <> ) = ne ~grad_spec:Prohibit_grad
441442
end
442443

443444
module TDSL = struct

lib/ppx_cd.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,13 @@ let translate (expr : expression) : result =
741741
}
742742
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_primitive_op op_ident ->
743743
default_result
744+
| [%expr !.[%e? expr1]] ->
745+
(* Hardcoding these two patterns to improve projection derivation expressivity. *)
746+
let res1 = loop ~proj_in_scope expr1 in
747+
{ res1 with typ = Tensor; slot = Scalar; expr = [%expr NTDSL.O.( !. ) [%e res1.expr]] }
748+
| [%expr !..[%e? expr1]] ->
749+
let res1 = loop ~proj_in_scope expr1 in
750+
{ res1 with typ = Tensor; slot = Scalar; expr = [%expr NTDSL.O.( !.. ) [%e res1.expr]] }
744751
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
745752
(* FIXME: `**.` should take a tensor and require that it's a literal. *)
746753
(* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)

lib/syntax_extensions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ p =+ learning_rate *. p.grad
229229

230230
In the first case, we have a binary assignment calculated pointwise. The resulting representation is `Accum_binop` where `accum` is `Add` and `op` is `Mul` (multiplication). In the second case, `*.` is not recognized as one of the built-in operators. This leaves the expression `learning_rate *. p.grad` un-transformed. Since `(*.)` is bound in `NTDSL.O` to pointwise tensor multiplication, this creates an intermediate tensor, that is then added onto p. The resulting representation is `Accum_unop` where `accum` is `Add` and `op` is `Identity`. Both variants end up with the same result, and even with the same computation, because the second variant's computation will get optimized (unless configured not to).
231231

232-
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections. Specifically, the identifiers `rhs1`, `t1`, `v1`, `g1` are "slot RHS1" of the projections, `rhs2`, `t2`, `v2`, `g2` are "slot RHS2", `lhs,`, `t`, `v`, `g` are "slot LHS". Scalar constants are provided the projection directly, to make the automated derivation more expressive.
232+
Advanced note: when a `~projections` parameter is in scope but no assignment-specific `~projections` argument is given -- the typical case in `lib/operation.ml` -- the actual projections field for an assignment is computed by transforming the projections parameter according to hints regarding how tensor nodes relate to the given projections. Specifically, the identifiers `rhs1`, `t1`, `v1`, `g1` are "slot RHS1" of the projections, `rhs2`, `t2`, `v2`, `g2` are "slot RHS2", `lhs,`, `t`, `v`, `g` are "slot LHS". Scalar constants are provided the projection directly, to make the automated derivation more expressive; this is supported both for literals, and (heuristically) for `!.` and `!..` embedding operators.
233233

234234
## Numeric and N-dimensional array literals
235235

0 commit comments

Comments
 (0)