Skip to content

Commit 6d58c98

Browse files
committed
Fixes #210: ppx_op: incorporate the input tensor's label in the resulting tensor's label
1 parent 930bd0c commit 6d58c98

File tree

5 files changed

+68
-9
lines changed

5 files changed

+68
-9
lines changed

lib/ppx_op.ml

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ let make_vb_nd ~has_config ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
5858
let vb = Ast_helper.Vb.mk ~loc pat v in
5959
(pat, vb)
6060

61-
let rec translate ~has_config ?label expr =
61+
let rec translate ~is_toplevel ~has_config ?label expr =
6262
let loc = expr.pexp_loc in
63-
let loop = translate ~has_config in
63+
let loop = translate ~is_toplevel:false ~has_config in
6464
match expr with
6565
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
6666
(no_vbs, [%expr TDSL.number ?label:[%e opt_expr ~loc label] [%e expr]])
@@ -140,11 +140,25 @@ let rec translate ~has_config ?label expr =
140140
let vbs2, e2 = loop expr2 in
141141
(Map.merge_skewed vbs1 vbs2 ~combine:(fun ~key:_ _v1 v2 -> v2), [%expr [%e e1] [%e e2]])
142142
| [%expr fun ~config -> [%e? body]] ->
143-
let vbs, body = translate ~has_config:true ?label body in
143+
let vbs, body = translate ~is_toplevel:true ~has_config:true ?label body in
144144
(no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs body]])
145145
| [%expr fun ~(config : [%typ? config_ty]) -> [%e? body]] ->
146-
let vbs, body = translate ~has_config:true ?label body in
146+
let vbs, body = translate ~is_toplevel:true ~has_config:true ?label body in
147147
(no_vbs, [%expr fun ~(config : [%typ ty]) -> [%e let_opt ~loc vbs body]])
148+
| [%expr fun [%p? pat] -> [%e? body]] when is_toplevel ->
149+
let input_label =
150+
let loc = pat.ppat_loc in
151+
[%expr [%e pat2expr pat].Tensor.value.Arrayjit.Tnode.label]
152+
in
153+
let label =
154+
match label with
155+
| None -> input_label
156+
| Some label ->
157+
let loc = pat.ppat_loc in
158+
[%expr [%e label] @ [%e input_label]]
159+
in
160+
let vbs, body = loop ~label body in
161+
(vbs, [%expr fun [%p pat] -> [%e body]])
148162
| [%expr fun [%p? pat] -> [%e? body]] ->
149163
let vbs, body = loop ?label body in
150164
(vbs, [%expr fun [%p pat] -> [%e body]])
@@ -226,7 +240,9 @@ let rec translate ~has_config ?label expr =
226240

227241
let translate ?ident_label expr =
228242
let vbs, expr =
229-
translate ~has_config:false ~label:(opt_pat2string_list ~loc:expr.pexp_loc ident_label) expr
243+
translate ~is_toplevel:true ~has_config:false
244+
~label:(opt_pat2string_list ~loc:expr.pexp_loc ident_label)
245+
expr
230246
in
231247
let loc = expr.pexp_loc in
232248
( vbs,

lib/syntax_extensions.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ Using [inline declarations](#inline-declarations), this becomes more concise:
6666
...
6767
```
6868

69+
When there is a function directly under the `%op` extension point, like in the example above, or directly under a function taking a `~config` parameter, the function parameter must be a tensor. That's because `%op` uses this tensor's (value's) label to enrich the label of the resulting tensor.
70+
6971
## The syntax for `%cd` {#syntax-for-cd}
7072

7173
The basic building blocks of the `%cd` syntax are individual assignments, separated by semicolons. The assignments, represented via `Assignments.Accum_binop` and `Assignments.Accum_unop`, are in full generality accumulating:

test/micrograd_demo.ml

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,4 +358,43 @@ let%expect_test "Micrograd half-moons example" =
358358
-2.000e-1----
359359
──────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
360360
0.000e+0 3.990e+2
361-
step |}]
361+
step |}];
362+
Tensor.print_tree ~with_grad:true ~depth:9 mlp_result;
363+
[%expect {|
364+
#187 +_mlp_point
365+
-8.86e+0
366+
#188 grad_+_mlp_point <waiting>
367+
<not-in-yet>
368+
[84]│ #185 * <Local 33>
369+
<not-in-yet>
370+
│ #186 grad_* <waiting>
371+
<not-in-yet>
372+
│[90]│ #183 ?/ <Virtual 15>
373+
│ │ <not-in-yet>
374+
│ │ #184 grad_?/ <waiting>
375+
│ │ <not-in-yet>
376+
│ │ #181 + <Virtual 15>
377+
│ │ <not-in-yet>
378+
│ │ #182 grad_+ <waiting>
379+
│ │ <not-in-yet>
380+
│ │[82]│ #179 * <Local 33>
381+
│ │ │ <not-in-yet>
382+
│ │ │ #180 grad_* <waiting>
383+
│ │ │ <not-in-yet>
384+
│ │ │[88]│ #177 ?/ <Local 33>
385+
│ │ │ │ <not-in-yet>
386+
│ │ │ │ #178 grad_?/ <waiting>
387+
│ │ │ │ <not-in-yet>
388+
│ │ │ │ #175 + <Virtual 15>
389+
│ │ │ │ <not-in-yet>
390+
│ │ │ │ #176 grad_+ <waiting>
391+
│ │ │ │ <not-in-yet>
392+
│ │ │ │[80]│ #173 * <Local 33>
393+
│ │ │ │ │ <not-in-yet>
394+
│ │ │ │ │ #174 grad_* <waiting>
395+
│ │ │ │ │ <not-in-yet>
396+
│ │ │ │ │[86]│#171 point
397+
│ │ │ │ │ │ 2.09e+0 -5.88e-1
398+
│ │ │ │ │ │#172 grad_point <Never_virtual 26>
399+
│ │ │ │ │ │<not-in-yet>
400+
|}]

test/zero2hero_1of7.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ let%expect_test "Graph drawing fetch" =
141141
Tensor.print_tree ~with_grad:false ~depth:9 f5;
142142
[%expect
143143
{|
144-
#54 +_f
144+
#54 +_f_5.
145145
6.00e+1
146146
#53 - │#46 5. <Virtual 40>
147147
5.50e+1<not-in-yet>

test_ppx/test_ppx_op_expected.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ let y1 =
1111
let hey2 = TDSL.param ?values:None "hey2" in
1212
let open! TDSL.O in
1313
fun x ->
14-
((+) ?label:(Some ["y1"]))
14+
((+) ?label:(Some (["y1"] @ (x.Tensor.value).Arrayjit.Tnode.label)))
1515
((( * ) ?label:None) hey2 (TDSL.number (Float.of_int 2))) x
1616
let y2 =
1717
let hey3 = TDSL.param ?values:None "hey3" in
1818
let open! TDSL.O in
1919
fun x1 ->
20-
fun x2 -> ((+) ?label:(Some ["y2"])) ((( *. ) ?label:None) x1 hey3) x2
20+
fun x2 ->
21+
((+) ?label:(Some (["y2"] @ (x1.Tensor.value).Arrayjit.Tnode.label)))
22+
((( *. ) ?label:None) x1 hey3) x2
2123
let a =
2224
let open! TDSL.O in
2325
TDSL.ndarray ?label:(Some ["a"]) ~batch_dims:[] ~input_dims:[3]

0 commit comments

Comments
 (0)