Skip to content

Commit

Permalink
Merge branch 'tensorflow' into differentiable-attr-where-clause
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-zheng committed Jan 7, 2019
2 parents d22c696 + 11ba9ec commit eb84095
Show file tree
Hide file tree
Showing 24 changed files with 401 additions and 172 deletions.
4 changes: 4 additions & 0 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3522,6 +3522,10 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) {

auto origCalleeType = site.getOrigCalleeType();
auto substCalleeType = site.getSubstCalleeType();

// SWIFT_ENABLE_TENSORFLOW
assert(!origCalleeType->isDifferentiable() && "Differentiable functions "
"should not reach here");

auto args = site.getArguments();
SILFunctionConventions origConv(origCalleeType, getSILModule());
Expand Down
4 changes: 3 additions & 1 deletion lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4931,6 +4931,7 @@ bool Differentiation::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi,
SILFunction *parent = adfi->getFunction();
auto origFnOperand = adfi->getOriginalFunction();
SILBuilder builder(adfi);
auto loc = parent->getLocation();

SmallVector<SILValue, 2> assocFns;
for (auto assocFnKind : {AutoDiffAssociatedFunctionKind::JVP,
Expand All @@ -4953,10 +4954,11 @@ bool Differentiation::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi,
assert(assocFnAndIndices->second == desiredIndices &&
"FIXME: We could emit a thunk that converts the VJP to have the "
"desired indices.");
auto assocFn = assocFnAndIndices->first;
builder.createRetainValue(loc, assocFn, builder.getDefaultAtomicity());
assocFns.push_back(assocFnAndIndices->first);
}

auto loc = parent->getLocation();
auto *newADFI = builder.createAutoDiffFunction(
loc, adfi->getParameterIndices(), adfi->getDifferentiationOrder(),
origFnOperand, assocFns);
Expand Down
12 changes: 11 additions & 1 deletion lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7457,7 +7457,17 @@ Expr *ExprRewriter::finishApply(ApplyExpr *apply, Type openedType,
return special;
}
}


// SWIFT_ENABLE_TENSORFLOW
if (auto *fnTy = cs.getType(fn)->getAs<AnyFunctionType>()) {
if (fnTy->isDifferentiable()) {
auto fnTyNoDiff =
fnTy->withExtInfo(fnTy->getExtInfo().withDifferentiable(false));
fn = new (tc.Context) AutoDiffFunctionExtractOriginalExpr(fn, fnTyNoDiff);
cs.setType(fn, fnTyNoDiff);
cs.cacheExprTypes(fn);
}
}

bool unwrapResult = false;
if (auto *IUOFnTy = dyn_cast<ImplicitlyUnwrappedFunctionConversionExpr>(fn)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal,
auto &C = nominal->getASTContext();
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
if (!v->getType())
if (!v->hasInterfaceType() || !v->getType())
C.getLazyResolver()->resolveDeclSignature(v);
if (!v->getType())
if (!v->hasInterfaceType() || !v->getType())
return false;
auto declType = v->getType()->hasArchetype()
? v->getType()
Expand Down
20 changes: 20 additions & 0 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,8 @@ getOrSynthesizeVectorSpaceStruct(DerivedConformance &derived,
// Skip members with `@noDerivative`.
if (member->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Add this member's corresponding vector space to the parent's vector space
// struct.
auto memberAssocType = getVectorSpaceType(member, nominal, kind);
auto newMember = new (C) VarDecl(
member->isStatic(), member->getSpecifier(), member->isCaptureList(),
Expand All @@ -491,6 +493,24 @@ getOrSynthesizeVectorSpaceStruct(DerivedConformance &derived,
newMember->setValidationToChecked();
newMember->setSetterAccess(member->getFormalAccess());
C.addSynthesizedDecl(newMember);

// Now that this member is in the associated vector space, it should be
// marked `@differentiable` so that the differentiation transform will
// synthesize associated functions for it. We only add this to public
// stored properties, because their access outside the module will go
// through a call to the getter.
if (member->getEffectiveAccess() > AccessLevel::Internal &&
!member->getAttrs().hasAttribute<DifferentiableAttr>()) {
auto *diffableAttr = DifferentiableAttr::create(
C, SourceLoc(), SourceLoc(), ArrayRef<AutoDiffParameter>(), None,
None, None, None, nullptr);
member->getAttrs().add(diffableAttr);
auto *getterType =
member->getGetter()->getInterfaceType()->castTo<AnyFunctionType>();
AutoDiffParameterIndicesBuilder builder(getterType);
builder.setParameter(0);
diffableAttr->setCheckedParameterIndices(builder.build(C));
}
}

// The implicit memberwise constructor must be explicitly created so that it
Expand Down
34 changes: 24 additions & 10 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2216,12 +2216,21 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
LookUpConformanceInModule(D->getDeclContext()->getParentModule());

FuncDecl *original = nullptr;
if (isa<VarDecl>(D)) {
bool isProperty = false;
if (auto *vd = dyn_cast<VarDecl>(D)) {
// When used on a storage decl, @differentiable refers to its getter.
original = cast<VarDecl>(D)->getGetter();
} else if (isa<FuncDecl>(D)) {
original = cast<FuncDecl>(D);
original = vd->getGetter();
isProperty = true;
} else if (auto *fd = dyn_cast<FuncDecl>(D)) {
original = fd;
if (auto *accessor = dyn_cast<AccessorDecl>(fd)) {
isProperty = true;
// We do not support setters yet because inout is not supported yet.
if (accessor->isSetter())
original = nullptr;
}
}

if (!original) {
// Global immutable vars, for example, have no getter, and therefore trigger
// this.
Expand Down Expand Up @@ -2448,12 +2457,17 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
originalFnTy);

if (uncheckedWrtParams.empty()) {
// If 'wrt:' is not specified, the wrt parameters are all the parameters in
// the main parameter group. Self is intentionally excluded.
unsigned numNonSelfParameters = autoDiffParameterIndicesBuilder.size() -
(isMethod ? 1 : 0);
for (unsigned i : range(numNonSelfParameters))
autoDiffParameterIndicesBuilder.setParameter(i);
if (isProperty)
autoDiffParameterIndicesBuilder.setParameter(0);
else {
// If 'wrt:' is not specified, the wrt parameters are all the parameters
// in the main parameter group. Self is intentionally excluded except when
// it's a property.
unsigned numNonSelfParameters = autoDiffParameterIndicesBuilder.size() -
(isMethod ? 1 : 0);
for (unsigned i : range(numNonSelfParameters))
autoDiffParameterIndicesBuilder.setParameter(i);
}
} else {
// 'wrt:' is specified. Validate and collect the selected parameters.
int lastIndex = -1;
Expand Down
8 changes: 6 additions & 2 deletions lib/Serialization/SerializeSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,12 @@ void SILSerializer::writeSILFunction(const SILFunction &F, bool DeclOnly) {
parameters.push_back(indices.parameters[i]);
SILDifferentiableAttrLayout::emitRecord(
Out, ScratchRecord, differentiableAttrAbbrCode,
S.addDeclBaseNameRef(Ctx.getIdentifier(DA->getPrimalName())),
S.addDeclBaseNameRef(Ctx.getIdentifier(DA->getAdjointName())),
DA->hasPrimal()
? S.addDeclBaseNameRef(Ctx.getIdentifier(DA->getPrimalName()))
: IdentifierID(),
DA->hasAdjoint()
? S.addDeclBaseNameRef(Ctx.getIdentifier(DA->getAdjointName()))
: IdentifierID(),
DA->isAdjointPrimitive(),
// TODO: Once we add synthesis for JVP and VJP, serialized
// [differentiable] attrs should always have JVP and VJP, so we should
Expand Down
8 changes: 4 additions & 4 deletions stdlib/public/TensorFlow/Gradients.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ extension Tensor where Scalar : Numeric {
func _adjointMinMax<T : Numeric & Comparable>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
) -> (Tensor<T>, Tensor<T>) {
let denom = 1 + Tensor<T>(x.elementsEqual(y))
let dfdx = seed * Tensor<T>(x.elementsEqual(originalValue)) / denom
let dfdy = seed * Tensor<T>(y.elementsEqual(originalValue)) / denom
let denom = 1 + Tensor<T>(x .== y)
let dfdx = seed * Tensor<T>(x .== originalValue) / denom
let dfdy = seed * Tensor<T>(y .== originalValue) / denom
return (dfdx.unbroadcast(like: x), dfdy.unbroadcast(like: y))
}

Expand Down Expand Up @@ -451,5 +451,5 @@ extension Tensor where Scalar : BinaryFloatingPoint {
func _adjointRelu<T : BinaryFloatingPoint>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
) -> Tensor<T> {
return Tensor(x.elementsGreater(0)) * seed
return Tensor(x .> 0) * seed
}
Loading

0 comments on commit eb84095

Please sign in to comment.