From e2155bf3a6303b974b5dec791a00f69fea31a5d7 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 20 Jan 2019 17:23:27 -0800 Subject: [PATCH 1/3] [AutoDiff] Enable differentiation of generic functions. - Relax differentiability diagnostic for generic functions. - Previously, an error was emitted when attempting to differentiate any generic function. Now, diagnose only functions with indirect differentiation parameters/result. - Propagate differentiation associated function generic signature throughout differentiation pass. - Change `PrimalGenCloner` to inherit `TypeSubstCloner`. - Make primal value structs inherit primal function's generic parameters and signature. - Calculate correct substitution map for `PrimalGenCloner::visitApplyInst`. Emit diagnostic when apply instruction's associated function (e.g. VJP) has generic requirements unmet by the primal generic environment. - Remap types in `AdjointEmitter`. - Remove manually `@differentiable` attribute where clause conformance requirement checks. - `GenericSignatureBuilder` already performs checks so manual checks are unnecessary. --- include/swift/AST/DiagnosticsSIL.def | 8 +- lib/SIL/SILFunctionType.cpp | 3 - .../Mandatory/Differentiation.cpp | 398 +++++++++++------- lib/Sema/TypeCheckAttr.cpp | 38 +- test/AutoDiff/autodiff_diagnostics.swift | 2 +- .../differentiable_attr_type_checking.swift | 2 +- 6 files changed, 264 insertions(+), 187 deletions(-) diff --git a/include/swift/AST/DiagnosticsSIL.def b/include/swift/AST/DiagnosticsSIL.def index 20feb837525bf..b8926af90a2e8 100644 --- a/include/swift/AST/DiagnosticsSIL.def +++ b/include/swift/AST/DiagnosticsSIL.def @@ -368,8 +368,9 @@ ERROR(autodiff_unsupported_type,none, "differentiating '%0' is not supported yet", (Type)) ERROR(autodiff_function_not_differentiable,none, "function is not differentiable", ()) -NOTE(autodiff_function_generic_functions_unsupported,none, - "differentiating generic functions is not supported yet", ()) +NOTE(autodiff_function_indirect_params_or_result_unsupported,none, + "differentiating functions with parameters or result of unknown size " + "is not supported yet", ()) NOTE(autodiff_external_nondifferentiable_function,none, "cannot differentiate an external function that has not been marked " "'@differentiable'", ()) @@ -386,6 +387,9 @@ NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none, NOTE(autodiff_function_subset_indices_not_differentiable,none, "function is differentiable only with respect to a smaller subset of " "arguments", ()) +NOTE(autodiff_function_assoc_func_requirements_unmet,none, + "function call is not differentiate because generic requirements are not " + "met", ()) NOTE(autodiff_opaque_function_not_differentiable,none, "opaque non-'@autodiff' function is not differentiable", ()) NOTE(autodiff_property_not_differentiable,none, diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index 74f3fd1f4443b..fd78dfdff9be9 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -289,9 +289,6 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( /*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None, ParameterConvention::Direct_Guaranteed, tangentParams, {}, tangentResults, None, ctx); - SmallVector jvpResults( - curryLevels.back()->getResults().begin(), - curryLevels.back()->getResults().end()); break; } case AutoDiffAssociatedFunctionKind::VJP: { diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 6bb1a0116a64a..53dd88b27aa7c 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -36,7 +36,7 @@ #include "swift/SIL/FormalLinkage.h" #include "swift/SIL/LoopInfo.h" #include "swift/SIL/SILBuilder.h" -#include "swift/SIL/SILCloner.h" +#include "swift/SIL/TypeSubstCloner.h" #include "swift/SILOptimizer/Analysis/DominanceAnalysis.h" #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" #include "swift/SILOptimizer/PassManager/Passes.h" @@ -246,6 +246,47 @@ static SILType getCotangentType(SILType type, SILModule &mod) { return getCotangentType(type.getASTType(), mod); } +// Return the expected generic signature for autodiff associated functions given +// a SILDifferentiableAttr. The expected generic signature is built from the +// original generic signature and the attribute's requirements. +static CanGenericSignature +getAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr, + SILFunction *original) { + auto originalGenSig = + original->getLoweredFunctionType()->getGenericSignature(); + if (!originalGenSig) + return nullptr; + GenericSignatureBuilder builder(original->getASTContext()); + // Add original generic signature. + builder.addGenericSignature(originalGenSig); + // Add where clause requirements. + auto source = + GenericSignatureBuilder::FloatingRequirementSource::forAbstract(); + for (auto &req : attr->getRequirements()) + builder.addRequirement(req, source, original->getModule().getSwiftModule()); + return std::move(builder) + .computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams=*/true) + ->getCanonicalSignature(); +} + +// Clone the generic parameters of the given generic signature and return a new +// `GenericParamList`. +static GenericParamList *cloneGenericParameters(ASTContext &ctx, + DeclContext *dc, + CanGenericSignature sig) { + SmallVector clonedParams; + for (auto paramType : sig->getGenericParams()) { + auto clonedParam = new (ctx) GenericTypeParamDecl(dc, paramType->getName(), + SourceLoc(), + paramType->getDepth(), + paramType->getIndex()); + clonedParam->setDeclContext(dc); + clonedParam->setImplicit(true); + clonedParams.push_back(clonedParam); + } + return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc()); +} + //===----------------------------------------------------------------------===// // Auxiliary data structures //===----------------------------------------------------------------------===// @@ -394,14 +435,6 @@ class PrimalInfo { /// The primal value struct declaration. StructDecl *primalValueStruct = nullptr; - /// The SIL module; - const SILModule &module; - - /// The corresponding type of the primal value struct. This is initially - /// null. After this field is computed, mutation of primal value will lead to - /// unexpected behavior. - StructType *primalValueStructType = nullptr; - /// Mapping from `apply` and `struct_extract` instructions in the original /// function to the corresponding pullback decl in the primal struct. DenseMap pullbackValueMap; @@ -418,7 +451,10 @@ class PrimalInfo { new (ctx) UsableFromInlineAttr(/*implicit*/ true)); else varDecl->setAccess(AccessLevel::Public); - varDecl->setInterfaceType(type); + if (type->hasArchetype()) + varDecl->setInterfaceType(type->mapTypeOutOfContext()); + else + varDecl->setInterfaceType(type); primalValueStruct->addMember(varDecl); return varDecl; } @@ -427,35 +463,13 @@ class PrimalInfo { PrimalInfo(const PrimalInfo &) = delete; PrimalInfo &operator=(const PrimalInfo &) = delete; - explicit PrimalInfo(StructDecl *primalValueStruct, const SILModule &module) - : primalValueStruct(&*primalValueStruct), module(module) {} + explicit PrimalInfo(StructDecl *primalValueStruct) + : primalValueStruct(&*primalValueStruct) {} /// Returns the primal value struct that the primal info is established /// around. StructDecl *getPrimalValueStruct() const { return primalValueStruct; } - /// Computes the primal value struct type. - StructType *computePrimalValueStructType() { - assert(!primalValueStructType && - "The primal value struct type has been computed before"); - primalValueStructType = StructType::get(primalValueStruct, Type(), - primalValueStruct->getASTContext()); - return primalValueStructType; - } - - /// Returns the primal value struct type, assuming the primal value struct - /// type has already been computed before. - StructType *getPrimalValueStructType() const { - assert(primalValueStructType && - "The primal value struct type has not been computed"); - return primalValueStructType; - } - - /// Returns the lowered SIL type for the primal value struct. - SILType getLoweredPrimalValueStructType() const { - return module.Types.getLoweredType(getPrimalValueStructType()); - } - /// Add a pullback to the primal value struct. VarDecl *addPullbackDecl(SILInstruction *inst, Type pullbackType) { // Decls must have AST types (not `SILFunctionType`), so we convert the @@ -866,7 +880,8 @@ class ADContext { /// Creates a struct declaration (without contents) for storing primal values /// of a function. The newly created struct will have the same generic /// parameters as the function. - StructDecl *createPrimalValueStruct(const DifferentiationTask *task); + StructDecl *createPrimalValueStruct(const DifferentiationTask *task, + CanGenericSignature primalGenericSig); /// Finds the `[differentiable]` attribute on the specified original function /// corresponding to the specified parameter indices. Returns nullptr if it @@ -977,6 +992,11 @@ class ADContext { return getASTContext().Diags.diagnose(loc, diag, std::forward(args)...); } + /// Emit a "not differentiable" error based on the given differentiation task + /// and diagnostic. + void emitNondifferentiabilityError(const DifferentiationTask *task, + Diag<> diag); + /// Given an instruction and a differentiation task associated with the /// parent function, emits a "not differentiable" error based on the task. If /// the task is indirect, emits notes all the way up to the outermost task, @@ -999,6 +1019,13 @@ ADContext::ADContext(SILModuleTransform &transform) : transform(transform), module(*transform.getModule()), passManager(*transform.getPassManager()) {} +void ADContext::emitNondifferentiabilityError(const DifferentiationTask *task, + Diag<> diag) { + auto invoker = task->getInvoker(); + diagnose(invoker.getLocation(), diag); + diagnose(invoker.getLocation(), diag::autodiff_function_not_differentiable); +} + void ADContext::emitNondifferentiabilityError(SILValue value, const DifferentiationTask *task, Optional> diag) { @@ -1452,6 +1479,52 @@ static void dumpActivityInfo(SILFunction &fn, } } +/// If the original function in the differentiation task has more than one basic +/// blocks, emit a "control flow unsupported" error at appropriate source +/// locations. Returns true if error is emitted. +static bool diagnoseUnsupportedControlFlow(ADContext &context, + DifferentiationTask *task) { + if (task->getOriginal()->getBlocks().size() <= 1) + return false; + // Find any control flow node and diagnose. + for (auto &bb : *task->getOriginal()) { + auto *term = bb.getTerminator(); + if (term->isBranch()) { + context.emitNondifferentiabilityError( + term, task, diag::autodiff_control_flow_not_supported); + return true; + } + } + return false; +} + +/// If the original function in the differentiation task has indirect +/// differentiation parameters/result, emit a "unknown parameter or result +/// size" error at appropriate source locations. Returns true if error is +/// emitted. +static bool diagnoseIndirectParamsOrResult(ADContext &context, + DifferentiationTask *task) { + auto originalFnTy = task->getOriginal()->getLoweredFunctionType(); + auto indices = task->getIndices(); + // Check whether differentiation result or parameters are indirect. + bool originalHasIndirectParamOrResult = + originalFnTy->getResults()[indices.source].isFormalIndirect(); + for (unsigned i : swift::indices(originalFnTy->getParameters())) { + if (indices.isWrtParameter(i)) { + if (originalFnTy->getParameters()[i].isFormalIndirect()) { + originalHasIndirectParamOrResult = true; + break; + } + } + } + if (originalHasIndirectParamOrResult) { + context.emitNondifferentiabilityError( + task, diag::autodiff_function_indirect_params_or_result_unsupported); + return true; + } + return false; +} + //===----------------------------------------------------------------------===// // Code emission utilities //===----------------------------------------------------------------------===// @@ -1803,7 +1876,8 @@ class PrimalGen { } // end anonymous namespace StructDecl * -ADContext::createPrimalValueStruct(const DifferentiationTask *task) { +ADContext::createPrimalValueStruct(const DifferentiationTask *task, + CanGenericSignature primalGenericSig) { auto *function = task->getOriginal(); assert(&function->getModule() == &module && "The function must be in the same module"); @@ -1818,6 +1892,13 @@ ADContext::createPrimalValueStruct(const DifferentiationTask *task) { /*NameLoc*/ loc, /*Inherited*/ {}, /*GenericParams*/ nullptr, // to be set later /*DC*/ &file); + if (primalGenericSig) { + auto genericParams = + cloneGenericParameters(astCtx, pvStruct, primalGenericSig); + pvStruct->setGenericParams(genericParams); + pvStruct->setGenericEnvironment( + primalGenericSig->createGenericEnvironment()); + } pvStruct->computeType(); if (auto *dc = function->getDeclContext()) { if (auto *afd = dyn_cast(dc)) { @@ -1833,9 +1914,6 @@ ADContext::createPrimalValueStruct(const DifferentiationTask *task) { pvStruct->getAttrs().add( new (astCtx) UsableFromInlineAttr(/*implicit*/ true)); } - if (auto originalGenSig = - task->getOriginal()->getLoweredFunctionType()->getGenericSignature()) - pvStruct->setGenericEnvironment(originalGenSig->createGenericEnvironment()); file.addVisibleDecl(pvStruct); LLVM_DEBUG({ auto &s = getADDebugStream(); @@ -1922,27 +2000,9 @@ static void collectMinimalIndicesForFunctionCall( } } -/// If the original function in the differentiation task has more than one basic -/// blocks, emit a "control flow unsupported" error at appropriate source -/// locations. Returns true if error is emitted. -static bool diagnoseUnsupportedControlFlow(ADContext &context, - DifferentiationTask *task) { - if (task->getOriginal()->getBlocks().size() <= 1) - return false; - // Find any control flow node and diagnose. - for (auto &bb : *task->getOriginal()) { - auto *term = bb.getTerminator(); - if (term->isBranch()) { - context.emitNondifferentiabilityError( - term, task, diag::autodiff_control_flow_not_supported); - return true; - } - } - return false; -} - namespace { -class PrimalGenCloner final : public SILClonerWithScopes { +class PrimalGenCloner final + : public TypeSubstCloner { private: /// A reference to this function synthesis item. const FunctionSynthesisItem &synthesis; @@ -1976,8 +2036,10 @@ class PrimalGenCloner final : public SILClonerWithScopes { public: explicit PrimalGenCloner(const FunctionSynthesisItem &synthesis, const DifferentiableActivityInfo &activityInfo, + SubstitutionMap substMap, PrimalGen &primalGen, ADContext &context) - : SILClonerWithScopes(*synthesis.target), synthesis(synthesis), + : TypeSubstCloner(*synthesis.target, *synthesis.original, substMap), + synthesis(synthesis), activityInfo(activityInfo), primalGen(primalGen) {} @@ -1990,17 +2052,15 @@ class PrimalGenCloner final : public SILClonerWithScopes { // Run primal generation. Returns true on error. bool run() { auto *original = getOriginal(); + auto *primal = getPrimal(); LLVM_DEBUG(getADDebugStream() - << "Cloning original @" << getOriginal()->getName() + << "Cloning original @" << original->getName() << " to primal @" << synthesis.target->getName() << '\n'); // Create entry BB and arguments. - auto *entry = getPrimal()->createBasicBlock(); - // Map the original's arguments to the new function's arguments. - SmallVector entryArgs; - for (auto *origArg : original->getArguments()) { - auto *newArg = entry->createFunctionArgument(origArg->getType()); - entryArgs.push_back(newArg); - } + auto *entry = primal->createBasicBlock(); + createEntryArguments(primal); + auto entryArgs = map>( + entry->getArguments(), [](SILArgument *arg) { return arg; }); // Clone. cloneFunctionBody(original, entry, entryArgs); // If errors occurred, back out. @@ -2018,6 +2078,9 @@ class PrimalGenCloner final : public SILClonerWithScopes { auto loc = getPrimal()->getLocation(); auto structTy = getPrimalInfo().getPrimalValueStruct()->getDeclaredInterfaceType(); + // TODO: Replace line above. + if (auto primalGenericEnv = getPrimal()->getGenericEnvironment()) + structTy = primalGenericEnv->mapTypeIntoContext(structTy); auto &builder = getBuilder(); builder.setInsertionPoint(exit); auto structLoweredTy = @@ -2067,7 +2130,7 @@ class PrimalGenCloner final : public SILClonerWithScopes { void visit(SILInstruction *inst) { if (errorOccurred) return; - SILClonerWithScopes::visit(inst); + TypeSubstCloner::visit(inst); } void visitSILInstruction(SILInstruction *inst) { @@ -2146,6 +2209,76 @@ class PrimalGenCloner final : public SILClonerWithScopes { primalValues.push_back(pullback); } + // Return the substitution map for the associated function of an apply + // instruction. If the associated function has generic requirements that are + // unfulfilled by the primal function, emit "callee requirements unmet" + // diagnostics for each unmet requirement and return `None`. + Optional getOrDiagnoseAssociatedFunctionSubstitutionMap( + ApplyInst *ai, CanSILFunctionType assocFnTy) { + auto &context = getContext(); + auto origSubstMap = ai->getSubstitutionMap(); + auto assocGenSig = assocFnTy->getGenericSignature(); + if (!assocGenSig) + return origSubstMap; + + auto assocSubstMap = assocGenSig->createGenericEnvironment() + ->getForwardingSubstitutionMap(); + SubstitutionMap primalSubstMap; + auto primalGenEnv = getPrimal()->getGenericEnvironment(); + if (primalGenEnv) + primalSubstMap = primalGenEnv->getForwardingSubstitutionMap(); + + // Jointly iterate through requirements and conformances of VJP callee. + SmallVector unsatisfiedRequirements; + auto conformances = assocSubstMap.getConformances(); + for (auto req : assocGenSig->getRequirements()) { + if (req.getKind() != RequirementKind::Conformance) + continue; + auto conformance = conformances.front(); + auto *proto = conformance.getAbstract(); + assert(proto && "Expected protocol in generic signature requirement"); + auto reqType = req.getFirstType(); + // If requirement type can be substituted in original substutition map to + // form a non-archetype type, use the ssubstituted type. + if (auto origFirstType = reqType.subst(origSubstMap)) + if (!origFirstType->hasArchetype()) + reqType = origFirstType; + // If requirement type has no type parameters and is not an archetype type, + // it is valid. Continue. + if (!reqType->isTypeParameter() && !reqType->hasArchetype()) + continue; + auto isConformanceMet = + origSubstMap.lookupConformance(reqType->getCanonicalType(), proto) || + primalSubstMap.lookupConformance(reqType->getCanonicalType(), proto); + if (!isConformanceMet) + unsatisfiedRequirements.push_back(req); + conformances = conformances.slice(1); + } + // Diagnose unsatisfied requirements. + if (!unsatisfiedRequirements.empty()) { + context.emitNondifferentiabilityError( + ai, getDifferentiationTask(), + diag::autodiff_function_assoc_func_requirements_unmet); + return None; + } + + // If all requirements are satisfied, return associated function + // substitution map. + if (!assocSubstMap.empty()) { + return assocSubstMap.subst( + [&](SubstitutableType *ty) -> Type { + Type type(ty); + if (!primalSubstMap.empty()) + type = type.subst(primalSubstMap); + if (type->hasArchetype() && primalGenEnv) + return type; + return type.subst(origSubstMap); + }, + LookUpConformanceInModule(context.getModule().getSwiftModule())); + } + return origSubstMap; + } + void visitApplyInst(ApplyInst *ai) { auto &context = getContext(); // Special handling logic only applies when `apply` is active. If not, just @@ -2214,16 +2347,13 @@ class PrimalGenCloner final : public SILClonerWithScopes { for (auto origArg : ai->getArguments()) newArgs.push_back(getOpValue(origArg)); assert(newArgs.size() == numVJPParams); - // Apply the VJP. - auto substMap = ai->getSubstitutionMap(); - if (auto vjpGenSig = vjpFnTy->getGenericSignature()) { - auto vjpSubstMap = - vjpGenSig->createGenericEnvironment()->getForwardingSubstitutionMap(); - substMap = vjpSubstMap.subst( - [&](SubstitutableType *ty) { return Type(ty).subst(substMap); }, - LookUpConformanceInModule(context.getModule().getSwiftModule())); + // Get the VJP substitution map and apply the VJP. + auto substMap = getOrDiagnoseAssociatedFunctionSubstitutionMap(ai, vjpFnTy); + if (!substMap) { + errorOccurred = true; + return; } - auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp, substMap, + auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp, *substMap, newArgs, ai->isNonThrowing()); LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); @@ -2282,17 +2412,11 @@ bool PrimalGen::performSynthesis(FunctionSynthesisItem item) { << item.target->getName() << '\n'); // FIXME: If the original function has multiple basic blocks, bail out since // AD does not support control flow yet. - if (diagnoseUnsupportedControlFlow(context, item.task)) { - errorOccurred = true; - return true; - } - // FIXME: Support generics. - auto *original = item.original; - if (original->getLoweredFunctionType()->getGenericSignature()) { - context.diagnose(original->getLocation().getSourceLoc(), - diag::autodiff_function_generic_functions_unsupported); - context.diagnose(original->getLocation().getSourceLoc(), - diag::autodiff_function_not_differentiable); + // FIXME: If the original function has indirect differentiation + // parameters/result, bail out since AD does not support side-effecting + // instructions yet. + if (diagnoseUnsupportedControlFlow(context, item.task) || + diagnoseIndirectParamsOrResult(context, item.task)) { errorOccurred = true; return true; } @@ -2305,7 +2429,12 @@ bool PrimalGen::performSynthesis(FunctionSynthesisItem item) { LLVM_DEBUG(dumpActivityInfo(*item.original, item.task->getIndices(), activityInfo, getADDebugStream())); // Synthesize primal. - PrimalGenCloner cloner(item, activityInfo, *this, context); + auto substMap = item.original->getForwardingSubstitutionMap(); + if (auto primalGenEnv = item.target->getGenericEnvironment()) { + auto primalSubstMap = primalGenEnv->getForwardingSubstitutionMap(); + substMap = substMap.subst(primalSubstMap); + } + PrimalGenCloner cloner(item, activityInfo, substMap, *this, context); // Run the cloner. return cloner.run(); } @@ -2333,7 +2462,6 @@ bool PrimalGen::run() { errorOccurred = true; continue; } - synthesis.task->getPrimalInfo()->computePrimalValueStructType(); synthesis.task->setPrimalSynthesisState(FunctionSynthesisState::Done); } return errorOccurred; @@ -2673,6 +2801,16 @@ class AdjointEmitter final : public SILInstructionVisitor { return insertion.first->getSecond(); } + SILType remapType(SILType ty) { + if (!ty.hasArchetype()) + return ty; + auto *adjointGenEnv = getAdjoint().getGenericEnvironment(); + if (!adjointGenEnv) + return ty; + return ty.subst(getAdjoint().getModule(), + adjointGenEnv->getForwardingSubstitutionMap()); + } + /// Add an adjoint value for the given original value. AdjointValue &addAdjointValue(SILValue originalValue, AdjointValue adjointValue) { @@ -2681,7 +2819,7 @@ class AdjointEmitter final : public SILInstructionVisitor { assert(originalValue->getFunction() == &getOriginal()); LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue); #ifndef NDEBUG - auto origTy = originalValue->getType().getASTType(); + auto origTy = remapType(originalValue->getType()).getASTType(); auto cotanSpace = origTy->getAutoDiffAssociatedVectorSpace( AutoDiffAssociatedVectorSpaceKind::Cotangent, LookUpConformanceInModule(getModule().getSwiftModule())); @@ -2897,8 +3035,6 @@ class AdjointEmitter final : public SILInstructionVisitor { SILLocation remapLocation(SILLocation loc) { return loc; } - SILType remapType(SILType type) { return type; } - void visitApplyInst(ApplyInst *ai) { // Replace a call to a function with a call to its pullback. auto loc = remapLocation(ai->getLoc()); @@ -3012,7 +3148,7 @@ class AdjointEmitter final : public SILInstructionVisitor { for (auto *field : decl->getStoredProperties()) { auto fv = si->getFieldValue(field); addAdjointValue( - fv, AdjointValue::getZero(getCotangentType(fv->getType(), + fv, AdjointValue::getZero(getCotangentType(remapType(fv->getType()), getModule()))); } break; @@ -3060,7 +3196,7 @@ class AdjointEmitter final : public SILInstructionVisitor { assert(!getModule().Types.getTypeLowering(cotangentVectorTy) .isAddressOnly()); auto cotangentVectorSILTy = - SILType::getPrimitiveObjectType(cotangentVectorTy); + remapType(SILType::getPrimitiveObjectType(cotangentVectorTy)); auto *cotangentVectorDecl = cotangentVectorTy->getStructOrBoundGenericStruct(); assert(cotangentVectorDecl); @@ -3091,8 +3227,8 @@ class AdjointEmitter final : public SILInstructionVisitor { eltVals.push_back(av); else eltVals.push_back(AdjointValue::getZero( - SILType::getPrimitiveObjectType( - field->getType()->getCanonicalType()))); + remapType(SILType::getPrimitiveObjectType(field->getType() + ->getCanonicalType())))); } addAdjointValue(sei->getOperand(), AdjointValue::getAggregate(cotangentVectorSILTy, @@ -3138,8 +3274,8 @@ class AdjointEmitter final : public SILInstructionVisitor { case AdjointValue::Kind::Zero: for (auto eltVal : ti->getElements()) addAdjointValue(eltVal, - AdjointValue::getZero(getCotangentType(eltVal->getType(), - getModule()))); + AdjointValue::getZero(remapType(getCotangentType(eltVal->getType(), + getModule())))); break; case AdjointValue::Kind::Materialized: for (auto i : range(ti->getNumOperands())) @@ -3450,7 +3586,7 @@ AdjointValue AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs, newElements.push_back(newElt); } return AdjointValue::getAggregate( - lhsVal->getType(), newElements, allocator); + remapType(lhsVal->getType()), newElements, allocator); } } // 0 @@ -3472,7 +3608,8 @@ AdjointValue AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs, rhs.getAggregateElements())) newElements.push_back( accumulateAdjointsDirect(std::get<0>(elt), std::get<1>(elt))); - return AdjointValue::getAggregate(lhs.getType(), newElements, allocator); + return AdjointValue::getAggregate( + remapType(lhs.getType()), newElements, allocator); } } } @@ -3629,29 +3766,6 @@ bool AdjointGen::performSynthesis(FunctionSynthesisItem item) { // DifferentiationTask //===----------------------------------------------------------------------===// -// Return the expected generic signature for autodiff associated functions given -// a SILDifferentiableAttr. The expected generic signature is built from the -// original generic signature and the attribute's requirements. -static CanGenericSignature -getAutoDiffAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr, - SILFunction *original) { - auto originalGenSig = - original->getLoweredFunctionType()->getGenericSignature(); - if (!originalGenSig) - return nullptr; - GenericSignatureBuilder builder(original->getASTContext()); - // Add original generic signature. - builder.addGenericSignature(originalGenSig); - // Add where clause requirements. - auto source = - GenericSignatureBuilder::FloatingRequirementSource::forAbstract(); - for (auto &req : attr->getRequirements()) - builder.addRequirement(req, source, original->getModule().getSwiftModule()); - return std::move(builder) - .computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams=*/true) - ->getCanonicalSignature(); -} - SILFunction * ADContext::declareExternalAssociatedFunction( SILFunction *original, SILDifferentiableAttr *attr, @@ -3669,8 +3783,7 @@ ADContext::declareExternalAssociatedFunction( name = attr->getVJPName(); break; } - auto assocGenSig = - getAutoDiffAssociatedFunctionGenericSignature(attr, original); + auto assocGenSig = getAssociatedFunctionGenericSignature(attr, original); auto assocFnTy = originalTy->getAutoDiffAssociatedFunctionType( indices.parameters, indices.source, /*differentiationOrder*/ 1, kind, module, LookUpConformanceInModule(module.getSwiftModule()), assocGenSig); @@ -3732,23 +3845,23 @@ void DifferentiationTask::createEmptyPrimal() { auto indices = getIndices(); auto *original = getOriginal(); - auto &module = context.getModule(); auto primalName = original->getASTContext() .getIdentifier("AD__" + original->getName().str() + "__primal_" + indices.mangle()) .str(); - auto primalGenericSig = - getAutoDiffAssociatedFunctionGenericSignature(attr, original); - StructDecl *primalValueStructDecl = context.createPrimalValueStruct(this); + auto primalGenericSig = getAssociatedFunctionGenericSignature(attr, original); + auto *primalGenericEnv = primalGenericSig + ? primalGenericSig->createGenericEnvironment() + : nullptr; + StructDecl *primalValueStructDecl = + context.createPrimalValueStruct(this, primalGenericSig); primalInfo = std::unique_ptr( - new PrimalInfo(primalValueStructDecl, module)); - auto pvType = primalValueStructDecl->getDeclaredType()->getCanonicalType(); - auto objTy = SILType::getPrimitiveObjectType(pvType); - auto resultConv = objTy.isLoadable(module) ? ResultConvention::Owned - : ResultConvention::Indirect; + new PrimalInfo(primalValueStructDecl)); + auto pvType = + primalValueStructDecl->getDeclaredInterfaceType()->getCanonicalType(); auto origResults = original->getLoweredFunctionType()->getResults(); SmallVector results; - results.push_back({pvType, resultConv}); + results.push_back({pvType, ResultConvention::Owned}); results.append(origResults.begin(), origResults.end()); // Create result info for checkpoints. auto originalTy = original->getLoweredFunctionType(); @@ -3769,6 +3882,8 @@ void DifferentiationTask::createEmptyPrimal() { original->getLocation(), primalName, linkage, primalTy, original->isBare(), IsNotTransparent, original->isSerialized()); primal->setUnqualifiedOwnership(); + if (primalGenericEnv) + primal->setGenericEnvironment(primalGenericEnv); LLVM_DEBUG(getADDebugStream() << "Primal function created \n" << *primal << '\n'); } @@ -3876,8 +3991,7 @@ void DifferentiationTask::createEmptyAdjoint() { .getIdentifier("AD__" + original->getName().str() + "__adjoint_" + getIndices().mangle()) .str(); - auto adjGenericSig = - getAutoDiffAssociatedFunctionGenericSignature(attr, original); + auto adjGenericSig = getAssociatedFunctionGenericSignature(attr, original); auto *adjGenericEnv = adjGenericSig ? adjGenericSig->createGenericEnvironment() : nullptr; @@ -3917,8 +4031,7 @@ void DifferentiationTask::createJVP() { .getIdentifier("AD__" + original->getName().str() + "__jvp_" + getIndices().mangle()) .str(); - auto jvpGenericSig = - getAutoDiffAssociatedFunctionGenericSignature(attr, original); + auto jvpGenericSig = getAssociatedFunctionGenericSignature(attr, original); auto *jvpGenericEnv = jvpGenericSig ? jvpGenericSig->createGenericEnvironment() : nullptr; @@ -3976,8 +4089,7 @@ void DifferentiationTask::createVJP() { .getIdentifier("AD__" + original->getName().str() + "__vjp_" + getIndices().mangle()) .str(); - auto vjpGenericSig = - getAutoDiffAssociatedFunctionGenericSignature(attr, original); + auto vjpGenericSig = getAssociatedFunctionGenericSignature(attr, original); auto *vjpGenericEnv = vjpGenericSig ? vjpGenericSig->createGenericEnvironment() : nullptr; diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 17934b6c560c8..bde7adf63c985 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2188,23 +2188,6 @@ static FuncDecl *resolveAutoDiffAssociatedFunction( return candidate; } -// SWIFT_ENABLE_TENSORFLOW -/// Require that the given type either not involve type parameters or be -/// a type parameter. -// TODO: Generalize function to take a `Diagnostic` and merge with -// `diagnoseIndirectGenericTypeParam`. -static bool diagnoseDifferentiableAttrIndirectGenericType(SourceLoc loc, - Type type, - TypeRepr *typeRepr) { - if (type->hasTypeParameter() && !type->is()) { - type->getASTContext() - .Diags.diagnose(loc, diag::differentiable_attr_only_generic_param_req) - .highlight(typeRepr->getSourceRange()); - return true; - } - return false; -} - void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { auto &ctx = TC.Context; auto lookupConformance = @@ -2303,11 +2286,10 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { RequirementRequest::visitRequirements( WhereClauseOwner(original, attr), TypeResolutionStage::Structural, [&](const Requirement &req, RequirementRepr *reqRepr) { - // Check additional constraints. - // TODO: refine constraints. switch (req.getKind()) { case RequirementKind::SameType: case RequirementKind::Superclass: + case RequirementKind::Conformance: break; // Layout requirements are not supported. @@ -2316,24 +2298,6 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { diag::differentiable_attr_unsupported_req_kind) .highlight(reqRepr->getSourceRange()); return false; - - // Conformance requirements are valid if: - // - The first type is a generic type parameter type. - // - The second type is a protocol type or protocol composition type. - case RequirementKind::Conformance: - if (diagnoseDifferentiableAttrIndirectGenericType( - attr->getLocation(), req.getFirstType(), - reqRepr->getSubjectRepr())) - return false; - - if (!req.getSecondType()->is() && - !req.getSecondType()->is()) { - TC.diagnose(attr->getLocation(), - diag::differentiable_attr_non_protocol_type_constraint_req) - .highlight(reqRepr->getSourceRange()); - return false; - } - break; } // Add requirement to generic signature builder. diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index 84878d3bd86fc..8d6bb83bdf3ab 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -24,7 +24,7 @@ _ = gradient(at: 0, in: one_to_one_0) // okay! // Generics //===----------------------------------------------------------------------===// -// expected-note @+3 {{differentiating generic functions is not supported yet}} +// expected-note @+3 {{differentiating functions with parameters or result of unknown size is not supported yet}} // expected-error @+2 {{function is not differentiable}} @differentiable() func generic(_ x: T) -> T { diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index 5d529b1424914..bd4056e6987ab 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -391,7 +391,7 @@ extension FloatingPoint { } } -// expected-error @+2 {{only conformances to protocol types are supported by @differentiable attribute}} +// expected-error @+2 {{type 'Scalar' constrained to non-protocol, non-class type 'Float'}} // expected-error @+1 {{can only differentiate with respect to parameters that conform to 'Differentiable', but 'Scalar' does not conform to 'Differentiable'}} @differentiable(where Scalar : Float) func invalidRequirementConformance(x: Scalar) -> Scalar { From a976a7cb7a02b2f1be141b73b270124bb89645b9 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 21 Jan 2019 04:10:36 -0800 Subject: [PATCH 2/3] Add tests. TODO: Add more tests. --- test/AutoDiff/generics.swift | 14 ++++++++++++++ .../tensor_autodiff_runtime.swift | 7 +++++++ 2 files changed, 21 insertions(+) create mode 100644 test/AutoDiff/generics.swift diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift new file mode 100644 index 0000000000000..f26d33de0d727 --- /dev/null +++ b/test/AutoDiff/generics.swift @@ -0,0 +1,14 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s + +struct Tensor : VectorNumeric, Differentiable { + var value: Float + init(_ value: Float) { self.value = value } +} + +func generic(_ x: Tensor) -> Float { + return x.value + x.value +} +print(pullback(at: Tensor(1), in: generic)) +print(pullback(at: Tensor(3), in: generic)) + +// TODO: add more tests. diff --git a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift index cf7c7c8bdd5a9..49e5ce921bf4f 100644 --- a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift +++ b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift @@ -22,6 +22,13 @@ TensorADTests.testAllBackends("TestSimpleGrad") { expectTrue(gradient(at: [[10], [20]], in: square) == [[20], [40]]) } +TensorADTests.testAllBackends("TestGenericGrad") { + func square(_ x: Tensor) -> Tensor { + return x * x + } + expectEqual([0.2, 0.4, 0.6], gradient(at: Tensor([0.1, 0.2, 0.3]), in: square)) +} + TensorADTests.testAllBackends("+") { let f = { (a: Tensor, b: Tensor) in a + b } expectTrue((Tensor(1), Tensor(1)) == gradient(at: Tensor(0), Tensor(0), in: f)) From ca9e922e4f0f54529e8fa1941868f6c6699b8282 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 21 Jan 2019 04:18:39 -0800 Subject: [PATCH 3/3] Address comments by @rxwei. --- .../Mandatory/Differentiation.cpp | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 53dd88b27aa7c..90c9be871b050 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -879,7 +879,7 @@ class ADContext { /// Creates a struct declaration (without contents) for storing primal values /// of a function. The newly created struct will have the same generic - /// parameters as the function. + /// signature as the given primal generic signature. StructDecl *createPrimalValueStruct(const DifferentiationTask *task, CanGenericSignature primalGenericSig); @@ -1502,19 +1502,19 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, /// differentiation parameters/result, emit a "unknown parameter or result /// size" error at appropriate source locations. Returns true if error is /// emitted. -static bool diagnoseIndirectParamsOrResult(ADContext &context, - DifferentiationTask *task) { +static bool diagnoseIndirectParametersOrResult(ADContext &context, + DifferentiationTask *task) { auto originalFnTy = task->getOriginal()->getLoweredFunctionType(); auto indices = task->getIndices(); // Check whether differentiation result or parameters are indirect. bool originalHasIndirectParamOrResult = originalFnTy->getResults()[indices.source].isFormalIndirect(); for (unsigned i : swift::indices(originalFnTy->getParameters())) { - if (indices.isWrtParameter(i)) { - if (originalFnTy->getParameters()[i].isFormalIndirect()) { - originalHasIndirectParamOrResult = true; - break; - } + if (!indices.isWrtParameter(i)) + continue; + if (originalFnTy->getParameters()[i].isFormalIndirect()) { + originalHasIndirectParamOrResult = true; + break; } } if (originalHasIndirectParamOrResult) { @@ -2078,7 +2078,6 @@ class PrimalGenCloner final auto loc = getPrimal()->getLocation(); auto structTy = getPrimalInfo().getPrimalValueStruct()->getDeclaredInterfaceType(); - // TODO: Replace line above. if (auto primalGenericEnv = getPrimal()->getGenericEnvironment()) structTy = primalGenericEnv->mapTypeIntoContext(structTy); auto &builder = getBuilder(); @@ -2264,19 +2263,18 @@ class PrimalGenCloner final // If all requirements are satisfied, return associated function // substitution map. - if (!assocSubstMap.empty()) { - return assocSubstMap.subst( - [&](SubstitutableType *ty) -> Type { - Type type(ty); - if (!primalSubstMap.empty()) - type = type.subst(primalSubstMap); - if (type->hasArchetype() && primalGenEnv) - return type; - return type.subst(origSubstMap); - }, - LookUpConformanceInModule(context.getModule().getSwiftModule())); - } - return origSubstMap; + if (assocSubstMap.empty()) + return origSubstMap; + return assocSubstMap.subst( + [&](SubstitutableType *ty) -> Type { + Type type(ty); + if (!primalSubstMap.empty()) + type = type.subst(primalSubstMap); + if (type->hasArchetype() && primalGenEnv) + return type; + return type.subst(origSubstMap); + }, + LookUpConformanceInModule(context.getModule().getSwiftModule())); } void visitApplyInst(ApplyInst *ai) { @@ -2416,7 +2414,7 @@ bool PrimalGen::performSynthesis(FunctionSynthesisItem item) { // parameters/result, bail out since AD does not support side-effecting // instructions yet. if (diagnoseUnsupportedControlFlow(context, item.task) || - diagnoseIndirectParamsOrResult(context, item.task)) { + diagnoseIndirectParametersOrResult(context, item.task)) { errorOccurred = true; return true; } @@ -3853,7 +3851,7 @@ void DifferentiationTask::createEmptyPrimal() { auto *primalGenericEnv = primalGenericSig ? primalGenericSig->createGenericEnvironment() : nullptr; - StructDecl *primalValueStructDecl = + auto *primalValueStructDecl = context.createPrimalValueStruct(this, primalGenericSig); primalInfo = std::unique_ptr( new PrimalInfo(primalValueStructDecl));