Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 123 additions & 79 deletions cpp2rust/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,11 @@ bool IsPointerType(clang::QualType qual_type) {
->getCanonicalTypeInternal()));
}

bool Converter::RecordDerivesDefault(const clang::CXXRecordDecl *decl) {
if (GetUserDefinedDefaultConstructor(decl)) {
return false;
bool Converter::RecordDerivesDefault(const clang::RecordDecl *decl) {
if (auto cxx_decl = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
if (GetUserDefinedDefaultConstructor(cxx_decl)) {
return false;
}
}

for (auto f : decl->fields()) {
Expand All @@ -546,7 +548,7 @@ bool Converter::RecordDerivesDefault(const clang::CXXRecordDecl *decl) {
return true;
}

static bool recordDerivesCopy(const clang::CXXRecordDecl *decl) {
static bool recordDerivesCopy(const clang::RecordDecl *decl) {
for (auto f : decl->fields()) {
// Records that contain std::vector, std::array, std::string or anything
// that is translated to Vec<>, do not derive Copy
Expand All @@ -569,8 +571,8 @@ static bool recordDerivesCopy(const clang::CXXRecordDecl *decl) {
}
}

// Look recursively into fields that are CXXRecordDecl
if (auto field_record = f->getType()->getAsCXXRecordDecl()) {
// Look recursively into fields that are RecordDecl
if (auto field_record = f->getType()->getAsRecordDecl()) {
if (!recordDerivesCopy(field_record)) {
return false;
}
Expand All @@ -580,6 +582,109 @@ static bool recordDerivesCopy(const clang::CXXRecordDecl *decl) {
return true;
}

bool Converter::VisitRecordDecl(clang::RecordDecl *decl) {
decl->dumpColor();

// VisitCXXRecordDecl already visited the record
if (clang::isa<clang::CXXRecordDecl>(decl)) {
return true;
}

if (!decl->isCompleteDefinition()) {
return false;
}

if (!record_decls_.insert(GetID(decl)).second) {
return false;
}

Mapper::AddRuleForUserDefinedType(decl);
EmitRustStruct(decl);

return false;
}

void Converter::EmitRustStruct(clang::RecordDecl *decl) {
// Enums and static variables. In rust they live outside the record
for (auto *d : decl->decls()) {
if (auto *enum_decl = llvm::dyn_cast<clang::EnumDecl>(d)) {
VisitEnumDecl(enum_decl);
}
if (auto *var_decl = clang::dyn_cast<clang::VarDecl>(d)) {
VisitVarDecl(var_decl);
}
}

// Inner records. In rust they live outside the record
for (auto *d : decl->decls()) {
if (auto *nested = clang::dyn_cast<clang::RecordDecl>(d)) {
if (!nested->isImplicit()) {
inner_structs_[GetID(nested)] = GetRecordName(nested);
if (auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(nested)) {
VisitCXXRecordDecl(cxx);
} else {
VisitRecordDecl(nested);
}
}
}
}

// Derived traits
StrCat("#[derive(");
for (auto *attr : GetStructAttributes(decl)) {
StrCat(attr, ",");
}
StrCat(")]");

// Fields
auto access = clang::dyn_cast<clang::CXXRecordDecl>(decl)
? AccessSpecifierAsString(decl->getAccess())
: keyword::kPub;
StrCat(access, keyword::kStruct, GetRecordName(decl),
token::kOpenCurlyBracket);
for (auto *field : decl->fields()) {
VisitFieldDecl(field);
}
StrCat(token::kCloseCurlyBracket);

// C++ method decls
if (auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
auto struct_name = GetRecordName(cxx);

ConvertCXXMethodDecls(
cxx, std::string(keyword::kImpl) + ' ' + struct_name,
[](const auto *method) {
return !method->isImplicit() &&
!(method->getDefinition() &&
method->getDefinition()->isDefaulted()) &&
(method->isThisDeclarationADefinition() ||
clang::isa<clang::CXXConstructorDecl>(method)) &&
!method->isVirtual() &&
!clang::isa<clang::CXXDestructorDecl>(method);
});

if (cxx->bases_begin() != cxx->bases_end()) {
ConvertCXXMethodDecls(
cxx,
std::format("{} impl {} for {}", keyword_unsafe_,
GetUnsafeTypeAsString(cxx->bases_begin()->getType()),
struct_name),
[](const auto *method) {
return !method->isImplicit() && method->isVirtual();
});
}
}

// Traits
if (auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
AddOrdTrait(cxx);
AddCloneTrait(cxx);
AddDropTrait(cxx);
AddDefaultTrait(cxx);
}
AddByteReprTrait(decl);
}

bool Converter::VisitCXXRecordDecl(clang::CXXRecordDecl *decl) {
if (clang::isa<clang::ClassTemplateSpecializationDecl>(decl)) {
materializeTemplateSpecialization(decl);
Expand Down Expand Up @@ -623,74 +728,7 @@ bool Converter::VisitCXXRecordDecl(clang::CXXRecordDecl *decl) {
}
}

auto struct_name = GetRecordName(decl);

// First visit the nested enums
for (auto d : decl->decls()) {
if (auto enum_decl = llvm::dyn_cast<clang::EnumDecl>(d)) {
VisitEnumDecl(enum_decl);
}
}

for (auto *decl : decl->decls()) {
if (auto var_decl = clang::dyn_cast<clang::VarDecl>(decl)) {
VisitVarDecl(var_decl);
}
}

auto nested = GetNestedStructs(decl);
for (auto *record_decl : nested) {
auto ID = GetID(record_decl);
inner_structs_[ID] = GetRecordName(record_decl);
VisitCXXRecordDecl(record_decl);
}

StrCat(token::kHash, token::kOpenBracket, "derive", token::kOpenParen);
bool derives_default = RecordDerivesDefault(decl);

for (auto *struct_attr : GetStructAttributes(decl, derives_default)) {
StrCat(struct_attr, token::kComma);
}
StrCat(token::kCloseParen, token::kCloseBracket);

auto access_specifier = decl->getAccess();
StrCat(AccessSpecifierAsString(access_specifier), keyword::kStruct,
struct_name, token::kOpenCurlyBracket);
for (auto *field : decl->fields()) {
VisitFieldDecl(field);
}
StrCat(token::kCloseCurlyBracket);

ConvertCXXMethodDecls(
decl, std::string(keyword::kImpl) + ' ' + struct_name,
[](const auto *method) {
return !method->isImplicit() &&
!(method->getDefinition() &&
method->getDefinition()->isDefaulted()) &&
(method->isThisDeclarationADefinition() ||
clang::isa<clang::CXXConstructorDecl>(method)) &&
!method->isVirtual() &&
!clang::isa<clang::CXXDestructorDecl>(method);
});

AddOrdTrait(decl);
AddCloneTrait(decl);
AddDropTrait(decl);
if (!derives_default) {
AddDefaultTrait(decl);
}
AddByteReprTrait(decl);

if (decl->bases_begin() != decl->bases_end()) {
ConvertCXXMethodDecls(
decl,
std::format("{} impl {} for {}", keyword_unsafe_,
GetUnsafeTypeAsString(decl->bases_begin()->getType()),
struct_name),
[](const auto *method) {
return !method->isImplicit() && method->isVirtual();
});
}
EmitRustStruct(decl);
} else {
// FIXME: improve error handling
assert(0 && "unsupported union");
Expand Down Expand Up @@ -2797,15 +2835,18 @@ std::string Converter::GetRecordName(const clang::NamedDecl *decl) const {
}

std::vector<const char *>
Converter::GetStructAttributes(const clang::CXXRecordDecl *decl,
bool &out_impl_default) {
Converter::GetStructAttributes(const clang::RecordDecl *decl) {
std::vector<const char *> struct_attrs = {};

if (recordDerivesCopy(decl)) {
struct_attrs.emplace_back("Copy");
}

if (!decl->defaultedCopyConstructorIsDeleted()) {
if (auto cxx_decl = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
if (!cxx_decl->defaultedCopyConstructorIsDeleted()) {
struct_attrs.emplace_back("Clone");
}
} else /* RecordDecl */ {
struct_attrs.emplace_back("Clone");
}

Expand Down Expand Up @@ -3106,11 +3147,14 @@ void Converter::AddCloneTrait(const clang::CXXRecordDecl *decl) {}
void Converter::AddDropTrait(const clang::CXXRecordDecl *decl) {}

void Converter::AddDefaultTrait(const clang::CXXRecordDecl *decl) {
if (RecordDerivesDefault(decl)) {
return;
}
auto struct_name = GetRecordName(decl);
StrCat(std::format("impl Default for {}", struct_name),
token::kOpenCurlyBracket, "fn default() -> Self",
token::kOpenCurlyBracket);
if (auto default_ctor = GetUserDefinedDefaultConstructor(decl)) {
if (auto *default_ctor = GetUserDefinedDefaultConstructor(decl)) {
StrCat(keyword_unsafe_, token::kOpenCurlyBracket);
Convert(clang::CXXConstructExpr::Create(
ctx_, ctx_.getCanonicalTagType(decl), clang::SourceLocation(),
Expand All @@ -3133,7 +3177,7 @@ void Converter::AddDefaultTrait(const clang::CXXRecordDecl *decl) {
StrCat(token::kCloseCurlyBracket, token::kCloseCurlyBracket);
}

void Converter::AddByteReprTrait(const clang::CXXRecordDecl *decl) {}
void Converter::AddByteReprTrait(const clang::RecordDecl *decl) {}

void Converter::ConvertUnsignedArithBinaryOperator(clang::BinaryOperator *op,
clang::Expr *expr) {
Expand Down
10 changes: 7 additions & 3 deletions cpp2rust/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,12 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

virtual bool ConvertLambdaVarDecl(clang::VarDecl *decl);

bool VisitRecordDecl(clang::RecordDecl *decl);

virtual bool VisitCXXRecordDecl(clang::CXXRecordDecl *decl);

void EmitRustStruct(clang::RecordDecl *decl);

virtual bool VisitCXXMethodDecl(clang::CXXMethodDecl *decl);
virtual std::string GetSelfMaybeWithMut(const clang::CXXMethodDecl *decl);

Expand Down Expand Up @@ -355,7 +359,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
virtual std::string GetRecordName(const clang::NamedDecl *decl) const;

virtual std::vector<const char *>
GetStructAttributes(const clang::CXXRecordDecl *decl, bool &out_impl_default);
GetStructAttributes(const clang::RecordDecl *decl);

virtual std::string GetUnsafeTypeAsString(clang::QualType qual_type) const;

Expand Down Expand Up @@ -410,7 +414,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

virtual void AddDefaultTrait(const clang::CXXRecordDecl *decl);

virtual void AddByteReprTrait(const clang::CXXRecordDecl *decl);
virtual void AddByteReprTrait(const clang::RecordDecl *decl);

virtual void
ConvertUnsignedArithBinaryOperator(clang::BinaryOperator *binary_operator,
Expand Down Expand Up @@ -453,7 +457,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

virtual bool IsReferenceType(const clang::Expr *expr) const;

virtual bool RecordDerivesDefault(const clang::CXXRecordDecl *decl);
virtual bool RecordDerivesDefault(const clang::RecordDecl *decl);

std::string *rs_code_;
clang::ASTContext &ctx_;
Expand Down
7 changes: 3 additions & 4 deletions cpp2rust/converter/models/converter_refcount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ void ConverterRefCount::AddDropTrait(const clang::CXXRecordDecl *decl) {
StrCat("}");
}

void ConverterRefCount::AddByteReprTrait(const clang::CXXRecordDecl *decl) {
void ConverterRefCount::AddByteReprTrait(const clang::RecordDecl *decl) {
auto struct_name = GetRecordName(decl);
StrCat(std::format("impl ByteRepr for {}", struct_name),
token::kOpenCurlyBracket, token::kCloseCurlyBracket);
Expand Down Expand Up @@ -1604,11 +1604,10 @@ ConverterRefCount::ConvertVarDefaultInit(clang::QualType qual_type) {
}

std::vector<const char *>
ConverterRefCount::GetStructAttributes(const clang::CXXRecordDecl *decl,
bool &out_impl_default) {
ConverterRefCount::GetStructAttributes(const clang::RecordDecl *decl) {
std::vector<const char *> attrs = {};

if (out_impl_default) {
if (RecordDerivesDefault(decl)) {
attrs.emplace_back("Default");
}
return attrs;
Expand Down
5 changes: 2 additions & 3 deletions cpp2rust/converter/models/converter_refcount.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ConverterRefCount final : public Converter {

void AddDropTrait(const clang::CXXRecordDecl *decl) override;

void AddByteReprTrait(const clang::CXXRecordDecl *decl) override;
void AddByteReprTrait(const clang::RecordDecl *decl) override;

void AddDefaultTrait(const clang::CXXRecordDecl *decl) override;

Expand Down Expand Up @@ -121,8 +121,7 @@ class ConverterRefCount final : public Converter {
std::string ConvertVarDefaultInit(clang::QualType qual_type) override;

std::vector<const char *>
GetStructAttributes(const clang::CXXRecordDecl *decl,
bool &out_impl_default) override;
GetStructAttributes(const clang::RecordDecl *decl) override;

bool MayCauseBorrowMutError(const clang::Expr *lhs, const clang::Expr *rhs);

Expand Down
Loading
Loading