diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 9374d59db3e3f..9d36c35bc7c80 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2823,15 +2823,16 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none, /*nominalCanDeriveAdditiveArithmetic*/ bool)) WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none, "synthesis of the 'Differentiable.move(along:)' requirement for %1 " - "requires 'wrappedValue' in property wrapper %0 to be mutable; " - "add an explicit '@noDerivative' attribute" + "requires 'wrappedValue' in property wrapper %0 to be mutable or have a " + "non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute" "%select{|, or conform %1 to 'AdditiveArithmetic'}2", (/*wrapperType*/ Identifier, /*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool)) WARNING(differentiable_let_property_implicit_noderivative_fixit,none, "synthesis of the 'Differentiable.move(along:)' requirement for %0 " "requires all stored properties not marked with `@noDerivative` to be " - "mutable; use 'var' instead, or add an explicit '@noDerivative' attribute" + "mutable or have a non-mutating 'move(along:)'; use 'var' instead, or " + "add an explicit '@noDerivative' attribute " "%select{|, or conform %0 to 'AdditiveArithmetic'}1", (/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool)) diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 451dc2c65efad..5ab7ffc81c1ef 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -32,12 +32,37 @@ using namespace swift; +/// Return true if `move(along:)` can be invoked on the given `Differentiable`- +/// conforming property. +/// +/// If the given property is a `var`, return true because `move(along:)` can be +/// invoked regardless. Otherwise, return true if and only if the property's +/// type's 'Differentiable.move(along:)' witness is non-mutating. +static bool canInvokeMoveAlongOnProperty( + VarDecl *vd, ProtocolConformanceRef diffableConformance) { + assert(diffableConformance && "Property must conform to 'Differentiable'"); + // `var` always supports `move(along:)` since it is mutable. + if (vd->getIntroducer() == VarDecl::Introducer::Var) + return true; + // When the property is a `let`, the only case that would be supported is when + // it has a `move(along:)` protocol requirement witness that is non-mutating. + auto interfaceType = vd->getInterfaceType(); + auto &C = vd->getASTContext(); + auto witness = diffableConformance.getWitnessByName( + interfaceType, DeclName(C, C.Id_move, {C.Id_along})); + if (!witness) + return false; + auto *decl = cast(witness.getDecl()); + return decl->isNonMutating(); +} + /// Get the stored properties of a nominal type that are relevant for /// differentiation, except the ones tagged `@noDerivative`. static void -getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC, - SmallVectorImpl &result, - bool includeLetProperties = false) { +getStoredPropertiesForDifferentiation( + NominalTypeDecl *nominal, DeclContext *DC, + SmallVectorImpl &result, + bool includeLetPropertiesWithNonmutatingMoveAlong = false) { auto &C = nominal->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); for (auto *vd : nominal->getStoredProperties()) { @@ -53,15 +78,18 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC, // Skip stored properties with `@noDerivative` attribute. if (vd->getAttrs().hasAttribute()) continue; - // Skip `let` stored properties if requested. - // `mutating func move(along:)` cannot be synthesized to update `let` - // properties. - if (!includeLetProperties && vd->isLet()) - continue; if (vd->getInterfaceType()->hasError()) continue; auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); - if (!TypeChecker::conformsToProtocol(varType, diffableProto, nominal)) + auto conformance = TypeChecker::conformsToProtocol( + varType, diffableProto, nominal); + if (!conformance) + continue; + // Skip `let` stored properties with a mutating `move(along:)` if requested. + // `mutating func move(along:)` cannot be synthesized to update `let` + // properties. + if (!includeLetPropertiesWithNonmutatingMoveAlong && + !canInvokeMoveAlongOnProperty(vd, conformance)) continue; result.push_back(vd); } @@ -782,18 +810,18 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context, continue; // Check whether to diagnose stored property. auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); - bool conformsToDifferentiable = - !TypeChecker::conformsToProtocol(varType, diffableProto, nominal) - .isInvalid(); + auto diffableConformance = + TypeChecker::conformsToProtocol(varType, diffableProto, nominal); // If stored property should not be diagnosed, continue. - if (conformsToDifferentiable && !vd->isLet()) + if (diffableConformance && + canInvokeMoveAlongOnProperty(vd, diffableConformance)) continue; // Otherwise, add an implicit `@noDerivative` attribute. vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true)); auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false); assert(loc.isValid() && "Expected valid source location"); // Diagnose properties that do not conform to `Differentiable`. - if (!conformsToDifferentiable) { + if (!diffableConformance) { Context.Diags .diagnose( loc, diff --git a/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift b/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift index f4b50fb19fa21..228c782375eba 100644 --- a/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift +++ b/test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift @@ -29,26 +29,62 @@ func testEmpty() { assertConformsToAdditiveArithmetic(Empty.TangentVector.self) } +protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {} +extension DifferentiableWithNonmutatingMoveAlong { + func move(along _: TangentVector) {} +} + +class EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong { + typealias TangentVector = Empty.TangentVector + var zeroTangentVectorInitializer: () -> TangentVector { { .init() } } + static func proof_that_i_have_nonmutating_move_along() { + let empty = EmptyWithInheritedNonmutatingMoveAlong() + empty.move(along: .init()) + } +} + +class EmptyWrapper: Differentiable {} +func testEmptyWrapper() { + assertConformsToAdditiveArithmetic(Empty.TangentVector.self) + assertConformsToAdditiveArithmetic(EmptyWrapper.TangentVector.self) +} + // Test structs with `let` stored properties. // Derived conformances fail because `mutating func move` requires all stored // properties to be mutable. -class ImmutableStoredProperties: Differentiable { +class ImmutableStoredProperties: Differentiable { var okay: Float // expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} let nondiff: Int - // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} let diff: Float - init() { + let letClass: Empty // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'. + + let letClassWithInheritedNonmutatingMoveAlong: EmptyWithInheritedNonmutatingMoveAlong + + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + let letClassGeneric: T // Error due to lack of non-mutating 'move(along:)'. + + let letClassWrappingGeneric: EmptyWrapper // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'. + + init(letClassGeneric: T) { okay = 0 nondiff = 0 diff = 0 + letClass = Empty() + self.letClassGeneric = letClassGeneric + self.letClassWrappingGeneric = EmptyWrapper() } } func testImmutableStoredProperties() { - _ = ImmutableStoredProperties.TangentVector(okay: 1) + _ = ImmutableStoredProperties.TangentVector( + okay: 1, + letClass: Empty.TangentVector(), + letClassWithInheritedNonmutatingMoveAlong: Empty.TangentVector(), + letClassWrappingGeneric: EmptyWrapper.TangentVector()) } class MutableStoredPropertiesWithInitialValue: Differentiable { var x = Float(1) @@ -56,7 +92,7 @@ class MutableStoredPropertiesWithInitialValue: Differentiable { } // Test class with both an empty constructor and memberwise initializer. class AllMixedStoredPropertiesHaveInitialValue: Differentiable { - // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} let x = Float(1) var y = Float(1) // Memberwise initializer should be `init(y:)` since `x` is immutable. @@ -550,7 +586,7 @@ struct Generic {} extension Generic: Differentiable where T: Differentiable {} class WrappedProperties: Differentiable { - // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}} + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}} @ImmutableWrapper var immutableInt: Generic = Generic() // expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} diff --git a/test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift b/test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift index e75b2730232f5..d8ec6e56c588f 100644 --- a/test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift +++ b/test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift @@ -11,6 +11,35 @@ func testEmpty() { assertConformsToAdditiveArithmetic(Empty.TangentVector.self) } +struct EmptyWithConcreteNonmutatingMoveAlong: Differentiable { + typealias TangentVector = Empty.TangentVector + var zeroTangentVectorInitializer: () -> TangentVector { { .init() } } + func move(along _: TangentVector) {} + static func proof_that_i_have_nonmutating_move_along() { + let empty = Self() + empty.move(along: .init()) + } +} + +protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {} +extension DifferentiableWithNonmutatingMoveAlong { + func move(along _: TangentVector) {} +} + +struct EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong { + typealias TangentVector = Empty.TangentVector + var zeroTangentVectorInitializer: () -> TangentVector { { .init() } } + static func proof_that_i_have_nonmutating_move_along() { + let empty = Self() + empty.move(along: .init()) + } +} + +class EmptyClass: Differentiable {} +func testEmptyClass() { + assertConformsToAdditiveArithmetic(EmptyClass.TangentVector.self) +} + // Test interaction with `AdditiveArithmetic` derived conformances. // Previously, this crashed due to duplicate memberwise initializer synthesis. struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {} @@ -21,14 +50,24 @@ struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {} struct ImmutableStoredProperties: Differentiable { var okay: Float - // expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }} + // expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} let nondiff: Int - // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic}} {{3-3=@noDerivative }} + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} let diff: Float + + let nonmutatingMoveAlongStruct: EmptyWithConcreteNonmutatingMoveAlong + + let inheritedNonmutatingMoveAlongStruct: EmptyWithInheritedNonmutatingMoveAlong + + let diffClass: EmptyClass // No error on class-bound `let` with a non-mutating `move(along:)`. } func testImmutableStoredProperties() { - _ = ImmutableStoredProperties.TangentVector(okay: 1) + _ = ImmutableStoredProperties.TangentVector( + okay: 1, + nonmutatingMoveAlongStruct: Empty.TangentVector(), + inheritedNonmutatingMoveAlongStruct: Empty.TangentVector(), + diffClass: EmptyClass.TangentVector()) } struct MutableStoredPropertiesWithInitialValue: Differentiable { var x = Float(1) @@ -36,7 +75,7 @@ struct MutableStoredPropertiesWithInitialValue: Differentiable { } // Test struct with both an empty constructor and memberwise initializer. struct AllMixedStoredPropertiesHaveInitialValue: Differentiable { - // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} let x = Float(1) var y = Float(1) // Memberwise initializer should be `init(y:)` since `x` is immutable. @@ -363,7 +402,7 @@ struct Generic {} extension Generic: Differentiable where T: Differentiable {} struct WrappedProperties: Differentiable { - // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}} + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}} @ImmutableWrapper var immutableInt: Generic // expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} diff --git a/test/AutoDiff/validation-test/class_differentiation.swift b/test/AutoDiff/validation-test/class_differentiation.swift index 2ae7619486214..ee76c21ba1d87 100644 --- a/test/AutoDiff/validation-test/class_differentiation.swift +++ b/test/AutoDiff/validation-test/class_differentiation.swift @@ -524,4 +524,19 @@ ClassTests.test("ClassProperties") { gradient(at: Super(base: 2)) { foo in foo.squared }) } +ClassTests.test("LetProperties") { + final class Foo: Differentiable { + var x: Tracked + init(x: Tracked) { self.x = x } + } + final class Bar: Differentiable { + let x = Foo(x: 2) + } + let bar = Bar() + let grad = gradient(at: bar) { bar in (bar.x.x * bar.x.x).value } + expectEqual(Bar.TangentVector(x: .init(x: 6.0)), grad) + bar.move(along: grad) + expectEqual(8.0, bar.x.x) +} + runAllTests()