@@ -226,17 +226,26 @@ let project_p_dims debug loc slot =
226226 " ppx_ocannl %%cd: insufficient slot filler information at %s %s" debug
227227 " (incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"
228228
229- let guess_pun_hint ~punned filler_typ filler =
229+ let guess_pun_hint ~punned ~ bad_pun_hints filler_typ filler =
230230 let loc = filler.pexp_loc in
231231 let hint = [% expr [% e filler].Arrayjit.Tnode. label] in
232232 match (filler_typ, filler) with
233233 | Code , _ -> None
234+ | _ , { pexp_desc = Pexp_ident { txt = Lident name ; _ } ; _ } when Set. mem bad_pun_hints name ->
235+ None
234236 | Array , _ -> Some (hint, false )
235237 | (Tensor | Unknown ), { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
236238 when Hashtbl. mem punned name ->
237239 Hashtbl. find punned name
238240 | (Tensor | Unknown ), { pexp_desc = Pexp_ident _ ; _ } -> Some (hint, true )
239241 | (Tensor | Unknown ), _ -> Some (hint, false )
242+ | ( ( Value_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
243+ | Grad_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
244+ | Merge_value { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
245+ | Merge_grad { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ } ),
246+ _ )
247+ when Set. mem bad_pun_hints name ->
248+ None
240249 | ( ( Value_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
241250 | Grad_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
242251 | Merge_value { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
@@ -249,7 +258,8 @@ let guess_pun_hint ~punned filler_typ filler =
249258 match t with { pexp_desc = Pexp_ident _ ; _ } -> Some (hint, true ) | _ -> Some (hint, false ))
250259 | No_grad_tensor_intro { name; _ } , _ -> Hashtbl. find punned name
251260
252- let setup_array ~punned ~is_lhs { typ = filler_typ ; slot; expr = filler ; vbs; array_opt_of_code } =
261+ let setup_array ~punned ~bad_pun_hints ~is_lhs
262+ { typ = filler_typ ; slot; expr = filler ; vbs; array_opt_of_code } =
253263 assert (Map. is_empty vbs);
254264 let loc = filler.pexp_loc in
255265 let opt_buffer tn =
@@ -259,7 +269,7 @@ let setup_array ~punned ~is_lhs { typ = filler_typ; slot; expr = filler; vbs; ar
259269 if is_lhs then opt_tn
260270 else [% expr Option. map [% e opt_tn] ~f: (fun tn -> Arrayjit.Assignments. Node tn)]
261271 in
262- let pun_hint_tnode = guess_pun_hint ~punned filler_typ filler in
272+ let pun_hint_tnode = guess_pun_hint ~punned ~bad_pun_hints filler_typ filler in
263273 let default_setup =
264274 {
265275 vb = None ;
@@ -406,19 +416,25 @@ let args_for ~loc = function
406416
407417let translate ?ident_label (expr : expression ) : result =
408418 let punned = Hashtbl. create (module String ) in
409- let rec transl ?ident_label ~proj_in_scope (expr : expression ) : result =
419+ let rec transl ~ bad_pun_hints ?ident_label ~proj_in_scope (expr : expression ) : result =
410420 let loc = expr.pexp_loc in
411421 let default_result =
412422 { vbs = no_vbs; typ = Tensor ; slot = Undet ; expr; array_opt_of_code = None }
413423 in
424+ let loop = transl ~bad_pun_hints in
414425 let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
415426 let initialize_neutral, accu_op = assignment_op accu_op in
416427 let setup_l =
417- setup_array ~punned ~is_lhs: true @@ transl ?ident_label ~proj_in_scope: true lhs
428+ setup_array ~punned ~bad_pun_hints ~is_lhs: true
429+ @@ loop ?ident_label ~proj_in_scope: true lhs
418430 in
419431 let _, bin_op = binary_op bin_op in
420- let setup_r1 = setup_array ~punned ~is_lhs: false @@ transl ~proj_in_scope rhs1 in
421- let setup_r2 = setup_array ~punned ~is_lhs: false @@ transl ~proj_in_scope rhs2 in
432+ let setup_r1 =
433+ setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1
434+ in
435+ let setup_r2 =
436+ setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2
437+ in
422438 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
423439 let projections =
424440 match projections with
@@ -477,8 +493,10 @@ let translate ?ident_label (expr : expression) : result =
477493 (* FIXME: I think this ignores the slot information here! Just assuming [projections] is
478494 as-should-be, but that's not consistent with omitting the projections arg (assuming it
479495 comes from the context). *)
480- let setup_l = setup_array ~punned ~is_lhs: true @@ transl ?ident_label ~proj_in_scope lhs in
481- let setup_r = setup_array ~punned ~is_lhs: false @@ transl ~proj_in_scope rhs in
496+ let setup_l =
497+ setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ?ident_label ~proj_in_scope lhs
498+ in
499+ let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs in
482500 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
483501 let projections =
484502 match projections with
@@ -530,9 +548,15 @@ let translate ?ident_label (expr : expression) : result =
530548 in
531549 let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
532550 let initialize_neutral, accu_op = assignment_op accu_op in
533- let setup_l = setup_array ~punned ~is_lhs: true @@ transl ?ident_label ~proj_in_scope lhs in
534- let setup_r1 = setup_array ~punned ~is_lhs: false @@ transl ~proj_in_scope rhs1 in
535- let setup_r2 = setup_array ~punned ~is_lhs: false @@ transl ~proj_in_scope rhs2 in
551+ let setup_l =
552+ setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ?ident_label ~proj_in_scope lhs
553+ in
554+ let setup_r1 =
555+ setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs1
556+ in
557+ let setup_r2 =
558+ setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs2
559+ in
536560 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
537561 let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
538562 let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
@@ -548,8 +572,10 @@ let translate ?ident_label (expr : expression) : result =
548572 in
549573 let process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic =
550574 let initialize_neutral, accu_op = assignment_op accu_op in
551- let setup_l = setup_array ~punned ~is_lhs: true @@ transl ?ident_label ~proj_in_scope lhs in
552- let setup_r = setup_array ~punned ~is_lhs: false @@ transl ~proj_in_scope rhs in
575+ let setup_l =
576+ setup_array ~punned ~bad_pun_hints ~is_lhs: true @@ loop ?ident_label ~proj_in_scope lhs
577+ in
578+ let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs: false @@ loop ~proj_in_scope rhs in
553579 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
554580 let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
555581 let t1_expr, rhs_is_grad, rhs_is_merge = args_for ~loc setup_r in
@@ -647,7 +673,7 @@ let translate ?ident_label (expr : expression) : result =
647673 | [% expr [% e? expr1] **. [% e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
648674 (* FIXME: `**.` should take a tensor and require that it's a literal. *)
649675 (* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)
650- let res1 = transl ~proj_in_scope expr1 in
676+ let res1 = loop ~proj_in_scope expr1 in
651677 {
652678 res1 with
653679 typ = Tensor ;
@@ -659,7 +685,7 @@ let translate ?ident_label (expr : expression) : result =
659685 (Float. of_int [% e i])];
660686 }
661687 | [% expr [% e? expr1] **. [% e? expr2]] ->
662- let res1 = transl ~proj_in_scope expr1 in
688+ let res1 = loop ~proj_in_scope expr1 in
663689 {
664690 res1 with
665691 typ = Tensor ;
@@ -674,8 +700,8 @@ let translate ?ident_label (expr : expression) : result =
674700 *+ [% e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]
675701 [% e? expr2]]
676702 when String. contains spec_str '>' ->
677- let res1 = transl ~proj_in_scope expr1 in
678- let res2 = transl ~proj_in_scope expr2 in
703+ let res1 = loop ~proj_in_scope expr1 in
704+ let res2 = loop ~proj_in_scope expr2 in
679705 let slot =
680706 Option. value ~default: Undet
681707 @@ List. find ~f: (function Undet -> false | _ -> true ) [ res1.slot; res2.slot ]
@@ -695,7 +721,7 @@ let translate ?ident_label (expr : expression) : result =
695721 [% e? expr1]
696722 ++ [% e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
697723 when String. contains spec_str '>' ->
698- let res1 = transl ~proj_in_scope expr1 in
724+ let res1 = loop ~proj_in_scope expr1 in
699725 {
700726 res1 with
701727 typ = Tensor ;
@@ -706,7 +732,7 @@ let translate ?ident_label (expr : expression) : result =
706732 [% e spec] [% e res1.expr]];
707733 }
708734 | [% expr [% e? expr1].grad] -> (
709- let res1 = transl ?ident_label ~proj_in_scope expr1 in
735+ let res1 = loop ?ident_label ~proj_in_scope expr1 in
710736 match res1.typ with
711737 | Unknown | Tensor | No_grad_tensor_intro _ ->
712738 {
@@ -732,7 +758,7 @@ let translate ?ident_label (expr : expression) : result =
732758 @@ Location. error_extensionf ~loc " ppx_ocannl %%cd: only tensors have a gradient" ;
733759 })
734760 | [% expr [% e? expr1].value] -> (
735- let res1 = transl ?ident_label ~proj_in_scope expr1 in
761+ let res1 = loop ?ident_label ~proj_in_scope expr1 in
736762 (* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)
737763 match res1.typ with
738764 | Unknown | Tensor | No_grad_tensor_intro _ ->
@@ -754,7 +780,7 @@ let translate ?ident_label (expr : expression) : result =
754780 }
755781 | Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_value _ | Merge_grad _ -> res1)
756782 | [% expr [% e? expr1].merge] -> (
757- let res1 = transl ?ident_label ~proj_in_scope expr1 in
783+ let res1 = loop ?ident_label ~proj_in_scope expr1 in
758784 match res1.typ with
759785 | Unknown | Tensor | No_grad_tensor_intro _ ->
760786 { res1 with typ = Merge_value res1.expr; expr = [% expr [% e res1.expr].Tensor. value] }
@@ -875,9 +901,9 @@ let translate ?ident_label (expr : expression) : result =
875901 process_raw_unop ~accu_op ~lhs ~un_op: [% expr Arrayjit.Ops. Identity ] ~rhs
876902 ~logic: [% expr Shape. Pointwise_un ]
877903 | [% expr [% e? expr1] [% e? expr2] [% e? expr3]] ->
878- let res1 = transl ?ident_label ~proj_in_scope expr1 in
879- let res2 = transl ~proj_in_scope expr2 in
880- let res3 = transl ~proj_in_scope expr3 in
904+ let res1 = loop ?ident_label ~proj_in_scope expr1 in
905+ let res2 = loop ~proj_in_scope expr2 in
906+ let res3 = loop ~proj_in_scope expr3 in
881907 let slot =
882908 Option. value ~default: Undet
883909 @@ List. find
@@ -892,8 +918,8 @@ let translate ?ident_label (expr : expression) : result =
892918 array_opt_of_code = None ;
893919 }
894920 | [% expr [% e? expr1] [% e? expr2]] ->
895- let res1 = transl ?ident_label ~proj_in_scope expr1 in
896- let res2 = transl ~proj_in_scope expr2 in
921+ let res1 = loop ?ident_label ~proj_in_scope expr1 in
922+ let res2 = loop ~proj_in_scope expr2 in
897923 let slot =
898924 Option. value ~default: Undet
899925 @@ List. find ~f: (function Undet -> false | _ -> true ) [ res1.slot; res2.slot ]
@@ -905,16 +931,17 @@ let translate ?ident_label (expr : expression) : result =
905931 expr = [% expr [% e res1.expr] [% e res2.expr]];
906932 array_opt_of_code = None ;
907933 }
908- | { pexp_desc = Pexp_fun ((arg_label : arg_label ), arg , opt_val , expr1 ); _ } as expr ->
934+ | { pexp_desc = Pexp_fun ((arg_label : arg_label ), arg , pat , expr1 ); _ } as expr ->
909935 let proj_in_scope =
910936 proj_in_scope
911937 ||
912938 match arg_label with
913939 | (Labelled s | Optional s ) when String. equal s " projections" -> true
914940 | _ -> false
915941 in
916- let res1 = transl ?ident_label ~proj_in_scope expr1 in
917- { res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, opt_val, res1.expr) } }
942+ let bad_pun_hints = Set. union bad_pun_hints @@ collect_pat_idents pat in
943+ let res1 = transl ~bad_pun_hints ?ident_label ~proj_in_scope expr1 in
944+ { res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, pat, res1.expr) } }
918945 | [% expr
919946 while [% e? _test_expr] do
920947 [% e? _body]
@@ -965,7 +992,7 @@ let translate ?ident_label (expr : expression) : result =
965992 | [% expr [% e? t].grad] -> [% expr Arrayjit.Tnode. debug_name [% e t].value ^ " .grad" ]
966993 | t -> [% expr Arrayjit.Tnode. debug_name [% e t].value])
967994 in
968- let res2 = transl ?ident_label ~proj_in_scope expr2 in
995+ let res2 = loop ?ident_label ~proj_in_scope expr2 in
969996 {
970997 res2 with
971998 expr =
@@ -977,8 +1004,8 @@ let translate ?ident_label (expr : expression) : result =
9771004 | [% expr
9781005 [% e? expr1];
9791006 [% e? expr2]] ->
980- let res1 = transl ~proj_in_scope expr1 in
981- let res2 = transl ?ident_label ~proj_in_scope expr2 in
1007+ let res1 = loop ~proj_in_scope expr1 in
1008+ let res2 = loop ?ident_label ~proj_in_scope expr2 in
9821009 {
9831010 vbs = reduce_vbss [ res1.vbs; res2.vbs ];
9841011 typ = Code ;
@@ -987,8 +1014,8 @@ let translate ?ident_label (expr : expression) : result =
9871014 array_opt_of_code = res2.array_opt_of_code;
9881015 }
9891016 | [% expr if [% e? expr1] then [% e? expr2] else [% e? expr3]] ->
990- let res2 = transl ?ident_label ~proj_in_scope expr2 in
991- let res3 = transl ?ident_label ~proj_in_scope expr3 in
1017+ let res2 = loop ?ident_label ~proj_in_scope expr2 in
1018+ let res3 = loop ?ident_label ~proj_in_scope expr3 in
9921019 let typ = if is_unknown res2.typ then res3.typ else res2.typ in
9931020 let slot =
9941021 Option. value ~default: Undet
@@ -1002,7 +1029,7 @@ let translate ?ident_label (expr : expression) : result =
10021029 array_opt_of_code = None ;
10031030 }
10041031 | [% expr if [% e? expr1] then [% e? expr2]] ->
1005- let res2 = transl ?ident_label ~proj_in_scope expr2 in
1032+ let res2 = loop ?ident_label ~proj_in_scope expr2 in
10061033 {
10071034 vbs = res2.vbs;
10081035 typ = Code ;
@@ -1014,7 +1041,7 @@ let translate ?ident_label (expr : expression) : result =
10141041 let fields, cases =
10151042 List. unzip
10161043 @@ List. map cases ~f: (fun ({ pc_rhs; _ } as c ) ->
1017- let res = transl ?ident_label ~proj_in_scope pc_rhs in
1044+ let res = loop ?ident_label ~proj_in_scope pc_rhs in
10181045 ((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
10191046 in
10201047 let vbss, typs, slots = List. unzip3 fields in
@@ -1039,13 +1066,13 @@ let translate ?ident_label (expr : expression) : result =
10391066 @@ Location. error_extensionf ~loc
10401067 " ppx_ocannl %%cd: let-in: local let-bindings not implemented yet" ;
10411068 }
1042- (* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=transl
1043- binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, transl body)} *)
1069+ (* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=loop
1070+ binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, loop body)} *)
10441071 | { pexp_desc = Pexp_open (decl , expr1 ); _ } ->
1045- let res1 = transl ?ident_label ~proj_in_scope expr1 in
1072+ let res1 = loop ?ident_label ~proj_in_scope expr1 in
10461073 { res1 with expr = { expr with pexp_desc = Pexp_open (decl, res1.expr) } }
10471074 | { pexp_desc = Pexp_letmodule (name , module_expr , expr1 ); _ } ->
1048- let res1 = transl ?ident_label ~proj_in_scope expr1 in
1075+ let res1 = loop ?ident_label ~proj_in_scope expr1 in
10491076 { res1 with expr = { expr with pexp_desc = Pexp_letmodule (name, module_expr, res1.expr) } }
10501077 | { pexp_desc = Pexp_ident { txt = Lident op_ident ; _ } ; _ } when is_operator op_ident ->
10511078 {
@@ -1055,7 +1082,7 @@ let translate ?ident_label (expr : expression) : result =
10551082 }
10561083 | _ -> { default_result with typ = Unknown }
10571084 in
1058- transl ?ident_label ~proj_in_scope: false expr
1085+ transl ?ident_label ~proj_in_scope: false ~bad_pun_hints: ( Set. empty ( module String )) expr
10591086
10601087let translate ?ident_label expr =
10611088 let res = translate ?ident_label expr in
0 commit comments