Skip to content

Commit 9cd1261

Browse files
committed
Incorporate let-binding identifier in %cd names; fix handling of syntactic functions
1 parent 1ef3c97 commit 9cd1261

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

lib/ppx_cd.ml

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type expr_type =
1414
| Merge_value of expression
1515
| Merge_grad of expression
1616
| No_grad_tensor_intro of { name : string; name_expr : expression }
17+
| Function
1718

1819
let is_unknown = function Unknown -> true | _ -> false
1920

@@ -62,6 +63,8 @@ let make_vb ~loc ~name ~name_expr ~hint_label =
6263
in
6364
let vb = A.Vb.mk ~loc pat v in
6465
vb
66+
(* let make_code ~loc ~name ~name_expr ~hint_label code_expr = [%expr { asgns = [%e code_expr];
67+
embedded_nodes = Base.Set.empty (module Ir.Tnode) }] *)
6568

6669
let reduce_embs_arr ~loc (rs : array_setup list) =
6770
List.filter_map rs ~f:(fun hs -> hs.fwd_code_or_noop)
@@ -152,7 +155,7 @@ let guess_pun_hint ~no_filler_label ~punned ~bad_pun_hints filler_typ filler =
152155
let loc = filler.pexp_loc in
153156
let hint = [%expr [%e filler].Ir.Tnode.label] in
154157
match (filler_typ, filler, no_filler_label) with
155-
| Code, _, _ -> None
158+
| (Code | Function), _, _ -> None
156159
| _, { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }, _ when Set.mem bad_pun_hints name ->
157160
None
158161
| Array, _, false -> Some (hint, false)
@@ -296,6 +299,15 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
296299
}]
297300
in
298301
{ (default_setup false) with fwd_code_or_noop; tensor = Some filler }
302+
| _, Function ->
303+
{
304+
(default_setup false) with
305+
fwd_code_or_noop = Some filler;
306+
array_opt =
307+
Ast_builder.Default.pexp_extension ~loc
308+
@@ Location.error_extensionf ~loc
309+
"ppx_ocannl %%cd: a syntactic function in place of an array is not supported";
310+
}
299311
| _, Code when Option.is_none array_opt_of_code ->
300312
{
301313
(default_setup false) with
@@ -332,7 +344,11 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
332344
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: merge buffers cannot be assigned to";
333345
}
334346
| _, Merge_value t ->
335-
{ (default_setup false) with array_opt = [%expr Some (Merge_buffer [%e filler])]; tensor = Some t }
347+
{
348+
(default_setup false) with
349+
array_opt = [%expr Some (Merge_buffer [%e filler])];
350+
tensor = Some t;
351+
}
336352
| _, Merge_grad t ->
337353
{
338354
(default_setup false) with
@@ -388,7 +404,7 @@ let handle_cases ~bad_pun_hints ~proj_in_scope transl cases =
388404
array_opt_of_code = None;
389405
} )
390406

391-
let translate (expr : expression) : result =
407+
let translate ?ident_label (expr : expression) : result =
392408
let punned = Hashtbl.create (module String) in
393409
let rec transl ~bad_pun_hints ~proj_in_scope (expr : expression) : result =
394410
let loc = expr.pexp_loc in
@@ -778,9 +794,10 @@ let translate (expr : expression) : result =
778794
slot = Scalar;
779795
}
780796
| { pexp_desc = Pexp_constant (Pconst_string (name, str_loc, _)); _ } ->
781-
(* TODO: consider passing toplevel binding name as a hint label *)
782797
let vbs =
783-
Map.singleton (module String) name @@ make_vb ~loc ~name ~name_expr:expr ~hint_label:None
798+
Map.singleton (module String) name
799+
@@ make_vb ~loc ~name ~name_expr:expr
800+
~hint_label:(Option.map ~f:(fun s -> [%expr [ [%e s] ]]) ident_label)
784801
in
785802
{
786803
vbs;
@@ -906,7 +923,7 @@ let translate (expr : expression) : result =
906923
@@ Location.error_extensionf ~loc
907924
"ppx_ocannl %%cd: write .grad.merge instead of .merge.grad";
908925
}
909-
| Code | Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_grad _ ->
926+
| Function | Code | Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_grad _ ->
910927
{
911928
res1 with
912929
typ = Array;
@@ -924,7 +941,7 @@ let translate (expr : expression) : result =
924941
typ = Value_of_tensor res1.expr;
925942
expr = [%expr [%e res1.expr].Tensor.value];
926943
}
927-
| Code ->
944+
| Function | Code ->
928945
{
929946
res1 with
930947
typ = Array;
@@ -942,7 +959,7 @@ let translate (expr : expression) : result =
942959
{ res1 with typ = Merge_value res1.expr; expr = [%expr [%e res1.expr].Tensor.value] }
943960
| Value_of_tensor t ->
944961
{ res1 with typ = Merge_value t; expr = [%expr [%e res1.expr].Tensor.value] }
945-
| Array | Code ->
962+
| Function | Array | Code ->
946963
{
947964
res1 with
948965
typ = Array;
@@ -1275,6 +1292,7 @@ let translate (expr : expression) : result =
12751292
let res = transl ~bad_pun_hints ~proj_in_scope body in
12761293
{
12771294
res with
1295+
typ = Function;
12781296
expr =
12791297
{ expr with pexp_desc = Pexp_function (args, constr, Pfunction_body res.expr) };
12801298
}
@@ -1286,6 +1304,7 @@ let translate (expr : expression) : result =
12861304
in
12871305
{
12881306
cases_result with
1307+
typ = Function;
12891308
expr =
12901309
{
12911310
expr with
@@ -1396,7 +1415,7 @@ let translate (expr : expression) : result =
13961415
transl ~proj_in_scope:false ~bad_pun_hints:(Set.empty (module String)) expr
13971416

13981417
let translate ?ident_label expr =
1399-
let res = translate expr in
1418+
let res = translate ?ident_label:(Option.map ~f:pat2string ident_label) expr in
14001419
let loc = res.expr.pexp_loc in
14011420
let expr = res.expr in
14021421
( res.vbs,

test/einsum/moons_demo_variant.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ n66 sgd_delta_w2 as sgd_delta_w2: Virt/15; single prec 1x16; mem in bytes: <not-
6464
n67 sgd_momentum_w2 as sgd_momentum_w2: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
6565
n68 0.0001 as n68: Virt/40; single prec 1; mem in bytes: <not-in-yet>
6666
n69 *. as n69: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
67-
n70 point as point: Host&shared/38039; single prec 2; mem in bytes: <not-in-yet>
67+
n70 point_mlp_result as point_mlp_result: Host&shared/38039; single prec 2; mem in bytes: <not-in-yet>
6868
n71 * as n71: Local/1046; single prec 16; mem in bytes: <not-in-yet>
6969
n72 grad_* as n71.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
7070
n73 + as n73: Virt/15; single prec 16; mem in bytes: <not-in-yet>
7171
n74 grad_+ as n73.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
7272
n75 relu as relu: Virt/15; single prec 16; mem in bytes: <not-in-yet>
7373
n76 grad_relu as relu.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
74-
n77 *_mlp_point as mlp_point: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
75-
n78 grad_*_mlp_point as mlp_point.grad: unknown; single prec 1; mem in bytes: <not-in-yet>
74+
n77 *_mlp_point_mlp_result as mlp_point_mlp_result: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
75+
n78 grad_*_mlp_point_mlp_result as mlp_point_mlp_result.grad: unknown; single prec 1; mem in bytes: <not-in-yet>
7676
Tnode: Finished printing headers.
77-
mlp_result's name: mlp_point
77+
mlp_result's name: mlp_point_mlp_result
7878
(mlp moons_input) name: mlp_moons_input

test/training/moons_demo.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,5 @@ Learning rate:
178178
├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────┤
179179
│ │0.00 4.79e+3│
180180
│ │ step │
181-
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘mlp_result's name: mlp_point
181+
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────┘mlp_result's name: mlp_point_mlp_result
182182
(mlp moons_input) name: mlp_moons_input

0 commit comments

Comments
 (0)