Skip to content

Commit fd8f53c

Browse files
committed
Untested: prevent %cd inline declarations with escaping label sources
1 parent 1565184 commit fd8f53c

File tree

2 files changed

+92
-42
lines changed

2 files changed

+92
-42
lines changed

lib/ppx_cd.ml

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -226,17 +226,26 @@ let project_p_dims debug loc slot =
226226
"ppx_ocannl %%cd: insufficient slot filler information at %s %s" debug
227227
"(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)"
228228

229-
let guess_pun_hint ~punned filler_typ filler =
229+
let guess_pun_hint ~punned ~bad_pun_hints filler_typ filler =
230230
let loc = filler.pexp_loc in
231231
let hint = [%expr [%e filler].Arrayjit.Tnode.label] in
232232
match (filler_typ, filler) with
233233
| Code, _ -> None
234+
| _, { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ } when Set.mem bad_pun_hints name ->
235+
None
234236
| Array, _ -> Some (hint, false)
235237
| (Tensor | Unknown), { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
236238
when Hashtbl.mem punned name ->
237239
Hashtbl.find punned name
238240
| (Tensor | Unknown), { pexp_desc = Pexp_ident _; _ } -> Some (hint, true)
239241
| (Tensor | Unknown), _ -> Some (hint, false)
242+
| ( ( Value_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
243+
| Grad_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
244+
| Merge_value { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
245+
| Merge_grad { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ } ),
246+
_ )
247+
when Set.mem bad_pun_hints name ->
248+
None
240249
| ( ( Value_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
241250
| Grad_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
242251
| Merge_value { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
@@ -249,7 +258,8 @@ let guess_pun_hint ~punned filler_typ filler =
249258
match t with { pexp_desc = Pexp_ident _; _ } -> Some (hint, true) | _ -> Some (hint, false))
250259
| No_grad_tensor_intro { name; _ }, _ -> Hashtbl.find punned name
251260

252-
let setup_array ~punned ~is_lhs { typ = filler_typ; slot; expr = filler; vbs; array_opt_of_code } =
261+
let setup_array ~punned ~bad_pun_hints ~is_lhs
262+
{ typ = filler_typ; slot; expr = filler; vbs; array_opt_of_code } =
253263
assert (Map.is_empty vbs);
254264
let loc = filler.pexp_loc in
255265
let opt_buffer tn =
@@ -259,7 +269,7 @@ let setup_array ~punned ~is_lhs { typ = filler_typ; slot; expr = filler; vbs; ar
259269
if is_lhs then opt_tn
260270
else [%expr Option.map [%e opt_tn] ~f:(fun tn -> Arrayjit.Assignments.Node tn)]
261271
in
262-
let pun_hint_tnode = guess_pun_hint ~punned filler_typ filler in
272+
let pun_hint_tnode = guess_pun_hint ~punned ~bad_pun_hints filler_typ filler in
263273
let default_setup =
264274
{
265275
vb = None;
@@ -406,19 +416,25 @@ let args_for ~loc = function
406416

407417
let translate ?ident_label (expr : expression) : result =
408418
let punned = Hashtbl.create (module String) in
409-
let rec transl ?ident_label ~proj_in_scope (expr : expression) : result =
419+
let rec transl ~bad_pun_hints ?ident_label ~proj_in_scope (expr : expression) : result =
410420
let loc = expr.pexp_loc in
411421
let default_result =
412422
{ vbs = no_vbs; typ = Tensor; slot = Undet; expr; array_opt_of_code = None }
413423
in
424+
let loop = transl ~bad_pun_hints in
414425
let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
415426
let initialize_neutral, accu_op = assignment_op accu_op in
416427
let setup_l =
417-
setup_array ~punned ~is_lhs:true @@ transl ?ident_label ~proj_in_scope:true lhs
428+
setup_array ~punned ~bad_pun_hints ~is_lhs:true
429+
@@ loop ?ident_label ~proj_in_scope:true lhs
418430
in
419431
let _, bin_op = binary_op bin_op in
420-
let setup_r1 = setup_array ~punned ~is_lhs:false @@ transl ~proj_in_scope rhs1 in
421-
let setup_r2 = setup_array ~punned ~is_lhs:false @@ transl ~proj_in_scope rhs2 in
432+
let setup_r1 =
433+
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1
434+
in
435+
let setup_r2 =
436+
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2
437+
in
422438
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
423439
let projections =
424440
match projections with
@@ -477,8 +493,10 @@ let translate ?ident_label (expr : expression) : result =
477493
(* FIXME: I think this ignores the slot information here! Just assuming [projections] is
478494
as-should-be, but that's not consistent with omitting the projections arg (assuming it
479495
comes from the context). *)
480-
let setup_l = setup_array ~punned ~is_lhs:true @@ transl ?ident_label ~proj_in_scope lhs in
481-
let setup_r = setup_array ~punned ~is_lhs:false @@ transl ~proj_in_scope rhs in
496+
let setup_l =
497+
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ?ident_label ~proj_in_scope lhs
498+
in
499+
let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs in
482500
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
483501
let projections =
484502
match projections with
@@ -530,9 +548,15 @@ let translate ?ident_label (expr : expression) : result =
530548
in
531549
let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
532550
let initialize_neutral, accu_op = assignment_op accu_op in
533-
let setup_l = setup_array ~punned ~is_lhs:true @@ transl ?ident_label ~proj_in_scope lhs in
534-
let setup_r1 = setup_array ~punned ~is_lhs:false @@ transl ~proj_in_scope rhs1 in
535-
let setup_r2 = setup_array ~punned ~is_lhs:false @@ transl ~proj_in_scope rhs2 in
551+
let setup_l =
552+
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ?ident_label ~proj_in_scope lhs
553+
in
554+
let setup_r1 =
555+
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1
556+
in
557+
let setup_r2 =
558+
setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2
559+
in
536560
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
537561
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
538562
let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
@@ -548,8 +572,10 @@ let translate ?ident_label (expr : expression) : result =
548572
in
549573
let process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic =
550574
let initialize_neutral, accu_op = assignment_op accu_op in
551-
let setup_l = setup_array ~punned ~is_lhs:true @@ transl ?ident_label ~proj_in_scope lhs in
552-
let setup_r = setup_array ~punned ~is_lhs:false @@ transl ~proj_in_scope rhs in
575+
let setup_l =
576+
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ?ident_label ~proj_in_scope lhs
577+
in
578+
let setup_r = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs in
553579
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
554580
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
555581
let t1_expr, rhs_is_grad, rhs_is_merge = args_for ~loc setup_r in
@@ -647,7 +673,7 @@ let translate ?ident_label (expr : expression) : result =
647673
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
648674
(* FIXME: `**.` should take a tensor and require that it's a literal. *)
649675
(* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)
650-
let res1 = transl ~proj_in_scope expr1 in
676+
let res1 = loop ~proj_in_scope expr1 in
651677
{
652678
res1 with
653679
typ = Tensor;
@@ -659,7 +685,7 @@ let translate ?ident_label (expr : expression) : result =
659685
(Float.of_int [%e i])];
660686
}
661687
| [%expr [%e? expr1] **. [%e? expr2]] ->
662-
let res1 = transl ~proj_in_scope expr1 in
688+
let res1 = loop ~proj_in_scope expr1 in
663689
{
664690
res1 with
665691
typ = Tensor;
@@ -674,8 +700,8 @@ let translate ?ident_label (expr : expression) : result =
674700
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]
675701
[%e? expr2]]
676702
when String.contains spec_str '>' ->
677-
let res1 = transl ~proj_in_scope expr1 in
678-
let res2 = transl ~proj_in_scope expr2 in
703+
let res1 = loop ~proj_in_scope expr1 in
704+
let res2 = loop ~proj_in_scope expr2 in
679705
let slot =
680706
Option.value ~default:Undet
681707
@@ List.find ~f:(function Undet -> false | _ -> true) [ res1.slot; res2.slot ]
@@ -695,7 +721,7 @@ let translate ?ident_label (expr : expression) : result =
695721
[%e? expr1]
696722
++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
697723
when String.contains spec_str '>' ->
698-
let res1 = transl ~proj_in_scope expr1 in
724+
let res1 = loop ~proj_in_scope expr1 in
699725
{
700726
res1 with
701727
typ = Tensor;
@@ -706,7 +732,7 @@ let translate ?ident_label (expr : expression) : result =
706732
[%e spec] [%e res1.expr]];
707733
}
708734
| [%expr [%e? expr1].grad] -> (
709-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
735+
let res1 = loop ?ident_label ~proj_in_scope expr1 in
710736
match res1.typ with
711737
| Unknown | Tensor | No_grad_tensor_intro _ ->
712738
{
@@ -732,7 +758,7 @@ let translate ?ident_label (expr : expression) : result =
732758
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: only tensors have a gradient";
733759
})
734760
| [%expr [%e? expr1].value] -> (
735-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
761+
let res1 = loop ?ident_label ~proj_in_scope expr1 in
736762
(* TODO: maybe this is too permissive? E.g. [t1.grad.value] is accepted. *)
737763
match res1.typ with
738764
| Unknown | Tensor | No_grad_tensor_intro _ ->
@@ -754,7 +780,7 @@ let translate ?ident_label (expr : expression) : result =
754780
}
755781
| Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_value _ | Merge_grad _ -> res1)
756782
| [%expr [%e? expr1].merge] -> (
757-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
783+
let res1 = loop ?ident_label ~proj_in_scope expr1 in
758784
match res1.typ with
759785
| Unknown | Tensor | No_grad_tensor_intro _ ->
760786
{ res1 with typ = Merge_value res1.expr; expr = [%expr [%e res1.expr].Tensor.value] }
@@ -875,9 +901,9 @@ let translate ?ident_label (expr : expression) : result =
875901
process_raw_unop ~accu_op ~lhs ~un_op:[%expr Arrayjit.Ops.Identity] ~rhs
876902
~logic:[%expr Shape.Pointwise_un]
877903
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
878-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
879-
let res2 = transl ~proj_in_scope expr2 in
880-
let res3 = transl ~proj_in_scope expr3 in
904+
let res1 = loop ?ident_label ~proj_in_scope expr1 in
905+
let res2 = loop ~proj_in_scope expr2 in
906+
let res3 = loop ~proj_in_scope expr3 in
881907
let slot =
882908
Option.value ~default:Undet
883909
@@ List.find
@@ -892,8 +918,8 @@ let translate ?ident_label (expr : expression) : result =
892918
array_opt_of_code = None;
893919
}
894920
| [%expr [%e? expr1] [%e? expr2]] ->
895-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
896-
let res2 = transl ~proj_in_scope expr2 in
921+
let res1 = loop ?ident_label ~proj_in_scope expr1 in
922+
let res2 = loop ~proj_in_scope expr2 in
897923
let slot =
898924
Option.value ~default:Undet
899925
@@ List.find ~f:(function Undet -> false | _ -> true) [ res1.slot; res2.slot ]
@@ -905,16 +931,17 @@ let translate ?ident_label (expr : expression) : result =
905931
expr = [%expr [%e res1.expr] [%e res2.expr]];
906932
array_opt_of_code = None;
907933
}
908-
| { pexp_desc = Pexp_fun ((arg_label : arg_label), arg, opt_val, expr1); _ } as expr ->
934+
| { pexp_desc = Pexp_fun ((arg_label : arg_label), arg, pat, expr1); _ } as expr ->
909935
let proj_in_scope =
910936
proj_in_scope
911937
||
912938
match arg_label with
913939
| (Labelled s | Optional s) when String.equal s "projections" -> true
914940
| _ -> false
915941
in
916-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
917-
{ res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, opt_val, res1.expr) } }
942+
let bad_pun_hints = Set.union bad_pun_hints @@ collect_pat_idents pat in
943+
let res1 = transl ~bad_pun_hints ?ident_label ~proj_in_scope expr1 in
944+
{ res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, pat, res1.expr) } }
918945
| [%expr
919946
while [%e? _test_expr] do
920947
[%e? _body]
@@ -965,7 +992,7 @@ let translate ?ident_label (expr : expression) : result =
965992
| [%expr [%e? t].grad] -> [%expr Arrayjit.Tnode.debug_name [%e t].value ^ ".grad"]
966993
| t -> [%expr Arrayjit.Tnode.debug_name [%e t].value])
967994
in
968-
let res2 = transl ?ident_label ~proj_in_scope expr2 in
995+
let res2 = loop ?ident_label ~proj_in_scope expr2 in
969996
{
970997
res2 with
971998
expr =
@@ -977,8 +1004,8 @@ let translate ?ident_label (expr : expression) : result =
9771004
| [%expr
9781005
[%e? expr1];
9791006
[%e? expr2]] ->
980-
let res1 = transl ~proj_in_scope expr1 in
981-
let res2 = transl ?ident_label ~proj_in_scope expr2 in
1007+
let res1 = loop ~proj_in_scope expr1 in
1008+
let res2 = loop ?ident_label ~proj_in_scope expr2 in
9821009
{
9831010
vbs = reduce_vbss [ res1.vbs; res2.vbs ];
9841011
typ = Code;
@@ -987,8 +1014,8 @@ let translate ?ident_label (expr : expression) : result =
9871014
array_opt_of_code = res2.array_opt_of_code;
9881015
}
9891016
| [%expr if [%e? expr1] then [%e? expr2] else [%e? expr3]] ->
990-
let res2 = transl ?ident_label ~proj_in_scope expr2 in
991-
let res3 = transl ?ident_label ~proj_in_scope expr3 in
1017+
let res2 = loop ?ident_label ~proj_in_scope expr2 in
1018+
let res3 = loop ?ident_label ~proj_in_scope expr3 in
9921019
let typ = if is_unknown res2.typ then res3.typ else res2.typ in
9931020
let slot =
9941021
Option.value ~default:Undet
@@ -1002,7 +1029,7 @@ let translate ?ident_label (expr : expression) : result =
10021029
array_opt_of_code = None;
10031030
}
10041031
| [%expr if [%e? expr1] then [%e? expr2]] ->
1005-
let res2 = transl ?ident_label ~proj_in_scope expr2 in
1032+
let res2 = loop ?ident_label ~proj_in_scope expr2 in
10061033
{
10071034
vbs = res2.vbs;
10081035
typ = Code;
@@ -1014,7 +1041,7 @@ let translate ?ident_label (expr : expression) : result =
10141041
let fields, cases =
10151042
List.unzip
10161043
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
1017-
let res = transl ?ident_label ~proj_in_scope pc_rhs in
1044+
let res = loop ?ident_label ~proj_in_scope pc_rhs in
10181045
((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
10191046
in
10201047
let vbss, typs, slots = List.unzip3 fields in
@@ -1039,13 +1066,13 @@ let translate ?ident_label (expr : expression) : result =
10391066
@@ Location.error_extensionf ~loc
10401067
"ppx_ocannl %%cd: let-in: local let-bindings not implemented yet";
10411068
}
1042-
(* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=transl
1043-
binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, transl body)} *)
1069+
(* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=loop
1070+
binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, loop body)} *)
10441071
| { pexp_desc = Pexp_open (decl, expr1); _ } ->
1045-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
1072+
let res1 = loop ?ident_label ~proj_in_scope expr1 in
10461073
{ res1 with expr = { expr with pexp_desc = Pexp_open (decl, res1.expr) } }
10471074
| { pexp_desc = Pexp_letmodule (name, module_expr, expr1); _ } ->
1048-
let res1 = transl ?ident_label ~proj_in_scope expr1 in
1075+
let res1 = loop ?ident_label ~proj_in_scope expr1 in
10491076
{ res1 with expr = { expr with pexp_desc = Pexp_letmodule (name, module_expr, res1.expr) } }
10501077
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_operator op_ident ->
10511078
{
@@ -1055,7 +1082,7 @@ let translate ?ident_label (expr : expression) : result =
10551082
}
10561083
| _ -> { default_result with typ = Unknown }
10571084
in
1058-
transl ?ident_label ~proj_in_scope:false expr
1085+
transl ?ident_label ~proj_in_scope:false ~bad_pun_hints:(Set.empty (module String)) expr
10591086

10601087
let translate ?ident_label expr =
10611088
let res = translate ?ident_label expr in

lib/ppx_shared.ml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,29 @@ let pat2string pat =
3030
in
3131
string_expr ~loc:pat.ppat_loc @@ loop pat
3232

33+
let collect_pat_idents pat =
34+
let one = Set.singleton (module String) in
35+
let none = Set.empty (module String) in
36+
let rec loop pat =
37+
let all pats = Set.union_list (module String) @@ List.map ~f:loop pats in
38+
match pat.ppat_desc with
39+
| Ppat_open (_, pat) | Ppat_lazy pat | Ppat_constraint (pat, _) -> loop pat
40+
| Ppat_alias (_, ident) -> one ident.txt
41+
| Ppat_var ident -> one ident.txt
42+
| Ppat_any -> none
43+
| Ppat_variant (_, None) -> none
44+
| Ppat_variant (_, Some pat) -> loop pat
45+
| Ppat_constant _ -> none
46+
| Ppat_tuple pats | Ppat_array pats -> all pats
47+
| Ppat_construct (_, None) -> none
48+
| Ppat_construct (_, Some (_, pat)) -> loop pat
49+
| Ppat_interval (_, _) -> none
50+
| Ppat_record (lpats, _) -> all @@ List.map ~f:snd lpats
51+
| Ppat_or (p1, p2) -> all [ p1; p2 ]
52+
| Ppat_type _ | Ppat_unpack _ | Ppat_exception _ | Ppat_extension _ -> none
53+
in
54+
loop pat
55+
3356
let expr2string_or_empty expr =
3457
let rec lident = function
3558
| Lident s -> s

0 commit comments

Comments
 (0)