Skip to content

Commit

Permalink
[Sema] Fix Differentiable derived conformances. (#22119)
Browse files Browse the repository at this point in the history
* [Sema] Fix `Differentiable` derived conformances.

Associated structs can inherit from `AdditiveArithmetic` or `VectorNumeric`
if their member's associated type inherit from the protocol, not the member's
type itself.

Revert unuseful refactoring from #22114.

* Update `tensorflow-swift-apis` version.

Use updated+fixed version.
  • Loading branch information
dan-zheng committed Jan 25, 2019
1 parent a0402b9 commit d202231
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 66 deletions.
69 changes: 21 additions & 48 deletions lib/Sema/DerivedConformanceAdditiveArithmeticVectorNumeric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,17 @@ static Type getVectorNumericScalarAssocType(VarDecl *decl, DeclContext *DC) {

// Return the `Scalar` associated type for a nominal type with the given
// members, or nullptr if `Scalar` cannot be derived.
static Type deriveVectorNumeric_Scalar(ArrayRef<VarDecl *> members,
static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal,
DeclContext *DC) {
auto &C = DC->getASTContext();
// Nominal type must be a struct. (Zero stored properties is okay.)
if (!isa<StructDecl>(nominal))
return nullptr;
// If all stored properties conform to `VectorNumeric` and have the same
// `Scalar` associated type, return that `Scalar` associated type.
// Otherwise, the `Scalar` type cannot be derived.
Type sameScalarType;
for (auto member : members) {
for (auto member : nominal->getStoredProperties()) {
if (!member->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(member);
if (!member->hasInterfaceType())
Expand All @@ -126,55 +132,13 @@ static Type deriveVectorNumeric_Scalar(ArrayRef<VarDecl *> members,
return sameScalarType;
}

// Return the `Scalar` associated type for a nominal type with the given
// members, or nullptr if `Scalar` cannot be derived.
static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal,
DeclContext *DC) {
// Nominal type must be a struct. (Zero stored properties is okay.)
if (!isa<StructDecl>(nominal))
return nullptr;
// If all stored properties conform to `VectorNumeric` and have the same
// `Scalar` associated type, return that `Scalar` associated type.
// Otherwise, the `Scalar` type cannot be derived.
SmallVector<VarDecl *, 4> storedProps;
storedProps.append(nominal->getStoredProperties().begin(),
nominal->getStoredProperties().end());
return deriveVectorNumeric_Scalar(storedProps, DC);
}

// Return true if a `VectorNumeric` requirement can be derived for the given
// members of a nominal type.
bool DerivedConformance::canDeriveVectorNumeric(ArrayRef<VarDecl *> members,
DeclContext *DC) {
return (bool)deriveVectorNumeric_Scalar(members, DC);
}

// Return true if given nominal type has a `let` stored with an initial value.
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
return v->isLet() && v->hasInitialValue();
});
}

// Return true if an `AdditiveArithmetic` requirement can be derived for the
// given members of a nominal type.
bool DerivedConformance::canDeriveAdditiveArithmetic(
ArrayRef<VarDecl *> members, DeclContext *DC) {
auto &C = DC->getASTContext();
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
return llvm::all_of(members, [&](VarDecl *v) {
if (!v->hasInterfaceType() || !v->getType())
C.getLazyResolver()->resolveDeclSignature(v);
if (!v->hasInterfaceType() || !v->getType())
return false;
auto declType = v->getType()->hasArchetype()
? v->getType()
: DC->mapTypeIntoContext(v->getType());
return (bool)TypeChecker::conformsToProtocol(declType, addArithProto, DC,
ConformanceCheckFlags::Used);
});
}

bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal,
DeclContext *DC) {
// Nominal type must be a struct. (Zero stored properties is okay.)
Expand All @@ -188,10 +152,19 @@ bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal,
if (hasLetStoredPropertyWithInitialValue(nominal))
return false;
// All stored properties must conform to `AdditiveArithmetic`.
SmallVector<VarDecl *, 4> storedProps;
storedProps.append(nominal->getStoredProperties().begin(),
nominal->getStoredProperties().end());
return canDeriveAdditiveArithmetic(storedProps, DC);
auto &C = nominal->getASTContext();
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
if (!v->hasInterfaceType() || !v->getType())
C.getLazyResolver()->resolveDeclSignature(v);
if (!v->hasInterfaceType() || !v->getType())
return false;
auto declType = v->getType()->hasArchetype()
? v->getType()
: DC->mapTypeIntoContext(v->getType());
return (bool)TypeChecker::conformsToProtocol(declType, addArithProto, DC,
ConformanceCheckFlags::Used);
});
}

bool DerivedConformance::canDeriveVectorNumeric(NominalTypeDecl *nominal,
Expand Down
34 changes: 31 additions & 3 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,13 +649,41 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, diffProperties);

// Associated struct can derive `AdditiveArithmetic` if the associated types
// of all members conform to `AdditiveArithmetic`.
bool canDeriveAdditiveArithmetic =
llvm::all_of(diffProperties, [&](VarDecl *var) {
return TC.conformsToProtocol(getAssociatedType(var, nominal, id),
addArithProto, nominal,
ConformanceCheckFlags::Used);
});

// Associated struct can derive `VectorNumeric` if the associated types of all
// members conform to `VectorNumeric` and share the same scalar type.
Type sameScalarType;
bool canDeriveVectorNumeric =
!diffProperties.empty() &&
llvm::all_of(diffProperties, [&](VarDecl *var) {
auto conf = TC.conformsToProtocol(getAssociatedType(var, nominal, id),
vecNumProto, nominal,
ConformanceCheckFlags::Used);
if (!conf)
return false;
Type scalarType = ProtocolConformanceRef::getTypeWitnessByName(
var->getType(), *conf, C.Id_Scalar, C.getLazyResolver());
if (!sameScalarType) {
sameScalarType = scalarType;
return true;
}
return scalarType->isEqual(sameScalarType);
});

// If the associated struct is `AllDifferentiableVariables`, conform it to:
// - `AdditiveArithmetic`, if all members of the parent conform to
// `AdditiveArithmetic`.
// - `KeyPathIterable`, if the parent conforms to to `KeyPathIterable`.
if (id == C.Id_AllDifferentiableVariables) {
if (DerivedConformance::canDeriveAdditiveArithmetic(diffProperties,
nominal))
if (canDeriveAdditiveArithmetic)
inherited.push_back(addArithType);
if (TC.conformsToProtocol(nominal->getDeclaredInterfaceType(),
kpIterableProto, parentDC,
Expand All @@ -665,7 +693,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
// If all members also conform to `VectorNumeric` with the same `Scalar` type,
// make the associated struct conform to `VectorNumeric` instead of just
// `AdditiveArithmetic`.
if (DerivedConformance::canDeriveVectorNumeric(diffProperties, nominal))
if (canDeriveVectorNumeric)
inherited.push_back(vecNumType);

auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(),
Expand Down
14 changes: 0 additions & 14 deletions lib/Sema/DerivedConformances.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,6 @@ class DerivedConformance {
Type deriveParameterGroup(AssociatedTypeDecl *assocType);

// SWIFT_ENABLE_TENSORFLOW
/// Determine if an AdditiveArithmetic requirement can be derived for the
/// given members of a nominal type.
///
/// \returns True if the requirement can be derived.
static bool canDeriveAdditiveArithmetic(ArrayRef<VarDecl *> members,
DeclContext *DC);

/// Determine if an AdditiveArithmetic requirement can be derived for a type.
///
/// \returns True if the requirement can be derived.
Expand All @@ -226,13 +219,6 @@ class DerivedConformance {
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveAdditiveArithmetic(ValueDecl *requirement);

/// Determine if a VectorNumeric requirement can be derived for the given
/// members of a nominal type.
///
/// \returns True if the requirement can be derived.
static bool canDeriveVectorNumeric(ArrayRef<VarDecl *> members,
DeclContext *DC);

/// Determine if a VectorNumeric requirement can be derived for a type.
///
/// \returns True if the requirement can be derived.
Expand Down
2 changes: 2 additions & 0 deletions test/Sema/struct_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ struct DifferentiableSubset : Differentiable {
@noDerivative var flag: Bool
@noDerivative let technicallyDifferentiable: Float = .pi
}
assertConformsToAdditiveArithmetic(DifferentiableSubset.AllDifferentiableVariables.self)
assertConformsToVectorNumeric(DifferentiableSubset.AllDifferentiableVariables.self)
assertAllDifferentiableVariablesEqualsCotangentVector(DifferentiableSubset.self)
let tangentSubset = DifferentiableSubset.TangentVector(w: 1, b: 1)
let cotangentSubset = DifferentiableSubset.CotangentVector(w: 1, b: 1)
Expand Down
2 changes: 1 addition & 1 deletion utils/update_checkout/update-checkout-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@
"icu": "release-61-1",
"tensorflow": "a6924e6affd935f537cdaf8977094df0e15a7957",
"tensorflow-swift-bindings": "10e591340134c37a6c3a1df735a7334a77d5cbc7",
"tensorflow-swift-apis": "0942cc25f50aaa4e43f6193d306af4626dc0ece7"
"tensorflow-swift-apis": "18cc613db5ff03a510e3ba835a8932531b4a1a84"
}
}
}
Expand Down

0 comments on commit d202231

Please sign in to comment.