Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] Fix @differentiable attribute SILGen and serialization. #21837

Merged
merged 12 commits into from
Jan 16, 2019
42 changes: 22 additions & 20 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,8 @@ class ClangImporterSynthesizedTypeAttr : public DeclAttribute {
/// @differentiable(reverse, wrt: (self, .0, .1), adjoint: bar(_:_:_:seed:))
class DifferentiableAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DifferentiableAttr, AutoDiffParameter> {
private llvm::TrailingObjects<DifferentiableAttr,
ParsedAutoDiffParameter> {
public:
struct DeclNameWithLoc {
DeclName Name;
Expand All @@ -1326,7 +1327,7 @@ class DifferentiableAttr final
friend TrailingObjects;

/// The number of parameters specified in 'wrt:'.
unsigned NumParameters;
unsigned NumParsedParameters = 0;
/// The primal function.
Optional<DeclNameWithLoc> Primal;
/// The adjoint function.
Expand All @@ -1347,8 +1348,9 @@ class DifferentiableAttr final
/// The VJP function (optional), to be resolved by the type checker if
/// specified.
FuncDecl *VJPFunction = nullptr;
/// Checked parameter indices, to be resolved by the type checker.
AutoDiffParameterIndices *CheckedParameterIndices = nullptr;
/// The differentiation parameters' indices, to be resolved by the type
/// checker.
AutoDiffParameterIndices *ParameterIndices = nullptr;
/// The trailing where clause, if it exists.
TrailingWhereClause *WhereClause = nullptr;
/// The requirements for autodiff associated functions. Resolved by the type
Expand All @@ -1359,7 +1361,7 @@ class DifferentiableAttr final

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Expand All @@ -1368,7 +1370,7 @@ class DifferentiableAttr final

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Expand All @@ -1378,7 +1380,7 @@ class DifferentiableAttr final
public:
static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Expand All @@ -1387,7 +1389,7 @@ class DifferentiableAttr final

static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Expand All @@ -1399,23 +1401,23 @@ class DifferentiableAttr final
Optional<DeclNameWithLoc> getJVP() const { return JVP; }
Optional<DeclNameWithLoc> getVJP() const { return VJP; }

AutoDiffParameterIndices *getCheckedParameterIndices() const {
return CheckedParameterIndices;
AutoDiffParameterIndices *getParameterIndices() const {
return ParameterIndices;
}
void setCheckedParameterIndices(AutoDiffParameterIndices *pi) {
CheckedParameterIndices = pi;
void setParameterIndices(AutoDiffParameterIndices *pi) {
ParameterIndices = pi;
}

/// The differentiation parameters, i.e. the list of parameters specified in
/// 'wrt:'.
ArrayRef<AutoDiffParameter> getParameters() const {
return { getTrailingObjects<AutoDiffParameter>(), NumParameters };
/// The parsed differentiation parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
MutableArrayRef<AutoDiffParameter> getParameters() {
return { getTrailingObjects<AutoDiffParameter>(), NumParameters };
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
}
size_t numTrailingObjects(OverloadToken<AutoDiffParameter>) const {
return NumParameters;
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

TrailingWhereClause *getWhereClause() const { return WhereClause; }
Expand Down
21 changes: 12 additions & 9 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

namespace swift {

class AutoDiffParameter {
class ParsedAutoDiffParameter {
public:
enum class Kind { Index, Self };

Expand All @@ -38,14 +38,15 @@ class AutoDiffParameter {
} V;

public:
AutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
: Loc(loc), Kind(kind), V(value) {}

static AutoDiffParameter getIndexParameter(SourceLoc loc, unsigned index) {
static ParsedAutoDiffParameter getIndexParameter(SourceLoc loc,
unsigned index) {
return { loc, Kind::Index, index };
}

static AutoDiffParameter getSelfParameter(SourceLoc loc) {
static ParsedAutoDiffParameter getSelfParameter(SourceLoc loc) {
return { loc, Kind::Self, {} };
}

Expand All @@ -62,7 +63,7 @@ class AutoDiffParameter {
return Loc;
}

bool isEqual(const AutoDiffParameter &other) const {
bool isEqual(const ParsedAutoDiffParameter &other) const {
if (getKind() == other.getKind() && getKind() == Kind::Index)
return getIndex() == other.getIndex();
return getKind() == other.getKind() && getKind() == Kind::Self;
Expand All @@ -88,6 +89,7 @@ class Type;
class AutoDiffParameterIndices : public llvm::FoldingSetNode {
friend AutoDiffParameterIndicesBuilder;

public:
/// Bits corresponding to parameters in the set are "on", and bits
/// corresponding to parameters not in the set are "off".
///
Expand All @@ -109,12 +111,13 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
///
const llvm::SmallBitVector parameters;

AutoDiffParameterIndices(llvm::SmallBitVector parameters)
: parameters(parameters) {}

static AutoDiffParameterIndices *get(llvm::SmallBitVector parameters,
ASTContext &C);

private:
AutoDiffParameterIndices(const llvm::SmallBitVector &parameters)
: parameters(parameters) {}

public:
/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
/// the given `string` generated by `getString()`. If the string is invalid,
Expand Down Expand Up @@ -230,7 +233,7 @@ struct SILAutoDiffIndices {
: source(source), parameters(parameters) {}

/// Creates a set of AD indices from the given source index and an array of
/// parameter indices. Elements in `parameters` must be acending integers.
/// parameter indices. Elements in `parameters` must be ascending integers.
/*implicit*/ SILAutoDiffIndices(unsigned source,
ArrayRef<unsigned> parameters);

Expand Down
2 changes: 1 addition & 1 deletion include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ class Parser {

/// Parse the arguments inside the @differentiable attribute.
bool parseDifferentiableAttributeArguments(
SmallVectorImpl<AutoDiffParameter> &params,
SmallVectorImpl<ParsedAutoDiffParameter> &params,
Optional<DifferentiableAttr::DeclNameWithLoc> &primalSpec,
Optional<DifferentiableAttr::DeclNameWithLoc> &adjointSpec,
Optional<DifferentiableAttr::DeclNameWithLoc> &jvpSpec,
Expand Down
8 changes: 8 additions & 0 deletions include/swift/SIL/SILFunctionBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ class SILFunctionBuilder {
ProfileCounter entryCount,
IsThunk_t isThunk);

// SWIFT_ENABLE_TENSORFLOW
// `addFunctionAttributes` edited because @differentiable attribute
// propagation requires access to original function declaration (via
// SILDeclRef).
void addFunctionAttributes(SILFunction *F, SILDeclRef constant,
DeclAttributes &Attrs, SILModule &M);


/// Return the declaration of a function, or create it if it doesn't exist.
SILFunction *getOrCreateFunction(
SILLocation loc, StringRef name, SILLinkage linkage,
Expand Down
12 changes: 4 additions & 8 deletions include/swift/SIL/SILWitnessVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,12 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
if (auto *DA = func->getAttrs().getAttribute<DifferentiableAttr>()) {
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
AutoDiffAssociatedFunctionIdentifier::get(
AutoDiffAssociatedFunctionKind::JVP,
/*differentiationOrder*/ 1,
DA->getCheckedParameterIndices(),
func->getASTContext())));
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
DA->getParameterIndices(), func->getASTContext())));
asDerived().addMethod(funcDeclRef.asAutoDiffAssociatedFunction(
AutoDiffAssociatedFunctionIdentifier::get(
AutoDiffAssociatedFunctionKind::VJP,
/*differentiationOrder*/ 1,
DA->getCheckedParameterIndices(),
func->getASTContext())));
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,
DA->getParameterIndices(), func->getASTContext())));
}
}

Expand Down
4 changes: 2 additions & 2 deletions include/swift/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t SWIFTMODULE_VERSION_MINOR = 461; // Last change: delete differentiation mode
const uint16_t SWIFTMODULE_VERSION_MINOR = 462; // Last change: serialize differentiation indices

using DeclIDField = BCFixed<31>;

Expand Down Expand Up @@ -1584,7 +1584,7 @@ namespace decls_block {
DeclIDField, // JVP function declaration.
IdentifierIDField, // VJP name.
DeclIDField, // VJP function declaration.
BCArray<BCFixed<32>> // Differentiation parameters.
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.
>;

#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
Expand Down
72 changes: 40 additions & 32 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,16 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer.printAttrName("@differentiable");
Printer << '(';
auto *attr = cast<DifferentiableAttr>(this);
auto params = attr->getParameters();
auto parsedParams = attr->getParsedParameters();

// Get original function.
auto *original = dyn_cast_or_null<AbstractFunctionDecl>(D);
bool isProperty = original && isa<AccessorDecl>(original);
if (auto *varDecl = dyn_cast_or_null<VarDecl>(D)) {
isProperty = true;
original = varDecl->getGetter();
}
bool isMethod = original && original->getImplicitSelfDecl() ? true : false;

// Print comma if not leading clause.
bool isLeadingClause = true;
Expand All @@ -564,21 +573,29 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
};

// Print differentiation parameters, if any.
if (!params.empty()) {
if (auto indices = attr->getParameterIndices()) {
printCommaIfNecessary();
Printer << "wrt: (";
interleave(indices->parameters.set_bits(), [&](unsigned index) {
if (isProperty || (isMethod && index == indices->parameters.size() - 1))
Printer << "self";
else
Printer << "." << index;
}, [&] { Printer << ", "; });
Printer << ")";
} else if (!parsedParams.empty()) {
printCommaIfNecessary();
Printer << "wrt: (";
interleave(params, [&](const AutoDiffParameter &param) {
interleave(parsedParams, [&](const ParsedAutoDiffParameter &param) {
switch (param.getKind()) {
case AutoDiffParameter::Kind::Index:
case ParsedAutoDiffParameter::Kind::Index:
Printer << '.' << param.getIndex();
break;
case AutoDiffParameter::Kind::Self:
case ParsedAutoDiffParameter::Kind::Self:
Printer << "self";
break;
}
}, [&] {
Printer << ", ";
});
}, [&] { Printer << ", "; });
Printer << ")";
}
// Print primal function name if any.
Expand All @@ -605,13 +622,9 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
if (!attr->getRequirements().empty()) {
Printer << " where ";
std::function<Type(Type)> getInterfaceType;
auto *original = dyn_cast<AbstractFunctionDecl>(D);
if (auto varDecl = dyn_cast<VarDecl>(D)) {
original = varDecl->getGetter();
}
if (!original || !original->getGenericEnvironment())
if (!original || !original->getGenericEnvironment()) {
getInterfaceType = [](Type Ty) -> Type { return Ty; };
else {
} else {
// Use GenericEnvironment to produce user-friendly
// names instead of something like 't_0_0'.
auto *genericEnv = original->getGenericEnvironment();
Expand Down Expand Up @@ -1046,47 +1059,44 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
// SWIFT_ENABLE_TENSORFLOW
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
NumParameters(parameters.size()),
NumParsedParameters(params.size()),
Primal(std::move(primal)), Adjoint(std::move(adjoint)),
JVP(std::move(jvp)), VJP(std::move(vjp)), WhereClause(clause) {
std::copy(parameters.begin(), parameters.end(),
getTrailingObjects<AutoDiffParameter>());
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
ArrayRef<Requirement> requirements)
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
NumParameters(parameters.size()),
Primal(std::move(primal)), Adjoint(std::move(adjoint)),
JVP(std::move(jvp)), VJP(std::move(vjp)) {
std::copy(parameters.begin(), parameters.end(),
getTrailingObjects<AutoDiffParameter>());
JVP(std::move(jvp)), VJP(std::move(vjp)), ParameterIndices(indices) {
setRequirements(context, requirements);
}

DifferentiableAttr *
DifferentiableAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
TrailingWhereClause *clause) {
unsigned size = totalSizeToAlloc<AutoDiffParameter>(parameters.size());
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
parameters, std::move(primal),
Expand All @@ -1097,25 +1107,23 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
DifferentiableAttr *
DifferentiableAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
ArrayRef<AutoDiffParameter> parameters,
AutoDiffParameterIndices *indices,
Optional<DeclNameWithLoc> primal,
Optional<DeclNameWithLoc> adjoint,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
ArrayRef<Requirement> requirements) {
unsigned size = totalSizeToAlloc<AutoDiffParameter>(parameters.size());
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
void *mem = context.Allocate(sizeof(DifferentiableAttr),
alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
parameters, std::move(primal),
indices, std::move(primal),
std::move(adjoint), std::move(jvp),
std::move(vjp), requirements);
}

void DifferentiableAttr::setRequirements(ASTContext &context,
ArrayRef<Requirement> requirements) {
Requirements =
context.AllocateUninitialized<Requirement>(requirements.size());
std::copy(requirements.begin(), requirements.end(), Requirements.data());
Requirements = context.AllocateCopy(requirements);
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
Expand Down