Skip to content

Commit d1a2868

Browse files
committed
Fixes #279: ident_label in ppx_cd is not helpful
1 parent 38e45b3 commit d1a2868

File tree

1 file changed

+39
-108
lines changed

1 file changed

+39
-108
lines changed

lib/ppx_cd.ml

Lines changed: 39 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ open Ppx_arrayjit.Ppx_helper
44
open Ppx_shared
55
module A = Ppxlib_ast.Ast_helper
66

7-
let ndarray_op ~ident_label ?axis_labels ?label expr =
7+
let ndarray_op ?axis_labels ?label expr =
88
let loc = expr.pexp_loc in
99
let values, batch_dims, output_dims, input_dims = ndarray_constant expr in
1010
let edims dims = Ast_builder.Default.elist ~loc dims in
@@ -14,15 +14,10 @@ let ndarray_op ~ident_label ?axis_labels ?label expr =
1414
| Some axis_labels, None -> [%expr NTDSL.ndarray ~axis_labels:[%e axis_labels]]
1515
| None, Some label -> [%expr NTDSL.ndarray ~label:[%e label]]
1616
| Some axis_labels, Some label ->
17-
[%expr
18-
NTDSL.ndarray
19-
~label:[%e opt_pat2string_list ~loc ident_label]
20-
~axis_labels:[%e axis_labels] ~label:[%e label]]
17+
[%expr NTDSL.ndarray ~axis_labels:[%e axis_labels] ~label:[%e label]]
2118
in
2219
[%expr
23-
[%e op]
24-
~label:[%e opt_pat2string_list ~loc ident_label]
25-
~batch_dims:[%e edims batch_dims] ~input_dims:[%e edims input_dims]
20+
[%e op] ~batch_dims:[%e edims batch_dims] ~input_dims:[%e edims input_dims]
2621
~output_dims:[%e edims output_dims] [%e values]]
2722

2823
type expr_type =
@@ -414,9 +409,9 @@ let args_for ~loc = function
414409
[%expr false],
415410
[%expr false] )
416411

417-
let translate ?ident_label (expr : expression) : result =
412+
let translate (expr : expression) : result =
418413
let punned = Hashtbl.create (module String) in
419-
let rec transl ~bad_pun_hints ?ident_label ~proj_in_scope (expr : expression) : result =
414+
let rec transl ~bad_pun_hints ~proj_in_scope (expr : expression) : result =
420415
let loc = expr.pexp_loc in
421416
let default_result =
422417
{ vbs = no_vbs; typ = Tensor; slot = Undet; expr; array_opt_of_code = None }
@@ -425,16 +420,11 @@ let translate ?ident_label (expr : expression) : result =
425420
let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
426421
let initialize_neutral, accu_op = assignment_op accu_op in
427422
let setup_l =
428-
setup_array ~punned ~bad_pun_hints ~is_lhs:true
429-
@@ loop ?ident_label ~proj_in_scope:true lhs
423+
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope:true lhs
430424
in
431425
let _, bin_op = binary_op bin_op in
432-
let setup_r1 =
433-
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1
434-
in
435-
let setup_r2 =
436-
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2
437-
in
426+
let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1 in
427+
let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2 in
438428
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
439429
let projections =
440430
match projections with
@@ -493,9 +483,7 @@ let translate ?ident_label (expr : expression) : result =
493483
(* FIXME: I think this ignores the slot information here! Just assuming [projections] is
494484
as-should-be, but that's not consistent with omitting the projections arg (assuming it
495485
comes from the context). *)
496-
let setup_l =
497-
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ?ident_label ~proj_in_scope lhs
498-
in
486+
let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope lhs in
499487
let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs in
500488
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
501489
let projections =
@@ -548,15 +536,9 @@ let translate ?ident_label (expr : expression) : result =
548536
in
549537
let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
550538
let initialize_neutral, accu_op = assignment_op accu_op in
551-
let setup_l =
552-
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ?ident_label ~proj_in_scope lhs
553-
in
554-
let setup_r1 =
555-
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1
556-
in
557-
let setup_r2 =
558-
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2
559-
in
539+
let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope lhs in
540+
let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1 in
541+
let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2 in
560542
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
561543
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
562544
let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
@@ -572,9 +554,7 @@ let translate ?ident_label (expr : expression) : result =
572554
in
573555
let process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic =
574556
let initialize_neutral, accu_op = assignment_op accu_op in
575-
let setup_l =
576-
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ?ident_label ~proj_in_scope lhs
577-
in
557+
let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope lhs in
578558
let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs in
579559
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
580560
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
@@ -589,31 +569,16 @@ let translate ?ident_label (expr : expression) : result =
589569
in
590570
match expr with
591571
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
592-
{
593-
default_result with
594-
expr = [%expr NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] [%e expr]];
595-
}
572+
{ default_result with expr = [%expr NTDSL.number [%e expr]] }
596573
| { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
597-
{
598-
default_result with
599-
expr =
600-
[%expr
601-
NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] (Float.of_int [%e expr])];
602-
}
574+
{ default_result with expr = [%expr NTDSL.number (Float.of_int [%e expr])] }
603575
| [%expr
604576
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
605577
[%e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] ->
606578
let axis =
607579
Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None))
608580
in
609-
{
610-
default_result with
611-
expr =
612-
[%expr
613-
NTDSL.number
614-
~label:[%e opt_pat2string_list ~loc ident_label]
615-
~axis_label:[%e axis] [%e f]];
616-
}
581+
{ default_result with expr = [%expr NTDSL.number ~axis_label:[%e axis] [%e f]] }
617582
| [%expr
618583
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
619584
[%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
@@ -622,12 +587,7 @@ let translate ?ident_label (expr : expression) : result =
622587
in
623588
{
624589
default_result with
625-
expr =
626-
[%expr
627-
NTDSL.number
628-
~label:[%e opt_pat2string_list ~loc ident_label]
629-
~axis_label:[%e axis]
630-
(Float.of_int [%e i])];
590+
expr = [%expr NTDSL.number ~axis_label:[%e axis] (Float.of_int [%e i])];
631591
}
632592
| { pexp_desc = Pexp_constant (Pconst_string (name, str_loc, _)); _ } ->
633593
{
@@ -637,7 +597,7 @@ let translate ?ident_label (expr : expression) : result =
637597
}
638598
| { pexp_desc = Pexp_array _; _ }
639599
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
640-
{ default_result with expr = ndarray_op ~ident_label expr }
600+
{ default_result with expr = ndarray_op expr }
641601
| { pexp_desc = Pexp_ident { txt = Lident ("v" | "lhs"); _ }; _ } ->
642602
{ default_result with typ = Array; slot = LHS }
643603
| { pexp_desc = Pexp_ident { txt = Lident "g"; _ }; _ } ->
@@ -677,24 +637,11 @@ let translate ?ident_label (expr : expression) : result =
677637
{
678638
res1 with
679639
typ = Tensor;
680-
expr =
681-
[%expr
682-
NTDSL.O.( **. )
683-
~label:[%e opt_pat2string_list ~loc ident_label]
684-
[%e res1.expr]
685-
(Float.of_int [%e i])];
640+
expr = [%expr NTDSL.O.( **. ) [%e res1.expr] (Float.of_int [%e i])];
686641
}
687642
| [%expr [%e? expr1] **. [%e? expr2]] ->
688643
let res1 = loop ~proj_in_scope expr1 in
689-
{
690-
res1 with
691-
typ = Tensor;
692-
expr =
693-
[%expr
694-
NTDSL.O.( **. )
695-
~label:[%e opt_pat2string_list ~loc ident_label]
696-
[%e res1.expr] [%e expr2]];
697-
}
644+
{ res1 with typ = Tensor; expr = [%expr NTDSL.O.( **. ) [%e res1.expr] [%e expr2]] }
698645
| [%expr
699646
[%e? expr1]
700647
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]
@@ -710,29 +657,17 @@ let translate ?ident_label (expr : expression) : result =
710657
vbs = reduce_vbss [ res1.vbs; res2.vbs ];
711658
typ = Tensor;
712659
slot;
713-
expr =
714-
[%expr
715-
NTDSL.einsum
716-
~label:[%e opt_pat2string_list ~loc ident_label]
717-
[%e spec] [%e res1.expr] [%e res2.expr]];
660+
expr = [%expr NTDSL.einsum [%e spec] [%e res1.expr] [%e res2.expr]];
718661
array_opt_of_code = None;
719662
}
720663
| [%expr
721664
[%e? expr1]
722665
++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
723666
when String.contains spec_str '>' ->
724667
let res1 = loop ~proj_in_scope expr1 in
725-
{
726-
res1 with
727-
typ = Tensor;
728-
expr =
729-
[%expr
730-
NTDSL.einsum1
731-
~label:[%e opt_pat2string_list ~loc ident_label]
732-
[%e spec] [%e res1.expr]];
733-
}
668+
{ res1 with typ = Tensor; expr = [%expr NTDSL.einsum1 [%e spec] [%e res1.expr]] }
734669
| [%expr [%e? expr1].grad] -> (
735-
let res1 = loop ?ident_label ~proj_in_scope expr1 in
670+
let res1 = loop ~proj_in_scope expr1 in
736671
match res1.typ with
737672
| Unknown | Tensor | No_grad_tensor_intro _ ->
738673
{
@@ -758,7 +693,7 @@ let translate ?ident_label (expr : expression) : result =
758693
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: only tensors have a gradient";
759694
})
760695
| [%expr [%e? expr1].value] -> (
761-
let res1 = loop ?ident_label ~proj_in_scope expr1 in
696+
let res1 = loop ~proj_in_scope expr1 in
762697
(* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)
763698
match res1.typ with
764699
| Unknown | Tensor | No_grad_tensor_intro _ ->
@@ -780,7 +715,7 @@ let translate ?ident_label (expr : expression) : result =
780715
}
781716
| Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_value _ | Merge_grad _ -> res1)
782717
| [%expr [%e? expr1].merge] -> (
783-
let res1 = loop ?ident_label ~proj_in_scope expr1 in
718+
let res1 = loop ~proj_in_scope expr1 in
784719
match res1.typ with
785720
| Unknown | Tensor | No_grad_tensor_intro _ ->
786721
{ res1 with typ = Merge_value res1.expr; expr = [%expr [%e res1.expr].Tensor.value] }
@@ -901,7 +836,7 @@ let translate ?ident_label (expr : expression) : result =
901836
process_raw_unop ~accu_op ~lhs ~un_op:[%expr Arrayjit.Ops.Identity] ~rhs
902837
~logic:[%expr Shape.Pointwise_un]
903838
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
904-
let res1 = loop ?ident_label ~proj_in_scope expr1 in
839+
let res1 = loop ~proj_in_scope expr1 in
905840
let res2 = loop ~proj_in_scope expr2 in
906841
let res3 = loop ~proj_in_scope expr3 in
907842
let slot =
@@ -918,7 +853,7 @@ let translate ?ident_label (expr : expression) : result =
918853
array_opt_of_code = None;
919854
}
920855
| [%expr [%e? expr1] [%e? expr2]] ->
921-
let res1 = loop ?ident_label ~proj_in_scope expr1 in
856+
let res1 = loop ~proj_in_scope expr1 in
922857
let res2 = loop ~proj_in_scope expr2 in
923858
let slot =
924859
Option.value ~default:Undet
@@ -940,7 +875,7 @@ let translate ?ident_label (expr : expression) : result =
940875
| _ -> false
941876
in
942877
let bad_pun_hints = Set.union bad_pun_hints @@ collect_pat_idents pat in
943-
let res1 = transl ~bad_pun_hints ?ident_label ~proj_in_scope expr1 in
878+
let res1 = transl ~bad_pun_hints ~proj_in_scope expr1 in
944879
{ res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, pat, res1.expr) } }
945880
| [%expr
946881
while [%e? _test_expr] do
@@ -992,7 +927,7 @@ let translate ?ident_label (expr : expression) : result =
992927
| [%expr [%e? t].grad] -> [%expr Arrayjit.Tnode.debug_name [%e t].value ^ ".grad"]
993928
| t -> [%expr Arrayjit.Tnode.debug_name [%e t].value])
994929
in
995-
let res2 = loop ?ident_label ~proj_in_scope expr2 in
930+
let res2 = loop ~proj_in_scope expr2 in
996931
{
997932
res2 with
998933
expr =
@@ -1005,7 +940,7 @@ let translate ?ident_label (expr : expression) : result =
1005940
[%e? expr1];
1006941
[%e? expr2]] ->
1007942
let res1 = loop ~proj_in_scope expr1 in
1008-
let res2 = loop ?ident_label ~proj_in_scope expr2 in
943+
let res2 = loop ~proj_in_scope expr2 in
1009944
{
1010945
vbs = reduce_vbss [ res1.vbs; res2.vbs ];
1011946
typ = Code;
@@ -1014,8 +949,8 @@ let translate ?ident_label (expr : expression) : result =
1014949
array_opt_of_code = res2.array_opt_of_code;
1015950
}
1016951
| [%expr if [%e? expr1] then [%e? expr2] else [%e? expr3]] ->
1017-
let res2 = loop ?ident_label ~proj_in_scope expr2 in
1018-
let res3 = loop ?ident_label ~proj_in_scope expr3 in
952+
let res2 = loop ~proj_in_scope expr2 in
953+
let res3 = loop ~proj_in_scope expr3 in
1019954
let typ = if is_unknown res2.typ then res3.typ else res2.typ in
1020955
let slot =
1021956
Option.value ~default:Undet
@@ -1029,7 +964,7 @@ let translate ?ident_label (expr : expression) : result =
1029964
array_opt_of_code = None;
1030965
}
1031966
| [%expr if [%e? expr1] then [%e? expr2]] ->
1032-
let res2 = loop ?ident_label ~proj_in_scope expr2 in
967+
let res2 = loop ~proj_in_scope expr2 in
1033968
{
1034969
vbs = res2.vbs;
1035970
typ = Code;
@@ -1041,7 +976,7 @@ let translate ?ident_label (expr : expression) : result =
1041976
let fields, cases =
1042977
List.unzip
1043978
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
1044-
let res = loop ?ident_label ~proj_in_scope pc_rhs in
979+
let res = loop ~proj_in_scope pc_rhs in
1045980
((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
1046981
in
1047982
let vbss, typs, slots = List.unzip3 fields in
@@ -1069,23 +1004,19 @@ let translate ?ident_label (expr : expression) : result =
10691004
(* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=loop
10701005
binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, loop body)} *)
10711006
| { pexp_desc = Pexp_open (decl, expr1); _ } ->
1072-
let res1 = loop ?ident_label ~proj_in_scope expr1 in
1007+
let res1 = loop ~proj_in_scope expr1 in
10731008
{ res1 with expr = { expr with pexp_desc = Pexp_open (decl, res1.expr) } }
10741009
| { pexp_desc = Pexp_letmodule (name, module_expr, expr1); _ } ->
1075-
let res1 = loop ?ident_label ~proj_in_scope expr1 in
1010+
let res1 = loop ~proj_in_scope expr1 in
10761011
{ res1 with expr = { expr with pexp_desc = Pexp_letmodule (name, module_expr, res1.expr) } }
10771012
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_operator op_ident ->
1078-
{
1079-
default_result with
1080-
typ = Unknown;
1081-
expr = [%expr [%e expr] ~label:[%e opt_pat2string_list ~loc ident_label]];
1082-
}
1013+
{ default_result with typ = Unknown; expr = [%expr [%e expr]] }
10831014
| _ -> { default_result with typ = Unknown }
10841015
in
1085-
transl ?ident_label ~proj_in_scope:false ~bad_pun_hints:(Set.empty (module String)) expr
1016+
transl ~proj_in_scope:false ~bad_pun_hints:(Set.empty (module String)) expr
10861017

10871018
let translate ?ident_label expr =
1088-
let res = translate ?ident_label expr in
1019+
let res = translate expr in
10891020
let expr = res.expr in
10901021
let loc = res.expr.pexp_loc in
10911022
( res.vbs,

0 commit comments

Comments
 (0)