Skip to content

Commit

Permalink
[AutoDiff] Rename "assocFn" to "derivativeFn" everywhere except Diffe…
Browse files Browse the repository at this point in the history
…rentiation.cpp. (#27597)
  • Loading branch information
bgogul authored and rxwei committed Oct 10, 2019
1 parent eeeeee2 commit 0aee08a
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 102 deletions.
6 changes: 3 additions & 3 deletions include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,14 @@ template<typename ImplClass>
void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
DifferentiableFunctionInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
Optional<std::pair<SILValue, SILValue>> assocFns = None;
Optional<std::pair<SILValue, SILValue>> derivativeFns = None;
if (Inst->hasDerivativeFunctions())
assocFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
derivativeFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
getOpValue(Inst->getVJPFunction()));
recordClonedInstruction(
Inst, getBuilder().createDifferentiableFunction(
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
getOpValue(Inst->getOriginalFunction()), assocFns));
getOpValue(Inst->getOriginalFunction()), derivativeFns));
}

template<typename ImplClass>
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,10 +1045,10 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction(
// Generator for the resultant function type, i.e. the AD associated function.
BuiltinGenericSignatureBuilder::LambdaGenerator resultGen{
[=, &Context](BuiltinGenericSignatureBuilder &builder) -> Type {
auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
auto derivativeFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
paramIndices, /*resultIndex*/ 0, kind,
LookUpConformanceInModule(Context.TheBuiltinModule));
return assocFnTy->getResult();
return derivativeFnTy->getResult();
}};
builder.addParameter(firstArgGen);
for (auto argGen : fnArgGens)
Expand Down
4 changes: 2 additions & 2 deletions lib/SIL/SILDeclRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,9 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
autoDiffAssociatedFunctionIdentifier->getParameterIndices(),
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices);
auto assocFnKind = autoDiffAssociatedFunctionIdentifier->getKind();
auto derivativeFnKind = autoDiffAssociatedFunctionIdentifier->getKind();
return mangler.mangleAutoDiffAssociatedFunctionHelper(
originalMangled, assocFnKind, indices);
originalMangled, derivativeFnKind, indices);
}

// As a special case, Clang functions and globals don't get mangled at all.
Expand Down
25 changes: 13 additions & 12 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() {
// given an existing associated function generic signature. All differentiation
// parameters are constrained to conform to `Differentiable`.
static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
CanGenericSignature assocFnGenSig,
CanGenericSignature derivativeFnGenSig,
ArrayRef<SILParameterInfo> originalParameters,
AutoDiffIndexSubset *parameterIndices, ModuleDecl *module) {
if (!assocFnGenSig)
if (!derivativeFnGenSig)
return nullptr;
auto &ctx = module->getASTContext();
GenericSignatureBuilder builder(ctx);

// Add associated function generic signature.
builder.addGenericSignature(assocFnGenSig);
builder.addGenericSignature(derivativeFnGenSig);
// Constrain all wrt parameters to conform to `Differentiable`.
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
Expand All @@ -182,7 +182,8 @@ static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffAssociatedFunctionKind kind, TypeConverter &TC,
LookupConformanceFn lookupConformance, CanGenericSignature assocFnGenSig) {
LookupConformanceFn lookupConformance,
CanGenericSignature derivativeFnGenSig) {
// JVP: (T...) -> ((R...),
// (T.TangentVector...) -> (R.TangentVector...))
// VJP: (T...) -> ((R...),
Expand All @@ -203,11 +204,11 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
wrtParams.push_back(valueAndIndex.value());

// Get the canonical associated function generic signature.
if (!assocFnGenSig)
assocFnGenSig = getGenericSignature();
assocFnGenSig = getAutoDiffAssociatedFunctionGenericSignature(
assocFnGenSig, getParameters(), parameterIndices, &TC.M);
Lowering::GenericContextScope genericContextScope(TC, assocFnGenSig);
if (!derivativeFnGenSig)
derivativeFnGenSig = getGenericSignature();
derivativeFnGenSig = getAutoDiffAssociatedFunctionGenericSignature(
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);
Lowering::GenericContextScope genericContextScope(TC, derivativeFnGenSig);

// Given a type, returns its formal SIL parameter info.
auto getTangentParameterInfoForOriginalResult = [&](
Expand Down Expand Up @@ -310,12 +311,12 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
newResults.reserve(getNumResults() + 1);
for (auto &result : getResults()) {
auto mappedResult = result.getWithType(
result.getType()->getCanonicalType(assocFnGenSig));
result.getType()->getCanonicalType(derivativeFnGenSig));
newResults.push_back(mappedResult);
}
newResults.push_back({closureType->getCanonicalType(assocFnGenSig),
newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig),
ResultConvention::Owned});
return SILFunctionType::get(assocFnGenSig, getExtInfo(),
return SILFunctionType::get(derivativeFnGenSig, getExtInfo(),
getCoroutineKind(), getCalleeConvention(),
getParameters(), getYields(), newResults,
getOptionalErrorResult(), ctx,
Expand Down
4 changes: 2 additions & 2 deletions lib/SIL/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,10 @@ namespace {
for (AutoDiffAssociatedFunctionKind kind :
{AutoDiffAssociatedFunctionKind::JVP,
AutoDiffAssociatedFunctionKind::VJP}) {
auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
auto derivativeFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
paramIndices, 0, kind, TC,
LookUpConformanceInModule(&TC.M));
auto silTy = SILType::getPrimitiveObjectType(assocFnTy);
auto silTy = SILType::getPrimitiveObjectType(derivativeFnTy);
DifferentiableFunctionExtractee extractee(kind);
// Assert that we have the right extractee. A terrible bug in the past
// was caused by implicit conversions from `unsigned` to
Expand Down
14 changes: 7 additions & 7 deletions lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff associated function thunk for the given
/// SILDeclRef, SILFunction, and associated function type.
SILFunction *getOrCreateAutoDiffThunk(SILDeclRef assocFnRef,
SILFunction *assocFn,
CanSILFunctionType assocFnTy);
SILFunction *getOrCreateAutoDiffThunk(SILDeclRef derivativeFnRef,
SILFunction *derivativeFn,
CanSILFunctionType derivativeFnTy);

// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff associated function vtable entry thunk for the
/// given SILDeclRef and associated function type.
SILFunction *
getOrCreateAutoDiffClassMethodThunk(SILDeclRef assocFnRef,
CanSILFunctionType assocFnTy);
getOrCreateAutoDiffClassMethodThunk(SILDeclRef derivativeFnRef,
CanSILFunctionType derivativeFnTy);

/// Emit a vtable thunk for a derived method if its natural abstraction level
/// diverges from the overridden base method. If no thunking is needed,
Expand Down Expand Up @@ -187,8 +187,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// - The last result in the returned pullback.
SILFunction *getOrCreateAutoDiffAssociatedFunctionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
bool reorderSelf);
SILFunction *derivativeFn,
AutoDiffAssociatedFunctionKind derivativeFnKind, bool reorderSelf);

/// Determine whether the given class has any instance variables that
/// need to be destroyed.
Expand Down
36 changes: 18 additions & 18 deletions lib/SILGen/SILGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1041,19 +1041,19 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
origFnArgVals.push_back(arg.getValue());

// Get the associated function.
SILValue assocFn = SGF.B.createDifferentiableFunctionExtract(
SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract(
loc, kind, origFnVal);
auto assocFnType = assocFn->getType().castTo<SILFunctionType>();
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();

// We don't need to destroy the original function or retain the `assocFn`,
// because they are trivial (because they are @noescape).
// We don't need to destroy the original function or retain the
// `derivativeFn`, because they are trivial (because they are @noescape).
assert(origFnVal->getType().isTrivial(SGF.F));
assert(assocFn->getType().isTrivial(SGF.F));
bool assocFnNeedsDestroy = false;
assert(derivativeFn->getType().isTrivial(SGF.F));
bool derivativeFnNeedsDestroy = false;

// Unwrap curry levels.
SmallVector<SILFunctionType *, 2> curryLevels;
SILFunctionType *currentLevel = assocFnType;
SILFunctionType *currentLevel = derivativeFnType;
unsigned numParameters = 0;
while (currentLevel != nullptr) {
curryLevels.push_back(currentLevel);
Expand All @@ -1074,25 +1074,25 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
#endif

// Apply all the curry levels except the last one, whose results we handle
// specially. We overwrite `assocFn` with the application results.
// specially. We overwrite `derivativeFn` with the application results.
unsigned currentParameter = 0;
auto curryLevelsWithoutLast =
ArrayRef<SILFunctionType *>(curryLevels).drop_back(1);
for (auto *curryLevel : curryLevelsWithoutLast) {
auto curryLevelArgVals = ArrayRef<SILValue>(origFnArgVals).slice(
currentParameter, curryLevel->getNumParameters());
auto applyResult = SGF.B.createApply(
loc, assocFn, SubstitutionMap(), curryLevelArgVals,
loc, derivativeFn, SubstitutionMap(), curryLevelArgVals,
/*isNonThrowing*/ false);
currentParameter += curryLevel->getNumParameters();

assocFn = applyResult;
derivativeFn = applyResult;

// Our new `assocFn` needs to be released because it's an owned result from
// a function call.
// Our new `derivativeFn` needs to be released because it's an owned result
// from a function call.
assert(curryLevel->getSingleResult().getConvention() ==
ResultConvention::Owned);
assocFnNeedsDestroy = true;
derivativeFnNeedsDestroy = true;
}

assert(curryLevels.back()->getNumResults() == 2);
Expand All @@ -1109,10 +1109,10 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
currentParameter);
for (auto origFnArgVal : curryLevelArgVals)
applyArgs.push_back(origFnArgVal);
auto differential = SGF.B.createApply(
loc, assocFn, SubstitutionMap(), applyArgs, /*isNonThrowing*/ false);
auto differential = SGF.B.createApply(loc, derivativeFn, SubstitutionMap(),
applyArgs, /*isNonThrowing*/ false);

assocFn = SILValue();
derivativeFn = SILValue();

SGF.B.createStore(loc, differential,
SGF.B.createTupleElementAddr(loc, indResBuffer, 1),
Expand All @@ -1125,10 +1125,10 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
auto curryLevelArgVals = ArrayRef<SILValue>(origFnArgVals).slice(
currentParameter);
auto resultTuple = SGF.B.createApply(
loc, assocFn, SubstitutionMap(), curryLevelArgVals,
loc, derivativeFn, SubstitutionMap(), curryLevelArgVals,
/*isNonThrowing*/ false);

assocFn = SILValue();
derivativeFn = SILValue();

return SGF.emitManagedRValueWithCleanup(resultTuple);
}
Expand Down
72 changes: 39 additions & 33 deletions lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3311,22 +3311,24 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
return AbstractionPattern(
pattern.getGenericSignature(), getAssocFnTy(patternType, kind));
};
auto createAssocFnThunk = [&](AutoDiffAssociatedFunctionKind kind)
-> ManagedValue {
auto assocFnInputOrigType = getAssocFnPattern(inputOrigTypeNotDiff, kind);
auto assocFnInputSubstType = getAssocFnTy(inputSubstTypeNotDiff, kind);
auto assocFnOutputOrigType = getAssocFnPattern(outputOrigTypeNotDiff,
auto createAssocFnThunk =
[&](AutoDiffAssociatedFunctionKind kind) -> ManagedValue {
auto derivativeFnInputOrigType =
getAssocFnPattern(inputOrigTypeNotDiff, kind);
auto derivativeFnInputSubstType = getAssocFnTy(inputSubstTypeNotDiff, kind);
auto derivativeFnOutputOrigType = getAssocFnPattern(outputOrigTypeNotDiff,
kind);
auto assocFnOutputSubstType = getAssocFnTy(outputSubstTypeNotDiff, kind);
auto &assocFnExpectedTL = SGF.getTypeLowering(assocFnOutputOrigType,
assocFnOutputSubstType);
SILValue assocFn = SGF.B.createDifferentiableFunctionExtract(
auto derivativeFnOutputSubstType =
getAssocFnTy(outputSubstTypeNotDiff, kind);
auto &derivativeFnExpectedTL = SGF.getTypeLowering(
derivativeFnOutputOrigType, derivativeFnOutputSubstType);
SILValue derivativeFn = SGF.B.createDifferentiableFunctionExtract(
loc, kind, borrowedFnValue.getValue());
assocFn = SGF.B.emitCopyValueOperation(loc, assocFn);
auto managedAssocFn = SGF.emitManagedRValueWithCleanup(assocFn);
return createThunk(SGF, loc, managedAssocFn, assocFnInputOrigType,
assocFnInputSubstType, assocFnOutputOrigType,
assocFnOutputSubstType, assocFnExpectedTL);
derivativeFn = SGF.B.emitCopyValueOperation(loc, derivativeFn);
auto managedAssocFn = SGF.emitManagedRValueWithCleanup(derivativeFn);
return createThunk(SGF, loc, managedAssocFn, derivativeFnInputOrigType,
derivativeFnInputSubstType, derivativeFnOutputOrigType,
derivativeFnOutputSubstType, derivativeFnExpectedTL);
};

auto jvpThunk = createAssocFnThunk(AutoDiffAssociatedFunctionKind::JVP);
Expand Down Expand Up @@ -3666,59 +3668,61 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
SILFunction *
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
SILFunction *derivativeFn, AutoDiffAssociatedFunctionKind derivativeFnKind,
bool reorderSelf) {
auto assocFnType = assocFn->getLoweredFunctionType();
auto derivativeFnType = derivativeFn->getLoweredFunctionType();

// TODO(TF-685): Use principled thunk mangling.
// Do not simply reuse reabstraction thunk mangling.
Mangle::ASTMangler mangler;
auto name = getASTContext().getIdentifier(
mangler.mangleAutoDiffAssociatedFunctionHelper(
original->getName(), assocFnKind, indices)).str();
original->getName(), derivativeFnKind, indices)).str();

Lowering::GenericContextScope genericContextScope(
Types, assocFnType->getGenericSignature());
auto *thunkGenericEnv = assocFnType->getGenericSignature()
? assocFnType->getGenericSignature()->getGenericEnvironment()
Types, derivativeFnType->getGenericSignature());
auto *thunkGenericEnv = derivativeFnType->getGenericSignature()
? derivativeFnType->getGenericSignature()->getGenericEnvironment()
: nullptr;

auto origFnType = original->getLoweredFunctionType();
auto origAssocFnType = origFnType->getAutoDiffAssociatedFunctionType(
indices.parameters, indices.source,
assocFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
assocFnType->getGenericSignature());
derivativeFnKind, Types, LookUpConformanceInModule(M.getSwiftModule()),
derivativeFnType->getGenericSignature());
assert(!origAssocFnType->getExtInfo().hasContext());

auto loc = assocFn->getLocation();
auto loc = derivativeFn->getLocation();
SILGenFunctionBuilder fb(*this);
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
original->getLinkage(), /*isAssocFnExported*/ true);
auto *thunk = fb.getOrCreateFunction(
loc, name, linkage, origAssocFnType, IsBare, IsNotTransparent,
assocFn->isSerialized(), assocFn->isDynamicallyReplaceable(),
assocFn->getEntryCount(), assocFn->isThunk(),
assocFn->getClassSubclassScope());
derivativeFn->isSerialized(), derivativeFn->isDynamicallyReplaceable(),
derivativeFn->getEntryCount(), derivativeFn->isThunk(),
derivativeFn->getClassSubclassScope());
if (!thunk->empty())
return thunk;
thunk->setGenericEnvironment(thunkGenericEnv);

SILGenFunction thunkSGF(*this, *thunk, assocFn->getDeclContext());
SILGenFunction thunkSGF(*this, *thunk, derivativeFn->getDeclContext());
SmallVector<ManagedValue, 4> params;
SmallVector<SILArgument *, 4> indirectResults;
thunkSGF.collectThunkParams(loc, params, &indirectResults);

auto *assocFnRef = thunkSGF.B.createFunctionRef(loc, assocFn);
auto assocFnRefType = assocFnRef->getType().castTo<SILFunctionType>();
auto *derivativeFnRef = thunkSGF.B.createFunctionRef(loc, derivativeFn);
auto derivativeFnRefType =
derivativeFnRef->getType().castTo<SILFunctionType>();

// Collect thunk arguments, converting ownership.
SmallVector<SILValue, 8> arguments;
for (auto *indRes : indirectResults)
arguments.push_back(indRes);
forwardFunctionArguments(thunkSGF, loc, assocFnRefType, params, arguments);
forwardFunctionArguments(thunkSGF, loc, derivativeFnRefType, params,
arguments);
// Apply function argument.
auto apply = thunkSGF.emitApplyWithRethrow(
loc, assocFnRef, /*substFnType*/ assocFnRef->getType(),
loc, derivativeFnRef, /*substFnType*/ derivativeFnRef->getType(),
thunk->getForwardingSubstitutionMap(), arguments);

// Create return instruction in the thunk, first deallocating local
Expand All @@ -3734,7 +3738,9 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
// If self ordering is not necessary and linear map types are unchanged,
// return the `apply` instruction.
auto linearMapFnType = cast<SILFunctionType>(
thunk->mapTypeIntoContext(assocFnRefType->getResults().back().getType())
thunk
->mapTypeIntoContext(
derivativeFnRefType->getResults().back().getType())
->getCanonicalType());
auto targetLinearMapFnType = thunk->mapTypeIntoContext(
origAssocFnType->getResults().back().getSILStorageType())
Expand All @@ -3749,7 +3755,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
extractAllElements(apply, loc, thunkSGF.B, directResults);
auto linearMap = thunkSGF.emitManagedRValueWithCleanup(directResults.back());
assert(linearMap.getType().castTo<SILFunctionType>() == linearMapFnType);
auto linearMapKind = assocFnKind.getLinearMapKind();
auto linearMapKind = derivativeFnKind.getLinearMapKind();
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
reorderSelf);
Expand Down
Loading

0 comments on commit 0aee08a

Please sign in to comment.