Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Teach SILGen to reabstract @autodiff functions #21367

Merged
merged 7 commits into from Dec 18, 2018

Conversation

marcrasi
Copy link
Collaborator

@marcrasi marcrasi commented Dec 17, 2018

This PR teaches SILGen to reabstract @autodiff functions by splitting them into their components, reabstracting all the components, and then putting them back together again. cc @slavapestov for your thoughts on this approach.

I pulled in some additional fixes from @rxwei:

  1. rxwei/swift@372748e: Update autodiff-associated function type calculation to work with abstraction thunks.
  2. rxwei/swift@fe47c71: Add a conversion rule for convert_function where the converted function has @noescape.
  3. rxwei/swift@ed8e7d0: Retain arguments before partial_apply, fixing a crasher.

I also switched the order of autodiff_function and function_conversion_expr in CSApply, to work around the problem where we can't convert escaping @autodiff functions to noescape.

As a result of these changes, differential operators work, so I enabled the tests.

@rxwei
Copy link
Member

rxwei commented Dec 17, 2018

Awesome! I've just experimented this (reversed function conversion and peer-through for convert_escape_to_noescape in the AD pass) on my local branch. It's great that you've finished this!

@@ -7659,6 +7659,9 @@ class AutoDiffFunctionExtractInst :
Extractee(AutoDiffAssociatedFunctionKind kind);
explicit Extractee(StringRef name);
operator innerty() const { return rawValue; }

llvm::Optional<AutoDiffAssociatedFunctionKind>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for llvm::, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think calling it getExtracteeAsAssociatedFunction() would provide more clarity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to split this out into a separate PR because it's a nice refactoring but it's no longer important for this PR after I include your assoc fn ty calculation fix.

@@ -688,32 +688,53 @@ AutoDiffFunctionExtractInst::Extractee::Extractee(StringRef string) {
rawValue = *result;
}

llvm::Optional<AutoDiffAssociatedFunctionKind>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::Optional<AutoDiffAssociatedFunctionKind>
Optional<AutoDiffAssociatedFunctionKind>

expr = cs.cacheType(new (tc.Context)
AutoDiffFunctionExpr(expr, toFunc));
}
else if (fromEI.isDifferentiable() && !toEI.isDifferentiable()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have an @escaping @autodiff function and we need to convert it to a non-escaping non-differentiable function, won't this emit a (autodiff_function_extract_original_expr (convert_function_expr f)) which would then trigger a reference counting crasher again? In this case, I think autodiff_function_extract_original should happen before function_conversion_expr.

@@ -1668,6 +1670,15 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
loc, innerNewFunc, pai->getSubstitutionMap(), newArgs,
ParameterConvention::Direct_Guaranteed);
}
// convert_escape_to_noescape
if (auto *cetn = dyn_cast<ConvertEscapeToNoEscapeInst>(oldConvertedFunc)) {
auto innerNewFunc = reapplyFunctionConversion(newFunc, oldFunc, cetn->getOperand(), builder, loc, substituteOperand);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

80 cols

// If the associated function comes from a reabstraction thunk, then it is
// impossible to determine the type of the associated function from the
// type of the original function, because we also need to know the
// abstraction pattern that the reabstraction made. So we currently
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what's going on here:

The autodiff-associated function type calculation produces the following type:

$@noescape @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (Float) -> Float)

But reabstraction makes the closure result's parameters and results indirect:

$@noescape @callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)

Currently SILFunctionType::getAutoDiffAssociatedFunctionType uses getFormalResultInfo and getFormalParameterInfo to determine the parameter convention and result convention of the differential/pullback. If we instead mirror their conventions in the original function type, then associated fn types can still be correctly calculated even if they are reabstracted, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated SILFunctionType::getAutoDiffAssociatedFunctionType in 372748e and it's passing all tests. Feel free to pull that in.

@rxwei
Copy link
Member

rxwei commented Dec 17, 2018

I made a bunch of additional fixes based on this PR, please feel free to pull those in.

  1. rxwei@372748e: Update autodiff-associated function type calculation to work with abstraction thunks.

  2. rxwei@fe47c71: Add a conversion rule for convert_function where the converted function has @noescape.

  3. rxwei@ed8e7d0: Retain arguments before partial_apply, fixing a crasher.

With these fixes, generic functional differential operators are working end-to-end on tensors.

import Swift
import TensorFlow

func valueWithPullback<T, R>(
  at x: T, in f: @autodiff (T) -> R
) -> (R, (R.CotangentVector) -> T.CotangentVector)
  where T : Differentiable, R : Differentiable {
  return Builtin.autodiffApplyVJP(f, x)
}

let (y, pullback) = valueWithPullback(at: Tensor<Float>(1)) { x in
  return sin(x)
}
print(y) // 0.84147096
print(pullback(Tensor(1))) // 0.5403023

@marcrasi
Copy link
Collaborator Author

I addressed your comments, pulled in all your fixes, and updated the PR description to describe the pieces you added.

@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

1 similar comment
@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

Copy link
Member

@slavapestov slavapestov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks fine. Couple of nitpicks

// If this is differentiable, then we need to apply a thunk to all the
// components. We extract them, apply a thunk to them, and then combine them
// back into a bundle.
if (sourceType->isDifferentiable()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could factor this out into a separate method?

SGF, loc, original, inputOrigTypeNotDiff, inputSubstTypeNotDiff,
outputOrigTypeNotDiff, outputSubstTypeNotDiff, expectedTLNotDiff);

// TODO: Use parameter indices specified in the funciton type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: funciton

@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

1 similar comment
@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

3 similar comments
@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow

@marcrasi
Copy link
Collaborator Author

@swift-ci please test tensorflow linux

@rxwei
Copy link
Member

rxwei commented Dec 18, 2018

Merging this for rebasing my next PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants