Skip to content

Commit

Permalink
[AutoDiff] [Sema] Include certain 'let' properties in 'Differentiable…
Browse files Browse the repository at this point in the history
…' derived conformances. (#33700)

In `Differentiable` derived conformances, `let` properties are currently treated as if they had `@noDerivative` and excluded from the derived `Differentiable` conformance implementation. This is limiting to properties that have a non-mutating `move(along:)` (e.g. class properties), which can be mathematically treated as differentiable variables.

This patch changes the derived conformances behavior such that `let` properties will be included as differentiable variables if they have a non-mutating `move(along:)`. This unblocks the following code:

```swift
final class Foo: Differentiable {
   let x: ClassStuff // Class type with a non-mutating 'move(along:)'

   // Synthesized code:
   //   struct TangentVector {
   //     var x: ClassStuff.TangentVector
   //   }
   //   ...
   //   func move(along direction: TangentVector) {
   //     x.move(along: direction.x)
   //   }
}
```

Resolves SR-13474 (rdar://67982207).
  • Loading branch information
rxwei committed Aug 30, 2020
1 parent 1e09ad0 commit 76d0648
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 28 deletions.
7 changes: 4 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Expand Up @@ -2823,15 +2823,16 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
"requires 'wrappedValue' in property wrapper %0 to be mutable; "
"add an explicit '@noDerivative' attribute"
"requires 'wrappedValue' in property wrapper %0 to be mutable or have a "
"non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute"
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
(/*wrapperType*/ Identifier, /*nominalName*/ Identifier,
/*nominalCanDeriveAdditiveArithmetic*/ bool))
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
"requires all stored properties not marked with `@noDerivative` to be "
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
"mutable or have a non-mutating 'move(along:)'; use 'var' instead, or "
"add an explicit '@noDerivative' attribute "
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))

Expand Down
56 changes: 42 additions & 14 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Expand Up @@ -32,12 +32,37 @@

using namespace swift;

/// Return true if `move(along:)` can be invoked on the given `Differentiable`-
/// conforming property.
///
/// If the given property is a `var`, return true because `move(along:)` can be
/// invoked regardless. Otherwise, return true if and only if the property's
/// type's 'Differentiable.move(along:)' witness is non-mutating.
static bool canInvokeMoveAlongOnProperty(
VarDecl *vd, ProtocolConformanceRef diffableConformance) {
assert(diffableConformance && "Property must conform to 'Differentiable'");
// `var` always supports `move(along:)` since it is mutable.
if (vd->getIntroducer() == VarDecl::Introducer::Var)
return true;
// When the property is a `let`, the only case that would be supported is when
// it has a `move(along:)` protocol requirement witness that is non-mutating.
auto interfaceType = vd->getInterfaceType();
auto &C = vd->getASTContext();
auto witness = diffableConformance.getWitnessByName(
interfaceType, DeclName(C, C.Id_move, {C.Id_along}));
if (!witness)
return false;
auto *decl = cast<FuncDecl>(witness.getDecl());
return decl->isNonMutating();
}

/// Get the stored properties of a nominal type that are relevant for
/// differentiation, except the ones tagged `@noDerivative`.
static void
getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
SmallVectorImpl<VarDecl *> &result,
bool includeLetProperties = false) {
getStoredPropertiesForDifferentiation(
NominalTypeDecl *nominal, DeclContext *DC,
SmallVectorImpl<VarDecl *> &result,
bool includeLetPropertiesWithNonmutatingMoveAlong = false) {
auto &C = nominal->getASTContext();
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
for (auto *vd : nominal->getStoredProperties()) {
Expand All @@ -53,15 +78,18 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
// Skip stored properties with `@noDerivative` attribute.
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
continue;
// Skip `let` stored properties if requested.
// `mutating func move(along:)` cannot be synthesized to update `let`
// properties.
if (!includeLetProperties && vd->isLet())
continue;
if (vd->getInterfaceType()->hasError())
continue;
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
if (!TypeChecker::conformsToProtocol(varType, diffableProto, nominal))
auto conformance = TypeChecker::conformsToProtocol(
varType, diffableProto, nominal);
if (!conformance)
continue;
// Skip `let` stored properties with a mutating `move(along:)` if requested.
// `mutating func move(along:)` cannot be synthesized to update `let`
// properties.
if (!includeLetPropertiesWithNonmutatingMoveAlong &&
!canInvokeMoveAlongOnProperty(vd, conformance))
continue;
result.push_back(vd);
}
Expand Down Expand Up @@ -782,18 +810,18 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
continue;
// Check whether to diagnose stored property.
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
bool conformsToDifferentiable =
!TypeChecker::conformsToProtocol(varType, diffableProto, nominal)
.isInvalid();
auto diffableConformance =
TypeChecker::conformsToProtocol(varType, diffableProto, nominal);
// If stored property should not be diagnosed, continue.
if (conformsToDifferentiable && !vd->isLet())
if (diffableConformance &&
canInvokeMoveAlongOnProperty(vd, diffableConformance))
continue;
// Otherwise, add an implicit `@noDerivative` attribute.
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
assert(loc.isValid() && "Expected valid source location");
// Diagnose properties that do not conform to `Differentiable`.
if (!conformsToDifferentiable) {
if (!diffableConformance) {
Context.Diags
.diagnose(
loc,
Expand Down
48 changes: 42 additions & 6 deletions test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift
Expand Up @@ -29,34 +29,70 @@ func testEmpty() {
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
}

protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {}
extension DifferentiableWithNonmutatingMoveAlong {
func move(along _: TangentVector) {}
}

class EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong {
typealias TangentVector = Empty.TangentVector
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
static func proof_that_i_have_nonmutating_move_along() {
let empty = EmptyWithInheritedNonmutatingMoveAlong()
empty.move(along: .init())
}
}

class EmptyWrapper<T: Differentiable & AnyObject>: Differentiable {}
func testEmptyWrapper() {
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
assertConformsToAdditiveArithmetic(EmptyWrapper<Empty>.TangentVector.self)
}

// Test structs with `let` stored properties.
// Derived conformances fail because `mutating func move` requires all stored
// properties to be mutable.
class ImmutableStoredProperties: Differentiable {
class ImmutableStoredProperties<T: Differentiable & AnyObject>: Differentiable {
var okay: Float

// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let nondiff: Int

// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let diff: Float

init() {
let letClass: Empty // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'.

let letClassWithInheritedNonmutatingMoveAlong: EmptyWithInheritedNonmutatingMoveAlong

// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let letClassGeneric: T // Error due to lack of non-mutating 'move(along:)'.

let letClassWrappingGeneric: EmptyWrapper<T> // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'.

init(letClassGeneric: T) {
okay = 0
nondiff = 0
diff = 0
letClass = Empty()
self.letClassGeneric = letClassGeneric
self.letClassWrappingGeneric = EmptyWrapper<T>()
}
}
func testImmutableStoredProperties() {
_ = ImmutableStoredProperties.TangentVector(okay: 1)
_ = ImmutableStoredProperties<Empty>.TangentVector(
okay: 1,
letClass: Empty.TangentVector(),
letClassWithInheritedNonmutatingMoveAlong: Empty.TangentVector(),
letClassWrappingGeneric: EmptyWrapper<Empty>.TangentVector())
}
class MutableStoredPropertiesWithInitialValue: Differentiable {
var x = Float(1)
var y = Double(1)
}
// Test class with both an empty constructor and memberwise initializer.
class AllMixedStoredPropertiesHaveInitialValue: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let x = Float(1)
var y = Float(1)
// Memberwise initializer should be `init(y:)` since `x` is immutable.
Expand Down Expand Up @@ -550,7 +586,7 @@ struct Generic<T> {}
extension Generic: Differentiable where T: Differentiable {}

class WrappedProperties: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}}
@ImmutableWrapper var immutableInt: Generic<Int> = Generic()

// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
Expand Down
49 changes: 44 additions & 5 deletions test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift
Expand Up @@ -11,6 +11,35 @@ func testEmpty() {
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
}

struct EmptyWithConcreteNonmutatingMoveAlong: Differentiable {
typealias TangentVector = Empty.TangentVector
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
func move(along _: TangentVector) {}
static func proof_that_i_have_nonmutating_move_along() {
let empty = Self()
empty.move(along: .init())
}
}

protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {}
extension DifferentiableWithNonmutatingMoveAlong {
func move(along _: TangentVector) {}
}

struct EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong {
typealias TangentVector = Empty.TangentVector
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
static func proof_that_i_have_nonmutating_move_along() {
let empty = Self()
empty.move(along: .init())
}
}

class EmptyClass: Differentiable {}
func testEmptyClass() {
assertConformsToAdditiveArithmetic(EmptyClass.TangentVector.self)
}

// Test interaction with `AdditiveArithmetic` derived conformances.
// Previously, this crashed due to duplicate memberwise initializer synthesis.
struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {}
Expand All @@ -21,22 +50,32 @@ struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {}
struct ImmutableStoredProperties: Differentiable {
var okay: Float

// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let nondiff: Int

// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let diff: Float

let nonmutatingMoveAlongStruct: EmptyWithConcreteNonmutatingMoveAlong

let inheritedNonmutatingMoveAlongStruct: EmptyWithInheritedNonmutatingMoveAlong

let diffClass: EmptyClass // No error on class-bound `let` with a non-mutating `move(along:)`.
}
func testImmutableStoredProperties() {
_ = ImmutableStoredProperties.TangentVector(okay: 1)
_ = ImmutableStoredProperties.TangentVector(
okay: 1,
nonmutatingMoveAlongStruct: Empty.TangentVector(),
inheritedNonmutatingMoveAlongStruct: Empty.TangentVector(),
diffClass: EmptyClass.TangentVector())
}
struct MutableStoredPropertiesWithInitialValue: Differentiable {
var x = Float(1)
var y = Double(1)
}
// Test struct with both an empty constructor and memberwise initializer.
struct AllMixedStoredPropertiesHaveInitialValue: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
let x = Float(1)
var y = Float(1)
// Memberwise initializer should be `init(y:)` since `x` is immutable.
Expand Down Expand Up @@ -363,7 +402,7 @@ struct Generic<T> {}
extension Generic: Differentiable where T: Differentiable {}

struct WrappedProperties: Differentiable {
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}}
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}}
@ImmutableWrapper var immutableInt: Generic<Int>

// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}
Expand Down
15 changes: 15 additions & 0 deletions test/AutoDiff/validation-test/class_differentiation.swift
Expand Up @@ -524,4 +524,19 @@ ClassTests.test("ClassProperties") {
gradient(at: Super(base: 2)) { foo in foo.squared })
}

ClassTests.test("LetProperties") {
final class Foo: Differentiable {
var x: Tracked<Float>
init(x: Tracked<Float>) { self.x = x }
}
final class Bar: Differentiable {
let x = Foo(x: 2)
}
let bar = Bar()
let grad = gradient(at: bar) { bar in (bar.x.x * bar.x.x).value }
expectEqual(Bar.TangentVector(x: .init(x: 6.0)), grad)
bar.move(along: grad)
expectEqual(8.0, bar.x.x)
}

runAllTests()

0 comments on commit 76d0648

Please sign in to comment.