Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] Enable differentiation of generic functions. #22023

Merged
merged 3 commits into from Jan 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/swift/AST/DiagnosticsSIL.def
Expand Up @@ -368,8 +368,9 @@ ERROR(autodiff_unsupported_type,none,
"differentiating '%0' is not supported yet", (Type))
ERROR(autodiff_function_not_differentiable,none,
"function is not differentiable", ())
NOTE(autodiff_function_generic_functions_unsupported,none,
"differentiating generic functions is not supported yet", ())
NOTE(autodiff_function_indirect_params_or_result_unsupported,none,
dan-zheng marked this conversation as resolved.
Show resolved Hide resolved
"differentiating functions with parameters or result of unknown size "
"is not supported yet", ())
NOTE(autodiff_external_nondifferentiable_function,none,
"cannot differentiate an external function that has not been marked "
"'@differentiable'", ())
Expand All @@ -386,6 +387,9 @@ NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none,
NOTE(autodiff_function_subset_indices_not_differentiable,none,
"function is differentiable only with respect to a smaller subset of "
"arguments", ())
NOTE(autodiff_function_assoc_func_requirements_unmet,none,
"function call is not differentiate because generic requirements are not "
"met", ())
NOTE(autodiff_opaque_function_not_differentiable,none,
"opaque non-'@autodiff' function is not differentiable", ())
NOTE(autodiff_property_not_differentiable,none,
Expand Down
3 changes: 0 additions & 3 deletions lib/SIL/SILFunctionType.cpp
Expand Up @@ -289,9 +289,6 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
ParameterConvention::Direct_Guaranteed, tangentParams, {},
tangentResults, None, ctx);
SmallVector<SILResultInfo, 8> jvpResults(
curryLevels.back()->getResults().begin(),
curryLevels.back()->getResults().end());
break;
}
case AutoDiffAssociatedFunctionKind::VJP: {
Expand Down
398 changes: 254 additions & 144 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp

Large diffs are not rendered by default.

38 changes: 1 addition & 37 deletions lib/Sema/TypeCheckAttr.cpp
Expand Up @@ -2188,23 +2188,6 @@ static FuncDecl *resolveAutoDiffAssociatedFunction(
return candidate;
}

// SWIFT_ENABLE_TENSORFLOW
/// Require that the given type either not involve type parameters or be
/// a type parameter.
// TODO: Generalize function to take a `Diagnostic` and merge with
// `diagnoseIndirectGenericTypeParam`.
static bool diagnoseDifferentiableAttrIndirectGenericType(SourceLoc loc,
Type type,
TypeRepr *typeRepr) {
if (type->hasTypeParameter() && !type->is<GenericTypeParamType>()) {
type->getASTContext()
.Diags.diagnose(loc, diag::differentiable_attr_only_generic_param_req)
.highlight(typeRepr->getSourceRange());
return true;
}
return false;
}

void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
auto &ctx = TC.Context;
auto lookupConformance =
Expand Down Expand Up @@ -2303,11 +2286,10 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
RequirementRequest::visitRequirements(
WhereClauseOwner(original, attr), TypeResolutionStage::Structural,
[&](const Requirement &req, RequirementRepr *reqRepr) {
// Check additional constraints.
// TODO: refine constraints.
switch (req.getKind()) {
case RequirementKind::SameType:
case RequirementKind::Superclass:
case RequirementKind::Conformance:
break;

// Layout requirements are not supported.
Expand All @@ -2316,24 +2298,6 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
diag::differentiable_attr_unsupported_req_kind)
.highlight(reqRepr->getSourceRange());
return false;

// Conformance requirements are valid if:
// - The first type is a generic type parameter type.
// - The second type is a protocol type or protocol composition type.
case RequirementKind::Conformance:
if (diagnoseDifferentiableAttrIndirectGenericType(
attr->getLocation(), req.getFirstType(),
reqRepr->getSubjectRepr()))
return false;

if (!req.getSecondType()->is<ProtocolType>() &&
!req.getSecondType()->is<ProtocolCompositionType>()) {
TC.diagnose(attr->getLocation(),
diag::differentiable_attr_non_protocol_type_constraint_req)
.highlight(reqRepr->getSourceRange());
return false;
}
break;
}

// Add requirement to generic signature builder.
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/autodiff_diagnostics.swift
Expand Up @@ -24,7 +24,7 @@ _ = gradient(at: 0, in: one_to_one_0) // okay!
// Generics
//===----------------------------------------------------------------------===//

// expected-note @+3 {{differentiating generic functions is not supported yet}}
// expected-note @+3 {{differentiating functions with parameters or result of unknown size is not supported yet}}
dan-zheng marked this conversation as resolved.
Show resolved Hide resolved
// expected-error @+2 {{function is not differentiable}}
@differentiable()
func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/differentiable_attr_type_checking.swift
Expand Up @@ -391,7 +391,7 @@ extension FloatingPoint {
}
}

// expected-error @+2 {{only conformances to protocol types are supported by @differentiable attribute}}
// expected-error @+2 {{type 'Scalar' constrained to non-protocol, non-class type 'Float'}}
// expected-error @+1 {{can only differentiate with respect to parameters that conform to 'Differentiable', but 'Scalar' does not conform to 'Differentiable'}}
@differentiable(where Scalar : Float)
func invalidRequirementConformance<Scalar>(x: Scalar) -> Scalar {
Expand Down
14 changes: 14 additions & 0 deletions test/AutoDiff/generics.swift
@@ -0,0 +1,14 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

struct Tensor<T : VectorNumeric> : VectorNumeric, Differentiable {
var value: Float
init(_ value: Float) { self.value = value }
}

func generic<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Float {
return x.value + x.value
}
print(pullback(at: Tensor<Float>(1), in: generic))
print(pullback(at: Tensor<Float>(3), in: generic))

// TODO: add more tests.
7 changes: 7 additions & 0 deletions test/TensorFlowRuntime/tensor_autodiff_runtime.swift
Expand Up @@ -22,6 +22,13 @@ TensorADTests.testAllBackends("TestSimpleGrad") {
expectTrue(gradient(at: [[10], [20]], in: square) == [[20], [40]])
}

TensorADTests.testAllBackends("TestGenericGrad") {
func square<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Tensor<T> {
return x * x
}
expectEqual([0.2, 0.4, 0.6], gradient(at: Tensor([0.1, 0.2, 0.3]), in: square))
}

TensorADTests.testAllBackends("+") {
let f = { (a: Tensor<Float>, b: Tensor<Float>) in a + b }
expectTrue((Tensor(1), Tensor(1)) == gradient(at: Tensor(0), Tensor(0), in: f))
Expand Down