Skip to content

Commit

Permalink
[AutoDiff] Support forward mode differentiation of functions with `in…
Browse files Browse the repository at this point in the history
…out` parameters (#33584)

Adds forward mode support for `apply` instruction with `inout` arguments.

Example of supported code:
```
func add(_ x: inout Float, _ y: inout Float) -> Float {
  var result = x
  result += y
  return result
}
print(differential(at: 1, 1, in: add)(1, 1)) // prints "2"
```
  • Loading branch information
efremale committed Aug 22, 2020
1 parent 7c508a0 commit 83b2ebe
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 77 deletions.
67 changes: 34 additions & 33 deletions lib/SILOptimizer/Differentiation/JVPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,6 @@ class JVPCloner::Implementation final
return;
}

// Diagnose functions with active inout arguments.
// TODO(TF-129): Support `inout` argument differentiation.
for (auto inoutArg : ai->getInoutArguments()) {
if (activityInfo.isActive(inoutArg, getIndices())) {
context.emitNondifferentiabilityError(
ai, invoker,
diag::autodiff_cannot_differentiate_through_inout_arguments);
errorOccurred = true;
return;
}
}

auto loc = ai->getLoc();
auto &builder = getBuilder();
auto origCallee = getOpValue(ai->getCallee());
Expand Down Expand Up @@ -1241,6 +1229,10 @@ class JVPCloner::Implementation final
SmallVector<SILValue, 8> differentialAllResults;
collectAllActualResultsInTypeOrder(
differentialCall, differentialDirectResults, differentialAllResults);
for (auto inoutArg : ai->getInoutArguments())
origAllResults.push_back(inoutArg);
for (auto inoutArg : differentialCall->getInoutArguments())
differentialAllResults.push_back(inoutArg);
assert(applyIndices.results->getNumIndices() ==
differentialAllResults.size());

Expand Down Expand Up @@ -1484,11 +1476,14 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
auto origIndResults = original->getIndirectResults();
auto diffIndResults = differential.getIndirectResults();
#ifndef NDEBUG
unsigned numInoutParameters = llvm::count_if(
original->getLoweredFunctionType()->getParameters(),
[](SILParameterInfo paramInfo) { return paramInfo.isIndirectInOut(); });
assert(origIndResults.size() + numInoutParameters == diffIndResults.size());
unsigned numNonWrtInoutParameters = llvm::count_if(
range(original->getLoweredFunctionType()->getNumParameters()),
[&] (unsigned i) {
auto &paramInfo = original->getLoweredFunctionType()->getParameters()[i];
return paramInfo.isIndirectInOut() && !getIndices().parameters->contains(i);
});
#endif
assert(origIndResults.size() + numNonWrtInoutParameters == diffIndResults.size());
for (auto &origBB : *original)
for (auto i : indices(origIndResults))
setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]);
Expand Down Expand Up @@ -1521,23 +1516,10 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
auto origParams = origTy->getParameters();
auto indices = witness->getSILAutoDiffIndices();

// Add differential results.
Optional<SILParameterInfo> inoutDiffParam = None;
for (auto origParam : origTy->getParameters()) {
if (!origParam.isIndirectInOut())
continue;
inoutDiffParam = origParam;
}

if (inoutDiffParam) {
dfResults.push_back(
SILResultInfo(inoutDiffParam->getInterfaceType()
->getAutoDiffTangentSpace(lookupConformance)
->getType()
->getCanonicalType(witnessCanGenSig),
ResultConvention::Indirect));
} else {
for (auto resultIndex : indices.results->getIndices()) {

for (auto resultIndex : indices.results->getIndices()) {
if (resultIndex < origTy->getNumResults()) {
// Handle formal original result.
auto origResult = origTy->getResults()[resultIndex];
origResult = origResult.getWithInterfaceType(
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
Expand All @@ -1548,6 +1530,25 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
->getCanonicalType(witnessCanGenSig),
origResult.getConvention()));
}
else {
// Handle original `inout` parameter.
auto inoutParamIndex = resultIndex - origTy->getNumResults();
auto inoutParamIt = std::next(
origTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
auto paramIndex =
std::distance(origTy->getParameters().begin(), &*inoutParamIt);
// If the original `inout` parameter is a differentiability parameter, then
// it already has a corresponding differential parameter. Skip adding a
// corresponding differential result.
if (indices.parameters->contains(paramIndex))
continue;
auto inoutParam = origTy->getParameters()[paramIndex];
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
lookupConformance);
assert(paramTan && "Parameter type does not have a tangent space?");
dfResults.push_back(
{paramTan->getCanonicalType(), ResultConvention::Indirect});
}
}

// Add differential parameters for the requested wrt parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ extension ${Self} {
static func _jvpMultiplyAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
value: Void, differential: (inout ${Self}, ${Self}) -> Void
) {
let oldLhs = lhs
lhs *= rhs
return ((), { $0 *= $1 })
return ((), { $0 = $0 * rhs + oldLhs * $1 })
}

${Availability(bits)}
Expand Down Expand Up @@ -251,8 +252,9 @@ extension ${Self} {
static func _jvpDivideAssign(_ lhs: inout ${Self}, _ rhs: ${Self}) -> (
value: Void, differential: (inout ${Self}, ${Self}) -> Void
) {
let oldLhs = lhs
lhs /= rhs
return ((), { $0 /= $1 })
return ((), { $0 = ($0 * rhs - oldLhs * $1) / (rhs * rhs) })
}
}

Expand Down
63 changes: 21 additions & 42 deletions test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil -verify %s
// RUN: %target-swift-frontend -emit-sil -enable-experimental-forward-mode-differentiation -verify %s

// Test forward-mode differentiation transform diagnostics.

Expand Down Expand Up @@ -46,8 +46,6 @@ func nonVariedResult(_ x: Float) -> Float {
// Multiple results
//===----------------------------------------------------------------------===//

// TODO(TF-983): Support differentiation of multiple results.
/*
func multipleResults(_ x: Float) -> (Float, Float) {
return (x, x)
}
Expand All @@ -56,28 +54,21 @@ func usesMultipleResults(_ x: Float) -> Float {
let tuple = multipleResults(x)
return tuple.0 + tuple.1
}
*/

//===----------------------------------------------------------------------===//
// `inout` parameter differentiation
//===----------------------------------------------------------------------===//

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutParamNonactiveInitialResult(_ x: Float) -> Float {
var result: Float = 1
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result += x
return result
}

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutParamTuple(_ x: Float) -> Float {
var tuple = (x, x)
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
tuple.0 *= x
return x * tuple.0
}
Expand All @@ -94,49 +85,37 @@ func activeInoutParamControlFlow(_ array: [Float]) -> Float {
return result
}

struct Mut: Differentiable {}
extension Mut {
@differentiable(wrt: x)
mutating func mutatingMethod(_ x: Mut) {}
}

// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func nonActiveInoutParam(_ nonactive: inout Mut, _ x: Mut) -> Mut {
return nonactive.mutatingMethod(x)
struct X: Differentiable {
var x : Float
@differentiable(wrt: y)
mutating func mutate(_ y: X) { self.x = y.x }
}
*/
// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func activeInoutParamMutatingMethod(_ x: Mut) -> Mut {
var result = x
result = result.mutatingMethod(result)
return result
@differentiable
func activeMutatingMethod(_ x: Float) -> Float {
let x1 = X.init(x: x)
var x2 = X.init(x: 0)
x2.mutate(x1)
return x1.x
}
*/

// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func activeInoutParamMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = nonactive
result = result.mutatingMethod(x)
return result

struct Mut: Differentiable {}
extension Mut {
@differentiable(wrt: x)
mutating func mutatingMethod(_ x: Mut) {}
}
*/

// FIXME(TF-984): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = (nonactive, x)
let result2 = result.0.mutatingMethod(result.0)
return result2
func activeInoutParamMutatingMethod(_ x: Mut) -> Mut {
var result = x
result.mutatingMethod(result)
return result
}
*/

//===----------------------------------------------------------------------===//
// Subset parameter differentiation thunks
Expand Down

0 comments on commit 83b2ebe

Please sign in to comment.