diff --git a/cpp2rust/converter/converter.cpp b/cpp2rust/converter/converter.cpp index c0e7f693..12d6898d 100644 --- a/cpp2rust/converter/converter.cpp +++ b/cpp2rust/converter/converter.cpp @@ -2997,12 +2997,64 @@ bool Converter::VisitCXXStdInitializerListExpr( return false; } +std::string Converter::GetArrayDefaultAsString(clang::QualType qual_type) { + if (auto *array_type = clang::dyn_cast(qual_type)) { + auto size_as_string = GetNumAsString(array_type->getSize()); + auto element_type = array_type->getElementType(); + auto element_type_as_string = GetDefaultAsString(element_type); + return std::format("[{}; {}]", element_type_as_string, + size_as_string.c_str()); + } + if (auto *array_type = + clang::dyn_cast(qual_type)) { + return GetDefaultAsString(array_type->getElementType()); + } + if (Mapper::ToString(qual_type).contains("std::array")) { + assert(GetTemplateArgs(qual_type).has_value()); + auto template_args = *GetTemplateArgs(qual_type); + assert(template_args.size() == 2); + auto array_size = template_args[1]; + unsigned size = 0; + switch (array_size.getKind()) { + case clang::TemplateArgument::Expression: { + auto array_size_expr = array_size.getAsExpr(); + assert(array_size_expr && !array_size_expr->isValueDependent()); + clang::Expr::EvalResult result; + ENSURE(array_size_expr->EvaluateAsInt(result, ctx_)); + size = result.Val.getInt().getZExtValue(); + break; + } + case clang::TemplateArgument::Integral: { + size = array_size.getAsIntegral().getZExtValue(); + break; + } + default: + assert(0 && "Unsupported array size kind"); + break; + } + return std::format( + "std::array::from_fn::<_, {}, _>(|_| Default::default()).to_vec()", + size); + } + return {}; +} + std::string Converter::GetDefaultAsString(clang::QualType qual_type) { if (IsVaListType(qual_type)) { computed_expr_type_ = ComputedExprType::FreshValue; return "VaList::default()"; } + if (auto arr = GetArrayDefaultAsString(qual_type); !arr.empty()) { + computed_expr_type_ = ComputedExprType::FreshValue; + return arr; + } + + if (auto init = Mapper::MapInitializer(qual_type); !init.empty()) { + computed_expr_type_ = ComputedExprType::FreshValue; + return init; + } + if (qual_type->isPointerType()) { auto pointee = qual_type->getPointeeType(); if (pointee->isFunctionType()) { @@ -3014,54 +3066,7 @@ std::string Converter::GetDefaultAsString(clang::QualType qual_type) { } computed_expr_type_ = ComputedExprType::FreshValue; - - if (auto *array_type = clang::dyn_cast(qual_type)) { - auto size_as_string = GetNumAsString(array_type->getSize()); - auto element_type = array_type->getElementType(); - auto element_type_as_string = GetDefaultAsString(element_type); - return std::format("[{}; {}]", element_type_as_string, - size_as_string.c_str()); - } else if (auto *array_type = - clang::dyn_cast(qual_type)) { - return GetDefaultAsString(array_type->getElementType()); - } else { - auto qual_type_str = Mapper::ToString(qual_type); - if (qual_type_str == "struct std::pair") { - auto template_args = *GetTemplateArgs(qual_type); - auto first_type = template_args[0].getAsType(); - auto second_type = template_args[1].getAsType(); - return std::format("({}, {})", GetDefaultAsString(first_type), - GetDefaultAsString(second_type)); - } else if (qual_type_str.contains("std::array")) { - assert(GetTemplateArgs(qual_type).has_value()); - auto template_args = *GetTemplateArgs(qual_type); - assert(template_args.size() == 2); - auto array_size = template_args[1]; - unsigned size = 0; - switch (array_size.getKind()) { - case clang::TemplateArgument::Expression: { - auto array_size_expr = array_size.getAsExpr(); - assert(array_size_expr && !array_size_expr->isValueDependent()); - clang::Expr::EvalResult result; - ENSURE(array_size_expr->EvaluateAsInt(result, ctx_)); - size = result.Val.getInt().getZExtValue(); - break; - } - case clang::TemplateArgument::Integral: { - size = array_size.getAsIntegral().getZExtValue(); - break; - } - default: - assert(0 && "Unsupported array size kind"); - break; - } - return std::format( - "std::array::from_fn::<_, {}, _>(|_| Default::default()).to_vec()", - size); - } else { - return GetDefaultAsStringFallback(qual_type); - } - } + return GetDefaultAsStringFallback(qual_type); } std::string Converter::GetDefaultAsStringFallback(clang::QualType qual_type) { diff --git a/cpp2rust/converter/converter.h b/cpp2rust/converter/converter.h index a5d8822f..38b773d6 100644 --- a/cpp2rust/converter/converter.h +++ b/cpp2rust/converter/converter.h @@ -400,6 +400,8 @@ class Converter : public clang::RecursiveASTVisitor { virtual std::string GetDefaultAsString(clang::QualType qual_type); + virtual std::string GetArrayDefaultAsString(clang::QualType qual_type); + virtual std::string GetDefaultAsStringFallback(clang::QualType qual_type); virtual std::string ConvertVarDefaultInit(clang::QualType qual_type); diff --git a/cpp2rust/converter/mapper.cpp b/cpp2rust/converter/mapper.cpp index a0c5ae8c..17c47b1c 100644 --- a/cpp2rust/converter/mapper.cpp +++ b/cpp2rust/converter/mapper.cpp @@ -625,6 +625,18 @@ std::string Map(clang::QualType qual_type) { return {}; } +std::string MapInitializer(clang::QualType qual_type) { + auto type_str = ToString(qual_type); + auto [rule, subs] = search(types_, type_str, GetTypeMapKey(type_str)); + if (rule && !rule->initializer.empty()) { + for (auto &ty : subs) { + ty = mapTypeStringRecursive(ty); + } + return instantiateTgt(subs, rule->initializer); + } + return {}; +} + bool MapsToPointer(clang::QualType qual_type) { auto rule = search(qual_type); return rule && rule->type_info.is_pointer(); diff --git a/cpp2rust/converter/mapper.h b/cpp2rust/converter/mapper.h index 942a39a9..cd4523c6 100644 --- a/cpp2rust/converter/mapper.h +++ b/cpp2rust/converter/mapper.h @@ -28,6 +28,7 @@ bool Contains(clang::QualType qual_type); bool Contains(const clang::Expr *expr); std::string Map(clang::QualType qual_type); +std::string MapInitializer(clang::QualType qual_type); const TranslationRule::ExprRule *GetExprRule(const clang::Expr *expr); std::string MapFunctionName(const clang::FunctionDecl *decl); std::string InstantiateTemplate(const clang::Expr *expr, unsigned n); diff --git a/cpp2rust/converter/models/converter_refcount.cpp b/cpp2rust/converter/models/converter_refcount.cpp index 3bc4b98f..c2bd9e18 100644 --- a/cpp2rust/converter/models/converter_refcount.cpp +++ b/cpp2rust/converter/models/converter_refcount.cpp @@ -1623,11 +1623,36 @@ bool ConverterRefCount::VisitCXXDefaultArgExpr(clang::CXXDefaultArgExpr *expr) { return Converter::VisitCXXDefaultArgExpr(expr); } +std::string +ConverterRefCount::GetArrayDefaultAsString(clang::QualType qual_type) { + if (auto *array_type = clang::dyn_cast(qual_type)) { + const auto &size = array_type->getSize(); + auto size_as_string = GetNumAsString(size); + auto element_type = array_type->getElementType(); + PushConversionKind push(*this, ConversionKind::Unboxed); + auto element_type_as_string = ToString(element_type); + auto default_as_string = GetDefaultAsString(element_type); + return std::format("(0..{}).map(|_| {}).collect::>()", + size_as_string.c_str(), default_as_string, + element_type_as_string); + } + return Converter::GetArrayDefaultAsString(qual_type); +} + std::string ConverterRefCount::GetDefaultAsString(clang::QualType qual_type) { if (IsVaListType(qual_type)) { return BoxValue("VaList::default()"); } + if (auto arr = GetArrayDefaultAsString(qual_type); !arr.empty()) { + return BoxValue(std::move(arr)); + } + + if (auto init = Mapper::MapInitializer(qual_type); !init.empty()) { + computed_expr_type_ = ComputedExprType::FreshValue; + return BoxValue(std::move(init)); + } + std::string ret; if (qual_type->isPointerType()) { auto pointee_type = qual_type->getPointeeType(); @@ -1641,26 +1666,6 @@ std::string ConverterRefCount::GetDefaultAsString(clang::QualType qual_type) { ret = std::format("Ptr::<{}>::null()", ConvertPointeeType(qual_type)); } } - } else if (auto *array_type = - clang::dyn_cast(qual_type)) { - const auto &size = array_type->getSize(); - auto size_as_string = GetNumAsString(size); - auto element_type = array_type->getElementType(); - PushConversionKind push(*this, ConversionKind::Unboxed); - auto element_type_as_string = ToString(element_type); - auto default_as_string = GetDefaultAsString(element_type); - ret = std::format("(0..{}).map(|_| {}).collect::>()", - size_as_string.c_str(), default_as_string, - element_type_as_string); - } else if (Mapper::ToString(qual_type) == "struct std::pair") { - auto template_args = *GetTemplateArgs(qual_type); - auto first_type = template_args[0].getAsType(); - auto second_type = template_args[1].getAsType(); - ret = std::format("(Rc::new(RefCell::new({})), Rc::new(RefCell::new({})))", - GetDefaultAsString(first_type), - GetDefaultAsString(second_type)); - } else if (Mapper::ToString(qual_type).contains("std::array")) { - ret = Converter::GetDefaultAsString(qual_type); } else { return Converter::GetDefaultAsString(qual_type); } diff --git a/cpp2rust/converter/models/converter_refcount.h b/cpp2rust/converter/models/converter_refcount.h index d890586a..c5eafadf 100644 --- a/cpp2rust/converter/models/converter_refcount.h +++ b/cpp2rust/converter/models/converter_refcount.h @@ -124,6 +124,8 @@ class ConverterRefCount final : public Converter { std::string GetDefaultAsString(clang::QualType qual_type) override; + std::string GetArrayDefaultAsString(clang::QualType qual_type) override; + void ConvertEqualsNullPtr(clang::Expr *expr) override; std::string GetDefaultAsStringFallback(clang::QualType qual_type) override; diff --git a/rules/pair/ir_unsafe.json b/rules/pair/ir_unsafe.json index f852a6e7..c4892eb0 100644 --- a/rules/pair/ir_unsafe.json +++ b/rules/pair/ir_unsafe.json @@ -459,7 +459,7 @@ } }, "t1": { - "init": "(T1::default(), T2::default())", + "init": "<(T1, T2)>::default()", "type": "(T1, T2)" } } diff --git a/rules/pair/tgt_unsafe.rs b/rules/pair/tgt_unsafe.rs index 5ac4de00..cae891fa 100644 --- a/rules/pair/tgt_unsafe.rs +++ b/rules/pair/tgt_unsafe.rs @@ -8,7 +8,7 @@ struct T1; struct T2; fn types() { - let t1: (T1, T2) = (T1::default(), T2::default()); + let t1: (T1, T2) = <(T1, T2)>::default(); } unsafe fn f1(a0: (T1, T2)) -> T2 { diff --git a/rules/stdio/ir_unsafe.json b/rules/stdio/ir_unsafe.json index b1f04054..1301ed25 100644 --- a/rules/stdio/ir_unsafe.json +++ b/rules/stdio/ir_unsafe.json @@ -756,7 +756,7 @@ } }, "t1": { - "init": "Default::default()", + "init": "std::ptr::null_mut()", "type": "*mut ::std::fs::File", "is_unsafe_pointer": true } diff --git a/rules/stdio/tgt_unsafe.rs b/rules/stdio/tgt_unsafe.rs index 98502092..771a8515 100644 --- a/rules/stdio/tgt_unsafe.rs +++ b/rules/stdio/tgt_unsafe.rs @@ -5,7 +5,7 @@ use libcc2rs::*; use std::io::prelude::*; fn types() -> Result<(), Box> { - let t1: *mut ::std::fs::File = Default::default(); + let t1: *mut ::std::fs::File = std::ptr::null_mut(); Ok(()) } diff --git a/tests/unit/out/refcount/fflush_null.rs b/tests/unit/out/refcount/fflush_null.rs index a7c883cf..61c6c9c9 100644 --- a/tests/unit/out/refcount/fflush_null.rs +++ b/tests/unit/out/refcount/fflush_null.rs @@ -10,8 +10,7 @@ pub fn main() { std::process::exit(main_0()); } fn main_0() -> i32 { - let file_ptr: Value> = - Rc::new(RefCell::new(Ptr::<::std::fs::File>::null())); + let file_ptr: Value> = Rc::new(RefCell::new(Ptr::null())); return if !(*file_ptr.borrow()).is_null() { match (*file_ptr.borrow()).with_mut(|v| v.sync_all()) { Ok(_) => 0, diff --git a/tests/unit/out/refcount/fn_ptr_stdlib_compare.rs b/tests/unit/out/refcount/fn_ptr_stdlib_compare.rs index 53c82c00..cd56a61f 100644 --- a/tests/unit/out/refcount/fn_ptr_stdlib_compare.rs +++ b/tests/unit/out/refcount/fn_ptr_stdlib_compare.rs @@ -69,7 +69,7 @@ fn main_0() -> i32 { let _arg0: AnyPtr = AnyPtr::default(); let _arg1: u64 = 0_u64; let _arg2: u64 = 0_u64; - let _arg3: Ptr<::std::fs::File> = Ptr::<::std::fs::File>::null(); + let _arg3: Ptr<::std::fs::File> = Ptr::null(); (*(*f3.borrow()))(_arg0, _arg1, _arg2, _arg3) }) == 22_u64) ); @@ -234,7 +234,7 @@ fn main_0() -> i32 { let _arg0: AnyPtr = AnyPtr::default(); let _arg1: u64 = 0_u64; let _arg2: u64 = 0_u64; - let _arg3: Ptr<::std::fs::File> = Ptr::<::std::fs::File>::null(); + let _arg3: Ptr<::std::fs::File> = Ptr::null(); (*(*g3.borrow()))(_arg0, _arg1, _arg2, _arg3) }) == 33_u64) ); diff --git a/tests/unit/out/refcount/global_without_initializer.rs b/tests/unit/out/refcount/global_without_initializer.rs index 52ae0d1e..08d8aa2b 100644 --- a/tests/unit/out/refcount/global_without_initializer.rs +++ b/tests/unit/out/refcount/global_without_initializer.rs @@ -23,8 +23,7 @@ thread_local!( pub static s: Value> = Rc::new(RefCell::new(Ptr::::null())); ); thread_local!( - pub static file: Value> = - Rc::new(RefCell::new(Ptr::<::std::fs::File>::null())); + pub static file: Value> = Rc::new(RefCell::new(Ptr::null())); ); thread_local!( pub static size: Value = >::default();