Skip to content

Commit 45cf3c7

Browse files
committed
%cd syntax: automatically generate comments; more lightweight comments in Train.grad_update
1 parent a8c22b3 commit 45cf3c7

File tree

3 files changed

+81
-51
lines changed

3 files changed

+81
-51
lines changed

lib/ppx_cd.ml

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@ let make_vb ~loc ~name ~name_expr ~hint_label =
6363
in
6464
let vb = A.Vb.mk ~loc pat v in
6565
vb
66-
(* let make_code ~loc ~name ~name_expr ~hint_label code_expr = [%expr { asgns = [%e code_expr];
67-
embedded_nodes = Base.Set.empty (module Ir.Tnode) }] *)
68-
69-
let reduce_embs_arr ~loc (rs : array_setup list) =
70-
List.filter_map rs ~f:(fun hs -> hs.fwd_code_or_noop)
71-
|> List.reduce ~f:(fun embs comp -> [%expr Base.Set.union [%e embs] [%e comp].embedded_nodes])
7266

7367
(** The expression argument is of type: [Assignments.t]. *)
7468
let assignment ~punned ~lhs ~rhses body =
@@ -106,7 +100,9 @@ let assignment ~punned ~lhs ~rhses body =
106100
else body
107101
in
108102
let tensor_vbs = List.filter_map rhses ~f:(fun rhs -> rhs.vb) in
109-
let body = [%expr { asgns = [%e body]; embedded_nodes = Base.Set.empty (module Ir.Tnode) }] in
103+
let body =
104+
[%expr { Ir.Assignments.asgns = [%e body]; embedded_nodes = Base.Set.empty (module Ir.Tnode) }]
105+
in
110106
let comps =
111107
List.fold (body :: List.rev forward_args) ~init:[%expr []] ~f:(fun xs x ->
112108
[%expr [%e x] :: [%e xs]])
@@ -193,7 +189,9 @@ let guess_pun_hint ~no_filler_label ~punned ~bad_pun_hints filler_typ filler =
193189
| _, _, true -> None
194190

195191
let empty_tns ~loc = [%expr Base.Set.empty (module Ir.Tnode)]
196-
let empty_comp ~loc = [%expr { asgns = Ir.Assignments.Noop; embedded_nodes = [%e empty_tns ~loc] }]
192+
193+
let empty_comp ~loc =
194+
[%expr { Ir.Assignments.asgns = Ir.Assignments.Noop; embedded_nodes = [%e empty_tns ~loc] }]
197195

198196
let setup_array ~punned ~bad_pun_hints ~is_lhs
199197
{ typ = filler_typ; slot; expr = filler; vbs; array_opt_of_code } =
@@ -298,11 +296,10 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
298296
let fwd_code_or_noop =
299297
Some
300298
[%expr
301-
Ir.Assignments.
302-
{
303-
asgns = Noop;
304-
embedded_nodes = Base.Set.singleton (module Ir.Tnode) [%e filler].Tensor.value;
305-
}]
299+
{
300+
Ir.Assignments.asgns = Ir.Assignments.Noop;
301+
embedded_nodes = Base.Set.singleton (module Ir.Tnode) [%e filler].Tensor.value;
302+
}]
306303
in
307304
{ (default_setup false) with fwd_code_or_noop; tensor = Some filler }
308305
| _, Function ->
@@ -986,47 +983,60 @@ let translate ?ident_label (expr : expression) : result =
986983
let res1 = loop ~proj_in_scope expr1 in
987984
match res1.typ with
988985
| Unknown | Tensor | No_grad_tensor_intro _ ->
989-
{ res1 with typ = Code { is_commented = false }; expr = [%expr Tensor.consume_forward_code [%e res1.expr]] }
986+
{
987+
res1 with
988+
typ = Code { is_commented = false };
989+
expr = [%expr Tensor.consume_forward_code [%e res1.expr]];
990+
}
990991
| _ ->
991992
{
992993
res1 with
993994
expr =
994995
Ast_builder.Default.pexp_extension ~loc
995-
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: .forward can only be applied to tensors";
996-
}
997-
)
996+
@@ Location.error_extensionf ~loc
997+
"ppx_ocannl %%cd: .forward can only be applied to tensors";
998+
})
998999
| [%expr [%e? expr1].backprop] -> (
9991000
let res1 = loop ~proj_in_scope expr1 in
10001001
match res1.typ with
10011002
| Unknown | Tensor | No_grad_tensor_intro _ ->
1002-
{ res1 with typ = Code { is_commented = false }; expr = [%expr Tensor.consume_backprop_code [%e res1.expr]] }
1003+
{
1004+
res1 with
1005+
typ = Code { is_commented = false };
1006+
expr = [%expr Tensor.consume_backprop_code [%e res1.expr]];
1007+
}
10031008
| _ ->
10041009
{
10051010
res1 with
10061011
expr =
10071012
Ast_builder.Default.pexp_extension ~loc
1008-
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: .backprop can only be applied to tensors";
1009-
}
1010-
)
1013+
@@ Location.error_extensionf ~loc
1014+
"ppx_ocannl %%cd: .backprop can only be applied to tensors";
1015+
})
10111016
| [%expr [%e? expr1].zero_grads] -> (
10121017
let res1 = loop ~proj_in_scope expr1 in
10131018
match res1.typ with
10141019
| Unknown | Tensor | No_grad_tensor_intro _ ->
1015-
{ res1 with typ = Code { is_commented = false };
1016-
expr = [%expr
1017-
match [%e res1.expr].diff with
1018-
| None ->
1019-
raise (Invalid_argument "ppx_ocannl %cd: .zero_grads requires a differentiable tensor")
1020-
| Some diff -> Ir.Assignments.to_comp diff.zero_grads
1021-
] }
1020+
{
1021+
res1 with
1022+
typ = Code { is_commented = false };
1023+
expr =
1024+
[%expr
1025+
match [%e res1.expr].diff with
1026+
| None ->
1027+
raise
1028+
(Invalid_argument
1029+
"ppx_ocannl %cd: .zero_grads requires a differentiable tensor")
1030+
| Some diff -> Ir.Assignments.to_comp diff.zero_grads];
1031+
}
10221032
| _ ->
10231033
{
10241034
res1 with
10251035
expr =
10261036
Ast_builder.Default.pexp_extension ~loc
1027-
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: .zero_grads can only be applied to tensors";
1028-
}
1029-
)
1037+
@@ Location.error_extensionf ~loc
1038+
"ppx_ocannl %%cd: .zero_grads can only be applied to tensors";
1039+
})
10301040
| [%expr
10311041
~~([%e? { pexp_desc = Pexp_constant (Pconst_string _); _ } as comment];
10321042
[%e? expr2])] ->
@@ -1481,23 +1491,45 @@ let translate ?ident_label (expr : expression) : result =
14811491
{ res1 with expr = { expr with pexp_desc = Pexp_letmodule (name, module_expr, res1.expr) } }
14821492
| _ -> { default_result with typ = Unknown }
14831493
in
1484-
transl ~proj_in_scope:false ~bad_pun_hints:(Set.empty (module String)) expr
1494+
let res = transl ~proj_in_scope:false ~bad_pun_hints:(Set.empty (module String)) expr in
1495+
match (res.typ, ident_label) with
1496+
| Code { is_commented = false }, Some string_expr ->
1497+
let loc = res.expr.pexp_loc in
1498+
{
1499+
res with
1500+
expr =
1501+
[%expr
1502+
let uncommented_comp = [%e res.expr] in
1503+
{
1504+
Ir.Assignments.embedded_nodes = uncommented_comp.Ir.Assignments.embedded_nodes;
1505+
asgns =
1506+
Ir.Assignments.Block_comment
1507+
([%e string_expr], uncommented_comp.Ir.Assignments.asgns);
1508+
}];
1509+
typ = Code { is_commented = true };
1510+
}
1511+
| _ -> res
14851512

14861513
let translate ?ident_label expr =
1487-
let res = translate ?ident_label:(Option.map ~f:pat2string ident_label) expr in
1514+
let ident_label, is_ignore =
1515+
match ident_label with
1516+
| Some [%pat? _] -> (None, true)
1517+
| Some label -> (Some (pat2string label), false)
1518+
| None -> (None, false)
1519+
in
1520+
let res = translate ?ident_label expr in
14881521
let loc = res.expr.pexp_loc in
14891522
let expr = res.expr in
14901523
( res.vbs,
1491-
match ident_label with
1492-
| Some [%pat? _] ->
1493-
[%expr
1494-
Tensor.with_unchanged_roots ~f:(fun () ->
1495-
let open! NTDSL.O in
1496-
[%e expr])]
1497-
| _ ->
1498-
[%expr
1499-
let open! NTDSL.O in
1500-
[%e expr]] )
1524+
if is_ignore then
1525+
[%expr
1526+
Tensor.with_unchanged_roots ~f:(fun () ->
1527+
let open! NTDSL.O in
1528+
[%e expr])]
1529+
else
1530+
[%expr
1531+
let open! NTDSL.O in
1532+
[%e expr]] )
15011533

15021534
let expr_expander ~loc ~path = expr_expander_with_punning translate ~loc ~path
15031535
let str_expander ~loc ~path = str_expander_with_punning translate ~loc ~path

lib/syntax_extensions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ type Assignments.t =
392392

393393
Schematic example: `~~("space" "separated" "comment" "tensor p debug_name:" p; <scope of the comment>)`. The content of the comment uses application syntax, must be composed of strings, `<tensor>`, `<tensor>.value` (equivalent to `<tensor>`), `<tensor>.grad` components, where `<tensor>` is any tensor expression or tensor identifier.
394394

395-
This syntax used to be very important, because comments in assignments are used to derive file names for generated code. Now, the `%cd` syntax automatically introduces block comments for code at let-binding points, using the identifier. However, currently the comment does not yet incorporate any tensor node labels -- and for that reason we are not yet adding comments around function bodies if a function is annotated with `%cd` -- so the `~~` syntax is still helpful when the comment needs to be more precise for debugging or naming purposes, or when `%cd` is not used with a let binding. If an explicit comment is provided at the let-binding level, the automatic one is omitted.
395+
This syntax used to be very important, because comments in assignments are used to derive file names for generated code. Now, the `%cd` syntax automatically introduces block comments for code at let-binding points, using the identifier. Currently the comment does not yet incorporate any tensor node labels -- and for that reason we are not yet adding comments around function bodies if a function is annotated with `%cd`. Moreover, we only automatically add comments for code, not for tensors -- so the `~~` syntax is still helpful when the comment needs to be more precise for debugging or naming purposes, or when `%cd` is not used with a let binding, or when we want to pass a forward code directly instead of let-binding it. If an explicit comment is provided at the let-binding level, the automatic one is omitted.
396396

397397
## Further features of the syntax extension %op
398398

lib/train.ml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,11 @@ let grad_update ?(setup_for_parallel = false) loss =
8989
set_materialized (Option.value_exn ~here:[%here] p.diff).grad);
9090
(* Note: the %cd syntax for [loss.grad] does not modify roots. *)
9191
[%cd
92-
~~(loss "gradient update";
93-
~~(loss "fwd";
94-
loss.forward);
95-
~~(loss "zero grads";
96-
loss.zero_grads);
97-
loss.grad =: 1;
98-
~~(loss "bprop";
92+
~~(loss "forward and gradient update";
93+
loss.forward;
94+
~~(loss "zero grads and backprop";
95+
loss.zero_grads;
96+
loss.grad =: 1;
9997
loss.backprop))]
10098

10199
(** See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py *)

0 commit comments

Comments
 (0)