Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] lift samefile derivative constriant #28790

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class ASTContext final {
// same parameter indices but different derivative generic signatures.
llvm::DenseMap<
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
DerivativeAttr *>
llvm::SmallPtrSet<DerivativeAttr *, 1>>
DerivativeAttrs;

private:
Expand Down
7 changes: 5 additions & 2 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3016,12 +3016,15 @@ NOTE(derivative_attr_result_func_type_mismatch_note,none,
"%0 does not have expected type %1", (Identifier, Type))
NOTE(derivative_attr_result_func_original_note,none,
"%0 defined here", (DeclName))
ERROR(derivative_attr_not_in_same_file_as_original,none,
"derivative not in the same file as the original function", ())
ERROR(derivative_attr_original_stored_property_unsupported,none,
"cannot register derivative for stored property %0", (DeclNameRef))
ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))
ERROR(derivative_attr_visibility_too_broad,none,
"derivative function visibility must be at least as restrictive as original function "
"visibility", ())
NOTE(derivative_attr_visibility_too_broad_note,none,
"original function defined here", ())
NOTE(derivative_attr_duplicate_note,none,
"other attribute declared here", ())

Expand Down
2 changes: 2 additions & 0 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class SILDifferentiabilityWitness
SILDifferentiabilityWitnessKey getKey() const;
SILModule &getModule() const { return Module; }
SILLinkage getLinkage() const { return Linkage; }
void setLinkage(SILLinkage linkage) { Linkage = linkage; }
SILFunction *getOriginalFunction() const { return OriginalFunction; }
const AutoDiffConfig &getConfig() const { return Config; }
IndexSubset *getParameterIndices() const {
Expand Down Expand Up @@ -129,6 +130,7 @@ class SILDifferentiabilityWitness
bool isDeclaration() const { return IsDeclaration; }
bool isDefinition() const { return !IsDeclaration; }
bool isSerialized() const { return IsSerialized; }
void setSerialized(bool isSerialized) { IsSerialized = isSerialized; }
const DeclAttribute *getAttribute() const { return Attribute; }

/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
Expand Down
25 changes: 17 additions & 8 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
"all original SIL functions with generic signatures");
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
diffAttr->getDerivativeGenericSignature());
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr,
F->getLinkage());
}
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
SILFunction *jvp = nullptr;
Expand All @@ -802,7 +803,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
derivativeGenSig);
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
derivAttr);
derivAttr, F->getLinkage());
}
};
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
Expand All @@ -816,7 +817,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
void SILGenModule::emitDifferentiabilityWitness(
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
const DeclAttribute *attr) {
const DeclAttribute *attr, SILLinkage witnessLinkage) {
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
auto origSilFnType = originalFunction->getLoweredFunctionType();
Expand Down Expand Up @@ -855,11 +856,19 @@ void SILGenModule::emitDifferentiabilityWitness(
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
if (!diffWitness) {
diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction,
silConfig.parameterIndices, silConfig.resultIndices,
config.derivativeGenericSignature, /*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
attr);
M, witnessLinkage, originalFunction, silConfig.parameterIndices,
silConfig.resultIndices, config.derivativeGenericSignature,
/*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(witnessLinkage), attr);
}

// Use the least restrictive declared linkage, so that e.g. a
// `@differentiable` on `public` function with `@derivative`s on `internal`
// functions results in a public witness. (Sema is responsible for diagnosing
// forbidden combinations).
if (witnessLinkage < diffWitness->getLinkage()) {
diffWitness->setLinkage(witnessLinkage);
diffWitness->setSerialized(hasPublicVisibility(witnessLinkage));
}

// Set derivative function in differentiability witness.
Expand Down
3 changes: 2 additions & 1 deletion lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
SILFunction *originalFunction,
const AutoDiffConfig &config,
SILFunction *jvp, SILFunction *vjp,
const DeclAttribute *diffAttr);
const DeclAttribute *diffAttr,
SILLinkage witnessLinkage);
// SWIFT_ENABLE_TENSORFLOW END

/// Emit the lazy initializer function for a global pattern binding
Expand Down
4 changes: 0 additions & 4 deletions lib/SILOptimizer/Utils/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,6 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
if (existingWitness)
return existingWitness;

assert(original->isExternalDeclaration() &&
"SILGen should create differentiability witnesses for all function "
"definitions with explicit differentiable attributes");

return SILDifferentiabilityWitness::createDeclaration(
module, SILLinkage::PublicExternal, original,
minimalConfig->parameterIndices, minimalConfig->resultIndices,
Expand Down