Skip to content

Commit

Permalink
[AutoDiff] Diagnose unsupported forward-mode control flow. (#27684)
Browse files Browse the repository at this point in the history
Diagnose unsupported forward-mode control flow instead of crashing.
  • Loading branch information
dan-zheng authored and rxwei committed Oct 15, 2019
1 parent 740b63e commit bb67311
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Expand Up @@ -513,6 +513,8 @@ NOTE(autodiff_class_member_not_supported,none,
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
"cannot convert a direct method reference to a '@differentiable' "
"function; use an explicit closure instead", ())
NOTE(autodiff_jvp_control_flow_not_supported,none,
"forward-mode differentiation does not yet support control flow", ())
NOTE(autodiff_control_flow_not_supported,none,
"cannot differentiate unsupported control flow", ())
// TODO(TF-645): Remove when differentiation supports `ref_element_addr`.
Expand Down
9 changes: 9 additions & 0 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Expand Up @@ -8030,6 +8030,15 @@ bool ADContext::processDifferentiableAttribute(
// generation because generated JVP may not match semantics of custom VJP.
// Instead, create an empty JVP.
if (RunJVPGeneration && !vjp) {
// JVP and differential generation do not currently support functions with
// multiple basic blocks.
if (original->getBlocks().size() > 1) {
emitNondifferentiabilityError(
original->getLocation().getSourceLoc(), invoker,
diag::autodiff_jvp_control_flow_not_supported);
return true;
}

JVPEmitter emitter(*this, original, attr, jvp, invoker);
if (emitter.run())
return true;
Expand Down
15 changes: 15 additions & 0 deletions test/AutoDiff/forward_mode_diagnostics.swift
Expand Up @@ -94,3 +94,18 @@ func nondiff(_ f: @differentiable (Float, @nondiff Float) -> Float) -> Float {
// expected-error @+1 {{function is not differentiable}}
return derivative(at: 2, 3) { (x, y) in f(x * x, y) }
}

//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
func cond(_ x: Float) -> Float {
if x > 0 {
return x * x
}
return x + x
}

0 comments on commit bb67311

Please sign in to comment.