Skip to content

Commit

Permalink
[6.0][Distributed] Diagnose missing import also for funcs in extensio…
Browse files Browse the repository at this point in the history
…ns (#72928)

* [Distributed] Diagnose missing import also for funcs in extensions

Resolves rdar://125813581

* [Distributed] Offer fixit for import Distributed when it is required
  • Loading branch information
ktoso committed Apr 14, 2024
1 parent c83caf5 commit 6d5fa1e
Show file tree
Hide file tree
Showing 20 changed files with 237 additions and 59 deletions.
2 changes: 2 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,13 @@ enum class DescriptiveDeclKind : uint8_t {
Struct,
Class,
Actor,
DistributedActor,
Protocol,
GenericEnum,
GenericStruct,
GenericClass,
GenericActor,
GenericDistributedActor,
GenericType,
Subscript,
StaticSubscript,
Expand Down
11 changes: 11 additions & 0 deletions include/swift/AST/DiagnosticEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,9 @@ namespace swift {

static const char *fixItStringFor(const FixItID id);

/// Get the best location where an 'import' fixit might be offered.
SourceLoc getBestAddImportFixItLoc(const Decl *Member) const;

/// Add a token-based replacement fix-it to the currently-active
/// diagnostic.
template <typename... ArgTypes>
Expand Down Expand Up @@ -783,6 +786,9 @@ namespace swift {
return fixItReplaceChars(L, L, "%0", {Str});
}

/// Add a fix-it suggesting to 'import' some module.
InFlightDiagnostic &fixItAddImport(StringRef ModuleName);

/// Add an insertion fix-it to the currently-active diagnostic. The
/// text is inserted immediately *after* the token specified.
InFlightDiagnostic &fixItInsertAfter(SourceLoc L, StringRef Str) {
Expand Down Expand Up @@ -1391,6 +1397,11 @@ namespace swift {
SourceLoc getDefaultDiagnosticLoc() const {
return bufferIndirectlyCausingDiagnostic;
}
SourceLoc getBestAddImportFixItLoc(const Decl *Member,
SourceFile *sourceFile) const;
SourceLoc getBestAddImportFixItLoc(const Decl *Member) const {
return getBestAddImportFixItLoc(Member, nullptr);
}
};

/// Remember details about the state of a diagnostic engine and restore them
Expand Down
6 changes: 3 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -5425,9 +5425,9 @@ ERROR(isolated_default_argument_context,none,
ERROR(conflicting_default_argument_isolation,none,
"default argument cannot be both %0 and %1",
(ActorIsolation, ActorIsolation))
ERROR(distributed_actor_needs_explicit_distributed_import,none,
"'Distributed' module not imported, required for 'distributed actor'",
())
ERROR(distributed_decl_needs_explicit_distributed_import,none,
"%kind0 declared without importing module 'Distributed'",
(const ValueDecl *))
NOTE(distributed_func_cannot_overload_on_async_only,none,
"%0 previously declared here, cannot overload distributed methods on effect only", (const ValueDecl *))
NOTE(distributed_func_other_ambiguous_overload_here,none,
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -3025,7 +3025,7 @@ class HasCircularRawValueRequest

/// Checks if the Distributed module is available.
class DistributedModuleIsAvailableRequest
: public SimpleRequest<DistributedModuleIsAvailableRequest, bool(Decl *),
: public SimpleRequest<DistributedModuleIsAvailableRequest, bool(const ValueDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
Expand All @@ -3034,7 +3034,7 @@ class DistributedModuleIsAvailableRequest
friend SimpleRequest;

// Evaluation.
bool evaluate(Evaluator &evaluator, Decl *decl) const;
bool evaluate(Evaluator &evaluator, const ValueDecl *decl) const;

public:
// Cached.
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest,
SourceLoc, bool, bool),
Uncached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DistributedModuleIsAvailableRequest,
bool(ModuleDecl *), Cached, NoLocationInfo)
bool(const ValueDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, InheritedTypeRequest,
Type(llvm::PointerUnion<const TypeDecl *, const ExtensionDecl *>,
unsigned, TypeResolutionStage),
Expand Down
25 changes: 19 additions & 6 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,23 @@ DescriptiveDeclKind Decl::getDescriptiveKind() const {
: DescriptiveDeclKind::Struct;

case DeclKind::Class: {
bool isActor = cast<ClassDecl>(this)->isActor();
return cast<ClassDecl>(this)->getGenericParams()
? (isActor ? DescriptiveDeclKind::GenericActor
: DescriptiveDeclKind::GenericClass)
: (isActor ? DescriptiveDeclKind::Actor
: DescriptiveDeclKind::Class);
auto clazz = cast<ClassDecl>(this);
bool isAnyActor = clazz->isAnyActor();
bool isGeneric = clazz->getGenericParams();

auto kind = isGeneric ? DescriptiveDeclKind::GenericClass
: DescriptiveDeclKind::Class;

if (isAnyActor) {
if (clazz->isDistributedActor()) {
kind = isGeneric ? DescriptiveDeclKind::GenericDistributedActor
: DescriptiveDeclKind::DistributedActor;
} else {
kind = isGeneric ? DescriptiveDeclKind::GenericActor
: DescriptiveDeclKind::Actor;
}
}
return kind;
}

case DeclKind::Var: {
Expand Down Expand Up @@ -332,11 +343,13 @@ StringRef Decl::getDescriptiveKindName(DescriptiveDeclKind K) {
ENTRY(Struct, "struct");
ENTRY(Class, "class");
ENTRY(Actor, "actor");
ENTRY(DistributedActor, "distributed actor");
ENTRY(Protocol, "protocol");
ENTRY(GenericEnum, "generic enum");
ENTRY(GenericStruct, "generic struct");
ENTRY(GenericClass, "generic class");
ENTRY(GenericActor, "generic actor");
ENTRY(GenericDistributedActor, "generic distributed actor");
ENTRY(GenericType, "generic type");
ENTRY(Subscript, "subscript");
ENTRY(StaticSubscript, "static subscript");
Expand Down
69 changes: 69 additions & 0 deletions lib/AST/DiagnosticEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,75 @@ InFlightDiagnostic::fixItReplaceChars(SourceLoc Start, SourceLoc End,
return *this;
}

SourceLoc
DiagnosticEngine::getBestAddImportFixItLoc(const Decl *Member,
SourceFile *sourceFile) const {
auto &SM = SourceMgr;

SourceLoc bestLoc;

auto SF =
sourceFile ? sourceFile : Member->getDeclContext()->getParentSourceFile();
if (!SF) {
return bestLoc;
}

for (auto item : SF->getTopLevelItems()) {
// If we found an import declaration, we want to insert after it.
if (auto importDecl =
dyn_cast_or_null<ImportDecl>(item.dyn_cast<Decl *>())) {
SourceLoc loc = importDecl->getEndLoc();
if (loc.isValid()) {
bestLoc = Lexer::getLocForEndOfLine(SM, loc);
}

// Keep looking for more import declarations.
continue;
}

// If we got a location based on import declarations, we're done.
if (bestLoc.isValid())
break;

// For any other item, we want to insert before it.
SourceLoc loc = item.getStartLoc();
if (loc.isValid()) {
bestLoc = Lexer::getLocForStartOfLine(SM, loc);
break;
}
}

return bestLoc;
}

InFlightDiagnostic &InFlightDiagnostic::fixItAddImport(StringRef ModuleName) {
assert(IsActive && "Cannot modify an inactive diagnostic");
auto Member = Engine->ActiveDiagnostic->getDecl();
SourceLoc bestLoc = Engine->getBestAddImportFixItLoc(Member);

if (bestLoc.isValid()) {
llvm::SmallString<64> importText;

// @_spi imports.
if (Member->isSPI()) {
auto spiGroups = Member->getSPIGroups();
if (!spiGroups.empty()) {
importText += "@_spi(";
importText += spiGroups[0].str();
importText += ") ";
}
}

importText += "import ";
importText += ModuleName;
importText += "\n";

return fixItInsert(bestLoc, importText);
}

return *this;
}

InFlightDiagnostic &InFlightDiagnostic::fixItExchange(SourceRange R1,
SourceRange R2) {
assert(IsActive && "Cannot modify an inactive diagnostic");
Expand Down
5 changes: 3 additions & 2 deletions lib/Basic/SourceLoc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
//
//===----------------------------------------------------------------------===//

#include "swift/Basic/Range.h"
#include "swift/Basic/SourceLoc.h"
#include "swift/AST/SourceFile.h"
#include "swift/Basic/Range.h"
#include "swift/Basic/SourceManager.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/raw_ostream.h"

using namespace swift;

Expand Down
29 changes: 2 additions & 27 deletions lib/Sema/CSDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6196,33 +6196,8 @@ bool InaccessibleMemberFailure::diagnoseAsError() {
definingModule->getName());

auto enclosingSF = getDC()->getParentSourceFile();
SourceLoc bestLoc;
SourceManager &srcMgr = Member->getASTContext().SourceMgr;
for (auto item : enclosingSF->getTopLevelItems()) {
// If we found an import declaration, we want to insert after it.
if (auto importDecl =
dyn_cast_or_null<ImportDecl>(item.dyn_cast<Decl *>())) {
SourceLoc loc = importDecl->getEndLoc();
if (loc.isValid()) {
bestLoc = Lexer::getLocForEndOfLine(srcMgr, loc);
}

// Keep looking for more import declarations.
continue;
}

// If we got a location based on import declarations, we're done.
if (bestLoc.isValid())
break;

// For any other item, we want to insert before it.
SourceLoc loc = item.getStartLoc();
if (loc.isValid()) {
bestLoc = Lexer::getLocForStartOfLine(srcMgr, loc);
break;
}
}

SourceLoc bestLoc =
getBestAddImportFixItLocation(Member, getDC()->getParentSourceFile());
if (bestLoc.isValid()) {
llvm::SmallString<64> importText;

Expand Down
6 changes: 6 additions & 0 deletions lib/Sema/CSDiagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class FailureDiagnostic {

ConstraintLocator *getLocator() const { return Locator; }

SourceLoc getBestAddImportFixItLocation(const Decl *Member,
SourceFile *sourceFile) const {
auto &engine = Member->getASTContext().Diags;
return engine.getBestAddImportFixItLoc(Member, sourceFile);
}

Type getType(ASTNode node, bool wantRValue = true) const;

/// Get type associated with a given ASTNode without resolving it,
Expand Down
1 change: 1 addition & 0 deletions lib/Sema/CodeSynthesisDistributedActor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ deriveBodyDistributed_thunk(AbstractFunctionDecl *thunk, void *context) {

// === Type:
StructDecl *RCT = C.getRemoteCallTargetDecl();
assert(RCT && "Missing RemoteCallTarget declaration");
Type remoteCallTargetTy = RCT->getDeclaredInterfaceType();

// === __isRemoteActor(self)
Expand Down
34 changes: 24 additions & 10 deletions lib/Sema/TypeCheckDistributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "swift/AST/NameLookupRequests.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/TypeVisitor.h"
#include "swift/AST/ImportCache.h"
#include "swift/AST/ExistentialLayout.h"
#include "swift/Basic/Defer.h"
#include "swift/AST/ASTPrinter.h"
Expand All @@ -33,7 +34,7 @@ using namespace swift;

// ==== ------------------------------------------------------------------------

bool swift::ensureDistributedModuleLoaded(Decl *decl) {
bool swift::ensureDistributedModuleLoaded(const ValueDecl *decl) {
auto &C = decl->getASTContext();
auto moduleAvailable = evaluateOrDefault(
C.evaluator, DistributedModuleIsAvailableRequest{decl}, false);
Expand All @@ -42,14 +43,25 @@ bool swift::ensureDistributedModuleLoaded(Decl *decl) {

bool
DistributedModuleIsAvailableRequest::evaluate(Evaluator &evaluator,
Decl *decl) const {
const ValueDecl *decl) const {
auto &C = decl->getASTContext();

if (C.getLoadedModule(C.Id_Distributed))
auto DistributedModule = C.getLoadedModule(C.Id_Distributed);
if (!DistributedModule) {
decl->diagnose(diag::distributed_decl_needs_explicit_distributed_import,
decl)
.fixItAddImport("Distributed");
return false;
}

auto &importCache = C.getImportCache();
if (importCache.isImportedBy(DistributedModule, decl->getDeclContext())) {
return true;
}

// seems we're missing the Distributed module, ask to import it explicitly
decl->diagnose(diag::distributed_actor_needs_explicit_distributed_import);
decl->diagnose(diag::distributed_decl_needs_explicit_distributed_import,
decl);
return false;
}

Expand Down Expand Up @@ -502,6 +514,10 @@ bool swift::checkDistributedFunction(AbstractFunctionDecl *func) {
if (!func->isDistributed())
return false;

// ==== Ensure the Distributed module is available,
if (!swift::ensureDistributedModuleLoaded(func))
return true;

auto &C = func->getASTContext();
return evaluateOrDefault(C.evaluator,
CheckDistributedFunctionRequest{func},
Expand All @@ -521,13 +537,11 @@ bool CheckDistributedFunctionRequest::evaluate(
auto module = func->getParentModule();

/// If no distributed module is available, then no reason to even try checks.
if (!C.getLoadedModule(C.Id_Distributed))
if (!C.getLoadedModule(C.Id_Distributed)) {
func->diagnose(diag::distributed_decl_needs_explicit_distributed_import,
func);
return true;

// // No checking for protocol requirements because they are not required
// // to have `SerializationRequirement`.
// if (isa<ProtocolDecl>(func->getDeclContext()))
// return false;
}

Type serializationReqType =
getDistributedActorSerializationType(func->getDeclContext());
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/TypeCheckDistributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NominalTypeDecl;
/******************************************************************************/

// Diagnose an error if the Distributed module is not loaded.
bool ensureDistributedModuleLoaded(Decl *decl);
bool ensureDistributedModuleLoaded(const ValueDecl *decl);

/// Check for illegal property declarations (e.g. re-declaring transport or id)
void checkDistributedActorProperties(const NominalTypeDecl *decl);
Expand Down
7 changes: 7 additions & 0 deletions test/Distributed/Inputs/FakeDistributedActorSystems.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

import Distributed

// ==== Example Distributed Actors ----------------------------------------------

@available(SwiftStdlib 5.7, *)
public distributed actor FakeRoundtripActorSystemDistributedActor {
public typealias ActorSystem = FakeRoundtripActorSystem
}

// ==== Fake Address -----------------------------------------------------------

public struct ActorAddress: Hashable, Sendable, Codable {
Expand Down

0 comments on commit 6d5fa1e

Please sign in to comment.