@@ -4,7 +4,7 @@ open Ppx_arrayjit.Ppx_helper
44open Ppx_shared
55module A = Ppxlib_ast. Ast_helper
66
7- let ndarray_op ~ ident_label ?axis_labels ?label expr =
7+ let ndarray_op ?axis_labels ?label expr =
88 let loc = expr.pexp_loc in
99 let values, batch_dims, output_dims, input_dims = ndarray_constant expr in
1010 let edims dims = Ast_builder.Default. elist ~loc dims in
@@ -14,15 +14,10 @@ let ndarray_op ~ident_label ?axis_labels ?label expr =
1414 | Some axis_labels , None -> [% expr NTDSL. ndarray ~axis_labels: [% e axis_labels]]
1515 | None , Some label -> [% expr NTDSL. ndarray ~label: [% e label]]
1616 | Some axis_labels , Some label ->
17- [% expr
18- NTDSL. ndarray
19- ~label: [% e opt_pat2string_list ~loc ident_label]
20- ~axis_labels: [% e axis_labels] ~label: [% e label]]
17+ [% expr NTDSL. ndarray ~axis_labels: [% e axis_labels] ~label: [% e label]]
2118 in
2219 [% expr
23- [% e op]
24- ~label: [% e opt_pat2string_list ~loc ident_label]
25- ~batch_dims: [% e edims batch_dims] ~input_dims: [% e edims input_dims]
20+ [% e op] ~batch_dims: [% e edims batch_dims] ~input_dims: [% e edims input_dims]
2621 ~output_dims: [% e edims output_dims] [% e values]]
2722
2823type expr_type =
@@ -414,9 +409,9 @@ let args_for ~loc = function
414409 [% expr false ],
415410 [% expr false ] )
416411
417- let translate ? ident_label (expr : expression ) : result =
412+ let translate (expr : expression ) : result =
418413 let punned = Hashtbl. create (module String ) in
419- let rec transl ~bad_pun_hints ? ident_label ~proj_in_scope (expr : expression ) : result =
414+ let rec transl ~bad_pun_hints ~proj_in_scope (expr : expression ) : result =
420415 let loc = expr.pexp_loc in
421416 let default_result =
422417 { vbs = no_vbs; typ = Tensor ; slot = Undet ; expr; array_opt_of_code = None }
@@ -425,16 +420,11 @@ let translate ?ident_label (expr : expression) : result =
425420 let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
426421 let initialize_neutral, accu_op = assignment_op accu_op in
427422 let setup_l =
428- setup_array ~punned ~bad_pun_hints ~is_lhs: true
429- @@ loop ?ident_label ~proj_in_scope: true lhs
423+ setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ~proj_in_scope: true lhs
430424 in
431425 let _, bin_op = binary_op bin_op in
432- let setup_r1 =
433- setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1
434- in
435- let setup_r2 =
436- setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2
437- in
426+ let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1 in
427+ let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2 in
438428 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
439429 let projections =
440430 match projections with
@@ -493,9 +483,7 @@ let translate ?ident_label (expr : expression) : result =
493483 (* FIXME: I think this ignores the slot information here! Just assuming [projections] is
494484 as-should-be, but that's not consistent with omitting the projections arg (assuming it
495485 comes from the context). *)
496- let setup_l =
497- setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ?ident_label ~proj_in_scope lhs
498- in
486+ let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ~proj_in_scope lhs in
499487 let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs in
500488 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
501489 let projections =
@@ -548,15 +536,9 @@ let translate ?ident_label (expr : expression) : result =
548536 in
549537 let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
550538 let initialize_neutral, accu_op = assignment_op accu_op in
551- let setup_l =
552- setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ?ident_label ~proj_in_scope lhs
553- in
554- let setup_r1 =
555- setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1
556- in
557- let setup_r2 =
558- setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2
559- in
539+ let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ~proj_in_scope lhs in
540+ let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1 in
541+ let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2 in
560542 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
561543 let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
562544 let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
@@ -572,9 +554,7 @@ let translate ?ident_label (expr : expression) : result =
572554 in
573555 let process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic =
574556 let initialize_neutral, accu_op = assignment_op accu_op in
575- let setup_l =
576- setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ?ident_label ~proj_in_scope lhs
577- in
557+ let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ~proj_in_scope lhs in
578558 let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs in
579559 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
580560 let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
@@ -589,31 +569,16 @@ let translate ?ident_label (expr : expression) : result =
589569 in
590570 match expr with
591571 | { pexp_desc = Pexp_constant (Pconst_float _ ); _ } ->
592- {
593- default_result with
594- expr = [% expr NTDSL. number ~label: [% e opt_pat2string_list ~loc ident_label] [% e expr]];
595- }
572+ { default_result with expr = [% expr NTDSL. number [% e expr]] }
596573 | { pexp_desc = Pexp_constant (Pconst_integer _ ); _ } ->
597- {
598- default_result with
599- expr =
600- [% expr
601- NTDSL. number ~label: [% e opt_pat2string_list ~loc ident_label] (Float. of_int [% e expr])];
602- }
574+ { default_result with expr = [% expr NTDSL. number (Float. of_int [% e expr])] }
603575 | [% expr
604576 [% e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
605577 [% e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] ->
606578 let axis =
607579 Ast_helper.Exp. constant ~loc: pexp_loc (Pconst_string (String. of_char ch, pexp_loc, None ))
608580 in
609- {
610- default_result with
611- expr =
612- [% expr
613- NTDSL. number
614- ~label: [% e opt_pat2string_list ~loc ident_label]
615- ~axis_label: [% e axis] [% e f]];
616- }
581+ { default_result with expr = [% expr NTDSL. number ~axis_label: [% e axis] [% e f]] }
617582 | [% expr
618583 [% e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
619584 [% e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
@@ -622,12 +587,7 @@ let translate ?ident_label (expr : expression) : result =
622587 in
623588 {
624589 default_result with
625- expr =
626- [% expr
627- NTDSL. number
628- ~label: [% e opt_pat2string_list ~loc ident_label]
629- ~axis_label: [% e axis]
630- (Float. of_int [% e i])];
590+ expr = [% expr NTDSL. number ~axis_label: [% e axis] (Float. of_int [% e i])];
631591 }
632592 | { pexp_desc = Pexp_constant (Pconst_string (name , str_loc , _ )); _ } ->
633593 {
@@ -637,7 +597,7 @@ let translate ?ident_label (expr : expression) : result =
637597 }
638598 | { pexp_desc = Pexp_array _; _ }
639599 | { pexp_desc = Pexp_construct ({ txt = Lident "::" ; _ } , _ ); _ } ->
640- { default_result with expr = ndarray_op ~ident_label expr }
600+ { default_result with expr = ndarray_op expr }
641601 | { pexp_desc = Pexp_ident { txt = Lident ("v" | "lhs" ); _ } ; _ } ->
642602 { default_result with typ = Array ; slot = LHS }
643603 | { pexp_desc = Pexp_ident { txt = Lident "g" ; _ } ; _ } ->
@@ -677,24 +637,11 @@ let translate ?ident_label (expr : expression) : result =
677637 {
678638 res1 with
679639 typ = Tensor ;
680- expr =
681- [% expr
682- NTDSL.O. ( **. )
683- ~label: [% e opt_pat2string_list ~loc ident_label]
684- [% e res1.expr]
685- (Float. of_int [% e i])];
640+ expr = [% expr NTDSL.O. ( **. ) [% e res1.expr] (Float. of_int [% e i])];
686641 }
687642 | [% expr [% e? expr1] **. [% e? expr2]] ->
688643 let res1 = loop ~proj_in_scope expr1 in
689- {
690- res1 with
691- typ = Tensor ;
692- expr =
693- [% expr
694- NTDSL.O. ( **. )
695- ~label: [% e opt_pat2string_list ~loc ident_label]
696- [% e res1.expr] [% e expr2]];
697- }
644+ { res1 with typ = Tensor ; expr = [% expr NTDSL.O. ( **. ) [% e res1.expr] [% e expr2]] }
698645 | [% expr
699646 [% e? expr1]
700647 *+ [% e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]
@@ -710,29 +657,17 @@ let translate ?ident_label (expr : expression) : result =
710657 vbs = reduce_vbss [ res1.vbs; res2.vbs ];
711658 typ = Tensor ;
712659 slot;
713- expr =
714- [% expr
715- NTDSL. einsum
716- ~label: [% e opt_pat2string_list ~loc ident_label]
717- [% e spec] [% e res1.expr] [% e res2.expr]];
660+ expr = [% expr NTDSL. einsum [% e spec] [% e res1.expr] [% e res2.expr]];
718661 array_opt_of_code = None ;
719662 }
720663 | [% expr
721664 [% e? expr1]
722665 ++ [% e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
723666 when String. contains spec_str '>' ->
724667 let res1 = loop ~proj_in_scope expr1 in
725- {
726- res1 with
727- typ = Tensor ;
728- expr =
729- [% expr
730- NTDSL. einsum1
731- ~label: [% e opt_pat2string_list ~loc ident_label]
732- [% e spec] [% e res1.expr]];
733- }
668+ { res1 with typ = Tensor ; expr = [% expr NTDSL. einsum1 [% e spec] [% e res1.expr]] }
734669 | [% expr [% e? expr1].grad] -> (
735- let res1 = loop ?ident_label ~proj_in_scope expr1 in
670+ let res1 = loop ~proj_in_scope expr1 in
736671 match res1.typ with
737672 | Unknown | Tensor | No_grad_tensor_intro _ ->
738673 {
@@ -758,7 +693,7 @@ let translate ?ident_label (expr : expression) : result =
758693 @@ Location. error_extensionf ~loc " ppx_ocannl %%cd: only tensors have a gradient" ;
759694 })
760695 | [% expr [% e? expr1].value] -> (
761- let res1 = loop ?ident_label ~proj_in_scope expr1 in
696+ let res1 = loop ~proj_in_scope expr1 in
762697 (* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)
763698 match res1.typ with
764699 | Unknown | Tensor | No_grad_tensor_intro _ ->
@@ -780,7 +715,7 @@ let translate ?ident_label (expr : expression) : result =
780715 }
781716 | Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_value _ | Merge_grad _ -> res1)
782717 | [% expr [% e? expr1].merge] -> (
783- let res1 = loop ?ident_label ~proj_in_scope expr1 in
718+ let res1 = loop ~proj_in_scope expr1 in
784719 match res1.typ with
785720 | Unknown | Tensor | No_grad_tensor_intro _ ->
786721 { res1 with typ = Merge_value res1.expr; expr = [% expr [% e res1.expr].Tensor. value] }
@@ -901,7 +836,7 @@ let translate ?ident_label (expr : expression) : result =
901836 process_raw_unop ~accu_op ~lhs ~un_op: [% expr Arrayjit.Ops. Identity ] ~rhs
902837 ~logic: [% expr Shape. Pointwise_un ]
903838 | [% expr [% e? expr1] [% e? expr2] [% e? expr3]] ->
904- let res1 = loop ?ident_label ~proj_in_scope expr1 in
839+ let res1 = loop ~proj_in_scope expr1 in
905840 let res2 = loop ~proj_in_scope expr2 in
906841 let res3 = loop ~proj_in_scope expr3 in
907842 let slot =
@@ -918,7 +853,7 @@ let translate ?ident_label (expr : expression) : result =
918853 array_opt_of_code = None ;
919854 }
920855 | [% expr [% e? expr1] [% e? expr2]] ->
921- let res1 = loop ?ident_label ~proj_in_scope expr1 in
856+ let res1 = loop ~proj_in_scope expr1 in
922857 let res2 = loop ~proj_in_scope expr2 in
923858 let slot =
924859 Option. value ~default: Undet
@@ -940,7 +875,7 @@ let translate ?ident_label (expr : expression) : result =
940875 | _ -> false
941876 in
942877 let bad_pun_hints = Set. union bad_pun_hints @@ collect_pat_idents pat in
943- let res1 = transl ~bad_pun_hints ?ident_label ~proj_in_scope expr1 in
878+ let res1 = transl ~bad_pun_hints ~proj_in_scope expr1 in
944879 { res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, pat, res1.expr) } }
945880 | [% expr
946881 while [% e? _test_expr] do
@@ -992,7 +927,7 @@ let translate ?ident_label (expr : expression) : result =
992927 | [% expr [% e? t].grad] -> [% expr Arrayjit.Tnode. debug_name [% e t].value ^ " .grad" ]
993928 | t -> [% expr Arrayjit.Tnode. debug_name [% e t].value])
994929 in
995- let res2 = loop ?ident_label ~proj_in_scope expr2 in
930+ let res2 = loop ~proj_in_scope expr2 in
996931 {
997932 res2 with
998933 expr =
@@ -1005,7 +940,7 @@ let translate ?ident_label (expr : expression) : result =
1005940 [% e? expr1];
1006941 [% e? expr2]] ->
1007942 let res1 = loop ~proj_in_scope expr1 in
1008- let res2 = loop ?ident_label ~proj_in_scope expr2 in
943+ let res2 = loop ~proj_in_scope expr2 in
1009944 {
1010945 vbs = reduce_vbss [ res1.vbs; res2.vbs ];
1011946 typ = Code ;
@@ -1014,8 +949,8 @@ let translate ?ident_label (expr : expression) : result =
1014949 array_opt_of_code = res2.array_opt_of_code;
1015950 }
1016951 | [% expr if [% e? expr1] then [% e? expr2] else [% e? expr3]] ->
1017- let res2 = loop ?ident_label ~proj_in_scope expr2 in
1018- let res3 = loop ?ident_label ~proj_in_scope expr3 in
952+ let res2 = loop ~proj_in_scope expr2 in
953+ let res3 = loop ~proj_in_scope expr3 in
1019954 let typ = if is_unknown res2.typ then res3.typ else res2.typ in
1020955 let slot =
1021956 Option. value ~default: Undet
@@ -1029,7 +964,7 @@ let translate ?ident_label (expr : expression) : result =
1029964 array_opt_of_code = None ;
1030965 }
1031966 | [% expr if [% e? expr1] then [% e? expr2]] ->
1032- let res2 = loop ?ident_label ~proj_in_scope expr2 in
967+ let res2 = loop ~proj_in_scope expr2 in
1033968 {
1034969 vbs = res2.vbs;
1035970 typ = Code ;
@@ -1041,7 +976,7 @@ let translate ?ident_label (expr : expression) : result =
1041976 let fields, cases =
1042977 List. unzip
1043978 @@ List. map cases ~f: (fun ({ pc_rhs; _ } as c ) ->
1044- let res = loop ?ident_label ~proj_in_scope pc_rhs in
979+ let res = loop ~proj_in_scope pc_rhs in
1045980 ((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
1046981 in
1047982 let vbss, typs, slots = List. unzip3 fields in
@@ -1069,23 +1004,19 @@ let translate ?ident_label (expr : expression) : result =
10691004 (* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=loop
10701005 binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, loop body)} *)
10711006 | { pexp_desc = Pexp_open (decl , expr1 ); _ } ->
1072- let res1 = loop ?ident_label ~proj_in_scope expr1 in
1007+ let res1 = loop ~proj_in_scope expr1 in
10731008 { res1 with expr = { expr with pexp_desc = Pexp_open (decl, res1.expr) } }
10741009 | { pexp_desc = Pexp_letmodule (name , module_expr , expr1 ); _ } ->
1075- let res1 = loop ?ident_label ~proj_in_scope expr1 in
1010+ let res1 = loop ~proj_in_scope expr1 in
10761011 { res1 with expr = { expr with pexp_desc = Pexp_letmodule (name, module_expr, res1.expr) } }
10771012 | { pexp_desc = Pexp_ident { txt = Lident op_ident ; _ } ; _ } when is_operator op_ident ->
1078- {
1079- default_result with
1080- typ = Unknown ;
1081- expr = [% expr [% e expr] ~label: [% e opt_pat2string_list ~loc ident_label]];
1082- }
1013+ { default_result with typ = Unknown ; expr = [% expr [% e expr]] }
10831014 | _ -> { default_result with typ = Unknown }
10841015 in
1085- transl ?ident_label ~proj_in_scope: false ~bad_pun_hints: (Set. empty (module String )) expr
1016+ transl ~proj_in_scope: false ~bad_pun_hints: (Set. empty (module String )) expr
10861017
10871018let translate ?ident_label expr =
1088- let res = translate ?ident_label expr in
1019+ let res = translate expr in
10891020 let expr = res.expr in
10901021 let loc = res.expr.pexp_loc in
10911022 ( res.vbs,
0 commit comments