diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 5f5ddf1f223f9..4bcf9da169a14 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1508,6 +1508,19 @@ class DifferentiableAttr final static bool classof(const DeclAttribute *DA) { return DA->getKind() == DAK_Differentiable; } + + bool parametersMatch(const DifferentiableAttr &other) const { + auto a = getParsedParameters(); + auto b = other.getParsedParameters(); + if (a.size() != b.size()) + return false; + + for (unsigned i = 0, n = b.size(); i < n; ++i) { + if (!a[i].isEqual(b[i])) + return false; + } + return true; + } }; /// \brief Attributes that may be applied to declarations. diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 938e2cab67ca4..9f126b458eb3b 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2449,6 +2449,8 @@ ERROR(broken_differentiable_requirement,none, WARNING(differentiable_implicit_noderivative_fixit,none, "stored property has no derivative because it does not conform to " "'Differentiable'; add '@noDerivative' to make it explicit", ()) +NOTE(protocol_witness_missing_differentiable_attr,none, + "candidate is missing attribute '%0'", (StringRef)) NOTE(codable_extraneous_codingkey_case_here,none, "CodingKey case %0 does not match any stored properties", (Identifier)) diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index e348306d8978a..10ee1f3cd90d3 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -481,6 +481,22 @@ swift::matchWitness( cast(witness)->hasThrows()) return RequirementMatch(witness, MatchKind::RethrowsConflict); + // SWIFT_ENABLE_TENSORFLOW + // Differentiation attributes must match completely or the generated + // functions will have the wrong signature. + { + auto *reqDifferentiationAttr = + reqAttrs.getAttribute(/*AllowInvalid*/ true); + auto *witnessDifferentiationAttr = + witnessAttrs.getAttribute( + /*AllowInvalid*/ true); + if (reqDifferentiationAttr && + (!witnessDifferentiationAttr || + !witnessDifferentiationAttr->parametersMatch( + *reqDifferentiationAttr))) + return RequirementMatch(witness, MatchKind::DifferentiableConflict); + } + // We want to decompose the parameters to handle them separately. decomposeFunctionType = true; } else if (auto *witnessASD = dyn_cast(witness)) { @@ -2212,6 +2228,17 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, case MatchKind::NonObjC: diags.diagnose(match.Witness, diag::protocol_witness_not_objc); break; + // SWIFT_ENABLE_TENSORFLOW + case MatchKind::DifferentiableConflict: + std::string diffAttrReq; + { + llvm::raw_string_ostream stream(diffAttrReq); + req->getAttrs().getAttribute()->print(stream, req); + } + diags.diagnose(match.Witness, + diag::protocol_witness_missing_differentiable_attr, + diffAttrReq); + break; } } diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h index 15dbb19e3f616..6ea894a6507cf 100644 --- a/lib/Sema/TypeCheckProtocol.h +++ b/lib/Sema/TypeCheckProtocol.h @@ -205,6 +205,10 @@ enum class MatchKind : uint8_t { /// The witness is explicitly @nonobjc but the requirement is @objc. NonObjC, + + // SWIFT_ENABLE_TENSORFLOW + /// The @differentiable attribute does not match. + DifferentiableConflict, }; /// Describes the kind of optional adjustment performed when @@ -418,6 +422,8 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: + // SWIFT_ENABLE_TENSORFLOW + case MatchKind::DifferentiableConflict: return false; } @@ -446,6 +452,8 @@ struct RequirementMatch { case MatchKind::RethrowsConflict: case MatchKind::ThrowsConflict: case MatchKind::NonObjC: + // SWIFT_ENABLE_TENSORFLOW + case MatchKind::DifferentiableConflict: return false; } diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index bd4056e6987ab..ca7d0c497e276 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -404,3 +404,28 @@ func invalidRequirementConformance(x: Scalar) -> Scalar { func invalidRequirementLayout(x: Scalar) -> Scalar { return x } + + +protocol DiffReq : Differentiable { + // expected-note @+2 {{protocol requires function 'f1'}} + @differentiable(wrt: (self, x)) + func f1(_ x: Float) -> Float + + // expected-note @+2 {{protocol requires function 'f2'}} + @differentiable(wrt: (self, x, y)) + func f2(_ x: Float, _ y: Float) -> Float +} + +// expected-error @+1 {{does not conform to protocol}} +struct ConformingWithErrors : DiffReq { + // expected-note @+1 {{@differentiable(wrt: (x, self))}} + func f1(_ x: Float) -> Float { + return x + } + + // expected-note @+2 {{@differentiable(wrt: (x, y, self))}} + @differentiable(wrt: (self, x)) + func f2(_ x: Float, _ y: Float) -> Float { + return x + y + } +} diff --git a/test/AutoDiff/protocol_requirement_autodiff.swift b/test/AutoDiff/protocol_requirement_autodiff.swift index 9d87ecf81095f..8adeed8948639 100644 --- a/test/AutoDiff/protocol_requirement_autodiff.swift +++ b/test/AutoDiff/protocol_requirement_autodiff.swift @@ -51,6 +51,7 @@ struct Quadratic : DiffReq, Equatable { self.c = c } + @differentiable(wrt: (self, x)) func f(_ x: Float) -> Float { return a * x * x + b * x + c } diff --git a/test/AutoDiff/witness_table_silgen.swift b/test/AutoDiff/witness_table_silgen.swift index 5f32228de8751..6a2a3d4e7c7f9 100644 --- a/test/AutoDiff/witness_table_silgen.swift +++ b/test/AutoDiff/witness_table_silgen.swift @@ -28,6 +28,7 @@ struct S : Proto, VectorNumeric { return (p, { dp in S(p: dp) }) } + @differentiable() func function1(_ x: Float, _ y: Float) -> Float { return x + y + p } @@ -48,6 +49,7 @@ struct S : Proto, VectorNumeric { // CHECK: apply [[VJP1]] // CHECK: } // end sil function 'AD__{{.*}}function1{{.*}}_vjp_SSU' + @differentiable(wrt: (self, x, y)) func function2(_ x: Float, _ y: Float) -> Float { return x + y + p } @@ -68,6 +70,7 @@ struct S : Proto, VectorNumeric { // CHECK: apply [[VJP2]] // CHECK: } // end sil function 'AD__{{.*}}function2{{.*}}_vjp_SSS' + @differentiable(wrt: (y)) func function3(_ x: Float, _ y: Float) -> Float { return x + y + p }