Skip to content

Commit

Permalink
[AutoDiff] Add an error when @differentiable attributes do not match …
Browse files Browse the repository at this point in the history
…when conforming to protocols.
  • Loading branch information
pschuh committed Jan 26, 2019
1 parent d202231 commit c655e50
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 0 deletions.
13 changes: 13 additions & 0 deletions include/swift/AST/Attr.h
Expand Up @@ -1508,6 +1508,19 @@ class DifferentiableAttr final
static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
}

bool parametersMatch(const DifferentiableAttr &other) const {
auto a = getParsedParameters();
auto b = other.getParsedParameters();
if (a.size() != b.size())
return false;

for (unsigned i = 0, n = b.size(); i < n; ++i) {
if (!a[i].isEqual(b[i]))
return false;
}
return true;
}
};

/// \brief Attributes that may be applied to declarations.
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Expand Up @@ -2449,6 +2449,8 @@ ERROR(broken_differentiable_requirement,none,
WARNING(differentiable_implicit_noderivative_fixit,none,
"stored property has no derivative because it does not conform to "
"'Differentiable'; add '@noDerivative' to make it explicit", ())
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))

NOTE(codable_extraneous_codingkey_case_here,none,
"CodingKey case %0 does not match any stored properties", (Identifier))
Expand Down
27 changes: 27 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Expand Up @@ -481,6 +481,22 @@ swift::matchWitness(
cast<AbstractFunctionDecl>(witness)->hasThrows())
return RequirementMatch(witness, MatchKind::RethrowsConflict);

// SWIFT_ENABLE_TENSORFLOW
// Differentiation attributes must match completely or the generated
// functions will have the wrong signature.
{
auto *reqDifferentiationAttr =
reqAttrs.getAttribute<DifferentiableAttr>(/*AllowInvalid*/ true);
auto *witnessDifferentiationAttr =
witnessAttrs.getAttribute<DifferentiableAttr>(
/*AllowInvalid*/ true);
if (reqDifferentiationAttr &&
(!witnessDifferentiationAttr ||
!witnessDifferentiationAttr->parametersMatch(
*reqDifferentiationAttr)))
return RequirementMatch(witness, MatchKind::DifferentiableConflict);
}

// We want to decompose the parameters to handle them separately.
decomposeFunctionType = true;
} else if (auto *witnessASD = dyn_cast<AbstractStorageDecl>(witness)) {
Expand Down Expand Up @@ -2212,6 +2228,17 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
case MatchKind::NonObjC:
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
break;
// SWIFT_ENABLE_TENSORFLOW
case MatchKind::DifferentiableConflict:
std::string diffAttrReq;
{
llvm::raw_string_ostream stream(diffAttrReq);
req->getAttrs().getAttribute<DifferentiableAttr>()->print(stream, req);
}
diags.diagnose(match.Witness,
diag::protocol_witness_missing_differentiable_attr,
diffAttrReq);
break;
}
}

Expand Down
8 changes: 8 additions & 0 deletions lib/Sema/TypeCheckProtocol.h
Expand Up @@ -205,6 +205,10 @@ enum class MatchKind : uint8_t {

/// The witness is explicitly @nonobjc but the requirement is @objc.
NonObjC,

// SWIFT_ENABLE_TENSORFLOW
/// The @differentiable attribute does not match.
DifferentiableConflict,
};

/// Describes the kind of optional adjustment performed when
Expand Down Expand Up @@ -418,6 +422,8 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
// SWIFT_ENABLE_TENSORFLOW
case MatchKind::DifferentiableConflict:
return false;
}

Expand Down Expand Up @@ -446,6 +452,8 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
// SWIFT_ENABLE_TENSORFLOW
case MatchKind::DifferentiableConflict:
return false;
}

Expand Down
25 changes: 25 additions & 0 deletions test/AutoDiff/differentiable_attr_type_checking.swift
Expand Up @@ -404,3 +404,28 @@ func invalidRequirementConformance<Scalar>(x: Scalar) -> Scalar {
func invalidRequirementLayout<Scalar>(x: Scalar) -> Scalar {
return x
}


protocol DiffReq : Differentiable {
// expected-note @+2 {{protocol requires function 'f1'}}
@differentiable(wrt: (self, x))
func f1(_ x: Float) -> Float

// expected-note @+2 {{protocol requires function 'f2'}}
@differentiable(wrt: (self, x, y))
func f2(_ x: Float, _ y: Float) -> Float
}

// expected-error @+1 {{does not conform to protocol}}
struct ConformingWithErrors : DiffReq {
// expected-note @+1 {{@differentiable(wrt: (x, self))}}
func f1(_ x: Float) -> Float {
return x
}

// expected-note @+2 {{@differentiable(wrt: (x, y, self))}}
@differentiable(wrt: (self, x))
func f2(_ x: Float, _ y: Float) -> Float {
return x + y
}
}
1 change: 1 addition & 0 deletions test/AutoDiff/protocol_requirement_autodiff.swift
Expand Up @@ -51,6 +51,7 @@ struct Quadratic : DiffReq, Equatable {
self.c = c
}

@differentiable(wrt: (self, x))
func f(_ x: Float) -> Float {
return a * x * x + b * x + c
}
Expand Down
3 changes: 3 additions & 0 deletions test/AutoDiff/witness_table_silgen.swift
Expand Up @@ -28,6 +28,7 @@ struct S : Proto, VectorNumeric {
return (p, { dp in S(p: dp) })
}

@differentiable()
func function1(_ x: Float, _ y: Float) -> Float {
return x + y + p
}
Expand All @@ -48,6 +49,7 @@ struct S : Proto, VectorNumeric {
// CHECK: apply [[VJP1]]
// CHECK: } // end sil function 'AD__{{.*}}function1{{.*}}_vjp_SSU'

@differentiable(wrt: (self, x, y))
func function2(_ x: Float, _ y: Float) -> Float {
return x + y + p
}
Expand All @@ -68,6 +70,7 @@ struct S : Proto, VectorNumeric {
// CHECK: apply [[VJP2]]
// CHECK: } // end sil function 'AD__{{.*}}function2{{.*}}_vjp_SSS'

@differentiable(wrt: (y))
func function3(_ x: Float, _ y: Float) -> Float {
return x + y + p
}
Expand Down

0 comments on commit c655e50

Please sign in to comment.