diff --git a/cpp2rust/converter/converter.cpp b/cpp2rust/converter/converter.cpp index c0e7f693..67d70818 100644 --- a/cpp2rust/converter/converter.cpp +++ b/cpp2rust/converter/converter.cpp @@ -3700,6 +3700,11 @@ std::string Converter::ConvertPlaceholder(clang::Expr *expr, clang::Expr *arg, if (ph_ctx.needs_object_receiver()) { Buffer buf(*this); + PushExplicitAutoref autoref( + *this, + ph_ctx.is_index_base + ? std::optional(ph_ctx.access == TranslationRule::Access::kWrite) + : std::nullopt); ConvertDeref(arg); return std::move(buf).str(); } @@ -3774,6 +3779,7 @@ std::string Converter::ConvertIRFragment( .maps_to_rust_ptr = Mapper::MapsToPointer(arg->getType()), .declared_in_rule_as_rust_ptr = Mapper::ParamIsPointer(GetCalleeOrExpr(expr), arg_idx), + .is_index_base = ph->is_index_base, }; result += ConvertPlaceholder(expr, arg, ph_ctx); } else if (auto *mc = diff --git a/cpp2rust/converter/converter.h b/cpp2rust/converter/converter.h index a5d8822f..83ce8850 100644 --- a/cpp2rust/converter/converter.h +++ b/cpp2rust/converter/converter.h @@ -185,6 +185,7 @@ class Converter : public clang::RecursiveASTVisitor { bool is_cpp_ptr; bool maps_to_rust_ptr; bool declared_in_rule_as_rust_ptr; + bool is_index_base; bool needs_materialization() const { return materialize_ctx && materialize_idx >= 0 && diff --git a/cpp2rust/converter/translation_rule.cpp b/cpp2rust/converter/translation_rule.cpp index 06f21921..27cb1ddb 100644 --- a/cpp2rust/converter/translation_rule.cpp +++ b/cpp2rust/converter/translation_rule.cpp @@ -45,7 +45,11 @@ Access ParseAccessJSON(llvm::StringRef value) { PlaceholderFragment ParsePlaceholderFragmentJSON(const llvm::json::Object &obj) { auto access = obj.getString("access"); - return {(unsigned)*obj.getInteger("arg"), ParseAccessJSON(*access)}; + return { + (unsigned)*obj.getInteger("arg"), + ParseAccessJSON(*access), + obj.getBoolean("is_index_base").value_or(false), + }; } std::vector ParseBodyFragmentsJSON(const llvm::json::Array &arr); diff --git a/cpp2rust/converter/translation_rule.h b/cpp2rust/converter/translation_rule.h index 91ac72f9..50ea45a4 100644 --- a/cpp2rust/converter/translation_rule.h +++ b/cpp2rust/converter/translation_rule.h @@ -26,6 +26,7 @@ enum class Access { kRead, kWrite, kMove }; struct PlaceholderFragment { unsigned n; // "a0", "a1", ... Access access; + bool is_index_base = false; void dump() const; }; diff --git a/rule-preprocessor/src/ir.rs b/rule-preprocessor/src/ir.rs index 9a0c5f58..362aeb07 100644 --- a/rule-preprocessor/src/ir.rs +++ b/rule-preprocessor/src/ir.rs @@ -63,17 +63,17 @@ pub struct FnIr { impl FnIr { /// Find the next unvisited placeholder for `param`, mark it visited, - /// and if it was "unknown", patch it with the given access. - /// Searches inside MethodCall bodies recursively. + /// and apply `patch` to it. Searches inside MethodCall bodies + /// recursively. pub fn resolve_next_param( &mut self, param: &str, - access: Access, visited: &mut HashMap, + patch: impl Fn(&mut PlaceholderInner), ) { let n = visited.entry(param.to_string()).or_insert(0); let nth = std::mem::replace(n, *n + 1); - resolve_nth_unknown(&mut self.body, param, access, nth); + resolve_nth_unknown(&mut self.body, param, nth, patch); } pub fn has_unknowns(&self) -> bool { @@ -139,6 +139,8 @@ pub enum Access { pub struct PlaceholderInner { pub arg: i32, pub access: Access, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub is_index_base: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -157,15 +159,21 @@ impl BodyFragment { } } -/// Resolve the nth Unknown placeholder for `param` in a body fragment list. -fn resolve_nth_unknown(body: &mut [BodyFragment], param: &str, access: Access, nth: usize) { +/// Find the nth occurrence of `param` in a body fragment list and apply +/// `patch` to it. +fn resolve_nth_unknown( + body: &mut [BodyFragment], + param: &str, + nth: usize, + patch: impl Fn(&mut PlaceholderInner), +) { let mut count = 0; fn resolve( body: &mut [BodyFragment], param: &str, - access: Access, nth: usize, count: &mut usize, + patch: &impl Fn(&mut PlaceholderInner), ) -> bool { for frag in body { match frag { @@ -173,18 +181,16 @@ fn resolve_nth_unknown(body: &mut [BodyFragment], param: &str, access: Access, n if placeholder.arg == param[1..].parse().unwrap_or(0) => { if *count == nth { - if placeholder.access == Access::Unknown { - placeholder.access = access; - } + patch(placeholder); return true; } *count += 1; } BodyFragment::MethodCall { method_call } => { - if resolve(&mut method_call.receiver, param, access, nth, count) { + if resolve(&mut method_call.receiver, param, nth, count, patch) { return true; } - if resolve(&mut method_call.body, param, access, nth, count) { + if resolve(&mut method_call.body, param, nth, count, patch) { return true; } } @@ -193,7 +199,7 @@ fn resolve_nth_unknown(body: &mut [BodyFragment], param: &str, access: Access, n } false } - resolve(body, param, access, nth, &mut count); + resolve(body, param, nth, &mut count, &patch); } // A rule file's IR: mix of function rules (f1, f2, ...) and type rules (t1, t2, ...) @@ -211,7 +217,6 @@ pub type FileIr = BTreeMap; /// All IR for all rule files. pub struct RulesIR { pub all_ir: HashMap, - pub has_unknowns: bool, pub crate_root: PathBuf, } diff --git a/rule-preprocessor/src/semantic.rs b/rule-preprocessor/src/semantic.rs index 0fd4fd7a..6dbef216 100644 --- a/rule-preprocessor/src/semantic.rs +++ b/rule-preprocessor/src/semantic.rs @@ -10,10 +10,6 @@ pub struct SemanticAnalysis; impl SemanticAnalysis { pub fn run(ir: RulesIR) -> RulesIR { - if !ir.has_unknowns { - return ir; - } - let args = build_rustc_args(&ir.crate_root); let mut resolver = MethodResolver { ir }; @@ -131,7 +127,6 @@ impl MethodResolver { fn resolve_fn_decl<'tcx>(&mut self, tcx: rustc_middle::ty::TyCtxt<'tcx>, f: &FnDecl<'tcx>) { if let Some(file_ir) = self.ir.all_ir.get_mut(&f.source_file) && let Some(RuleIr::Fn(fn_ir)) = file_ir.get_mut(&f.name) - && fn_ir.has_unknowns() { f.resolve_unknowns(tcx, fn_ir); } @@ -218,11 +213,29 @@ struct AstVisitor<'a, 'tcx> { } impl<'a, 'tcx> AstVisitor<'a, 'tcx> { + fn visit_expr_as_index_base(&mut self, expr: &'tcx rustc_hir::Expr<'tcx>, context: Access) { + if let Some(param) = self.expr_as_decl_ref(expr) { + self.fn_ir + .resolve_next_param(¶m, &mut self.visited, |p| { + if p.access == Access::Unknown { + p.access = context; + } + p.is_index_base = true; + }); + return; + } + self.visit_expr(expr, context); + } + fn visit_expr(&mut self, expr: &'tcx rustc_hir::Expr<'tcx>, context: Access) { // Reached an argument used inside the rule body if let Some(param) = self.expr_as_decl_ref(expr) { self.fn_ir - .resolve_next_param(¶m, context, &mut self.visited); + .resolve_next_param(¶m, &mut self.visited, |p| { + if p.access == Access::Unknown { + p.access = context; + } + }); return; } @@ -311,7 +324,11 @@ impl<'a, 'tcx> AstVisitor<'a, 'tcx> { | rustc_hir::ExprKind::Repeat(e, _) => { self.visit_expr(e, context); } - rustc_hir::ExprKind::Index(a, b, _) | rustc_hir::ExprKind::Binary(_, a, b) => { + rustc_hir::ExprKind::Index(base, idx, _) => { + self.visit_expr_as_index_base(base, context); + self.visit_expr(idx, context); + } + rustc_hir::ExprKind::Binary(_, a, b) => { self.visit_expr(a, context); self.visit_expr(b, context); } diff --git a/rule-preprocessor/src/syntactic.rs b/rule-preprocessor/src/syntactic.rs index 6f3ba7d6..46f76073 100644 --- a/rule-preprocessor/src/syntactic.rs +++ b/rule-preprocessor/src/syntactic.rs @@ -44,7 +44,6 @@ impl SyntacticAnalysis { pub fn run(crate_root: &Path) -> RulesIR { let rule_files = Self::collect_rule_files(crate_root); let mut all_ir = HashMap::new(); - let mut has_unknowns = false; for rule_file in &rule_files { let source = std::fs::read_to_string(rule_file).unwrap(); @@ -56,10 +55,6 @@ impl SyntacticAnalysis { rule_file.display() ); - has_unknowns |= file_ir.values().any(|r| match r { - RuleIr::Fn(f) => f.has_unknowns(), - RuleIr::Type(_) => false, - }); let canonical = rule_file .canonicalize() .unwrap_or_else(|_| rule_file.clone()) @@ -70,7 +65,6 @@ impl SyntacticAnalysis { RulesIR { all_ir, - has_unknowns, crate_root: crate_root.to_path_buf(), } } @@ -196,6 +190,7 @@ impl<'a> FragmentCtx<'a> { placeholder: PlaceholderInner { arg: token.text()[1..].parse().unwrap_or(0), access, + is_index_base: false, }, }); return; diff --git a/rules/string/ir_refcount.json b/rules/string/ir_refcount.json index db87be05..1cb69c54 100644 --- a/rules/string/ir_refcount.json +++ b/rules/string/ir_refcount.json @@ -10,7 +10,8 @@ { "placeholder": { "arg": 0, - "access": "read" + "access": "read", + "is_index_base": true } }, { diff --git a/rules/string/ir_unsafe.json b/rules/string/ir_unsafe.json index 8f50170b..ceafca88 100644 --- a/rules/string/ir_unsafe.json +++ b/rules/string/ir_unsafe.json @@ -10,7 +10,8 @@ { "placeholder": { "arg": 0, - "access": "read" + "access": "read", + "is_index_base": true } }, { @@ -1194,7 +1195,8 @@ { "placeholder": { "arg": 0, - "access": "write" + "access": "write", + "is_index_base": true } }, { diff --git a/rules/vector/ir_refcount.json b/rules/vector/ir_refcount.json index 0f533932..5ba5b38b 100644 --- a/rules/vector/ir_refcount.json +++ b/rules/vector/ir_refcount.json @@ -2098,7 +2098,8 @@ { "placeholder": { "arg": 0, - "access": "read" + "access": "read", + "is_index_base": true } }, { diff --git a/rules/vector/ir_unsafe.json b/rules/vector/ir_unsafe.json index f3b3ed37..d4434d85 100644 --- a/rules/vector/ir_unsafe.json +++ b/rules/vector/ir_unsafe.json @@ -2775,7 +2775,8 @@ { "placeholder": { "arg": 0, - "access": "write" + "access": "write", + "is_index_base": true } }, { diff --git a/tests/unit/implicit_autoref.cpp b/tests/unit/implicit_autoref.cpp index 78f0b7b8..8d47cec1 100644 --- a/tests/unit/implicit_autoref.cpp +++ b/tests/unit/implicit_autoref.cpp @@ -5,6 +5,8 @@ struct Holder { std::vector v; }; +static void write_through(int *p) { *p = 42; } + int main() { std::vector v; v.push_back(10); @@ -25,5 +27,8 @@ int main() { assert((*p)[1] == 30); assert(b == 40); assert((*hp).v[1] == 60); + + write_through(&p->at(0)); + assert((*p)[0] == 42); return 0; } diff --git a/tests/unit/out/refcount/implicit_autoref.rs b/tests/unit/out/refcount/implicit_autoref.rs index 61e15839..43279334 100644 --- a/tests/unit/out/refcount/implicit_autoref.rs +++ b/tests/unit/out/refcount/implicit_autoref.rs @@ -19,6 +19,10 @@ impl Clone for Holder { } } impl ByteRepr for Holder {} +pub fn write_through_0(p: Ptr) { + let p: Value> = Rc::new(RefCell::new(p)); + (*p.borrow()).write(42); +} pub fn main() { std::process::exit(main_0()); } @@ -61,5 +65,16 @@ fn main_0() -> i32 { .read()) == 60) ); + ({ + let _p: Ptr = + (((*p.borrow()).to_strong().as_pointer() as Ptr).offset(0_u64 as isize)); + write_through_0(_p) + }); + assert!( + (((((*p.borrow()).to_strong().as_pointer()) as Ptr) + .offset(0_u64 as isize) + .read()) + == 42) + ); return 0; } diff --git a/tests/unit/out/unsafe/implicit_autoref.rs b/tests/unit/out/unsafe/implicit_autoref.rs index 709b4eca..f8ffeafb 100644 --- a/tests/unit/out/unsafe/implicit_autoref.rs +++ b/tests/unit/out/unsafe/implicit_autoref.rs @@ -11,6 +11,9 @@ use std::rc::Rc; pub struct Holder { pub v: Vec, } +pub unsafe fn write_through_0(mut p: *mut i32) { + (*p) = 42; +} pub fn main() { unsafe { std::process::exit(main_0() as i32); @@ -33,5 +36,10 @@ unsafe fn main_0() -> i32 { assert!((((&mut (*p))[(1_u64) as usize]) == (30))); assert!(((b) == (40))); assert!((((&mut (*hp)).v[(1_u64) as usize]) == (60))); + (unsafe { + let _p: *mut i32 = (&mut (&mut (*p))[0_u64 as usize]); + write_through_0(_p) + }); + assert!((((&mut (*p))[(0_u64) as usize]) == (42))); return 0; }