Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] lift samefile derivative constriant #28790

Closed
wants to merge 3 commits into from

Conversation

marcrasi
Copy link
Collaborator

@marcrasi marcrasi commented Dec 14, 2019

Lifts the samefile derivative constraint, and adds many testcases for cross-file and cross-module @derivative attrs.

The tests revealed many broken things. I filed TODOs for some of the broken things, and I fixed others in this PR.

The things that are fixed in this PR are:

  • If there is no @differentiable attribute, then the differentiability witness must have the linkage of the derivative function. (Otherwise, when you define a @derivative of a function in a separate file or module, the differentiability witness definition gets external linkage because the original function has external linkage, and you're not supposed to give external linkage to definitions.)
  • Need to check that derivative visibility is less than or equal to original function visibility because the opposite doesn't make sense.
  • TBDGen needed to be changed back to operate on attributes rather than on AFD->getDerivativeFunctionConfigurations(), because a @derivative of a function in a different module puts a configuration on the function in the different module, and the TBDGenVisitor doesn't visit functions in other modules!
  • Typechecking needed to be modified to eagerly check all the @derivative attributes in the module (rather than just the @derivative attributes in the primary file), so that the differentiation pass sees configurations arising from @derivatives in other files. This approach is slightly inspired by bindExtensions.


extension FunctionInModule1_InternalDerivatives {
// TODO(TF-XXXX): This causes duplicate symbol linker errors.
// TODO(TF-XXXX): Why is @usableFromInline necessary?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@usableFromInline isn't necessary if we are registering an internal derivative, but is necessary if we are registering a public derivative.

@marcrasi marcrasi force-pushed the marcrasi-retrodiff-lift-samefile branch 2 times, most recently from 43fac5f to e6b39f7 Compare December 18, 2019 05:06
@marcrasi marcrasi changed the title [AutoDiff] draft of lifting samefile derivative constriant [AutoDiff] lift samefile derivative constriant Dec 18, 2019
@marcrasi marcrasi marked this pull request as ready for review December 18, 2019 05:16
@marcrasi
Copy link
Collaborator Author

Ok, this is ready for review now.

I'm running swift-apis and swift-models tests on this too because it does some pretty dramatic things.

@@ -67,6 +67,8 @@ void TBDGenVisitor::addSymbol(StringRef name, SymbolKind kind) {
if (StringSymbols && kind == SymbolKind::GlobalSymbol) {
auto isNewValue = StringSymbols->insert(mangled).second;
(void)isNewValue;
if (!isNewValue)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove debug print statement

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

lib/Sema/TypeCheckAttr.cpp Show resolved Hide resolved
@rxwei
Copy link
Member

rxwei commented Dec 18, 2019

For a public original function, how do we currently differentiate (no pun) between an internal derivative and a public derivative? Is that by the @usableFromInline attribute?

@@ -2919,8 +2920,7 @@ static bool checkFunctionSignature(
static IndexSubset *computeDifferentiationParameters(
ArrayRef<ParsedAutoDiffParameter> parsedWrtParams,
AbstractFunctionDecl *function, GenericEnvironment *derivativeGenEnv,
StringRef attrName, SourceLoc attrLoc
) {
StringRef attrName, SourceLoc attrLoc, bool diagnoseErrors = true) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding an invasive diagnoseErrors flag, can you try using DiagnosticTransaction for diagnoseErrors = false users instead?

DiagnosticTransaction collects all diagnostics and provides a way to cancel them:

        DiagnosticTransaction transaction(Ctx.Diags);
        SWIFT_DEFER { transaction.abort(); };

See #28717 for an example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@marcrasi
Copy link
Collaborator Author

marcrasi commented Dec 18, 2019

For a public original function, how do we currently differentiate (no pun) between an internal derivative and a public derivative? Is that by the @usableFromInline attribute?

Here are some current working behaviors:

// public derivative
@differentiable
public func f1(...) {...}
@derivative(of: f1)
internal func df1(...) {...}

// public derivative
public func f2(...) {...}
@derivative(of: f2)
public func df2(...) {...}

// internal derivative
public func f3(...) {...}
@derivative(of: f3)
internal func df3(...) {...}

/// - Stores the attribute in the `ASTContext` list of derivative attributes.
/// - Stores the derivative configuration in the original function's list of
/// derivative configurations.
static void typeCheckDerivativeAttr(ASTContext &Ctx, AttributeChecker *AC,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe AttributeChecker is just a thin wrapper around an ASTContext and we can call ASTContext::diagnose instead. How about removing the AttributeChecker * argument?

If it's possible to use DiagnosticTransaction in callers to abort diagnostics instead of conditionally emitting diagnostics based on a flag, that would be preferable too!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, done

@marcrasi marcrasi force-pushed the marcrasi-retrodiff-lift-samefile branch from e6b39f7 to 569b24c Compare December 19, 2019 00:47
Copy link
Collaborator Author

@marcrasi marcrasi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've addressed all comments and uploaded a new commit.

I still want to add more tests, and I'll do that soon.

@marcrasi
Copy link
Collaborator Author

Actually, I think that testing and polishing this PR is going to take a lot more time and I would rather spend my last day before vacation doing some more immediately useful things. Therefore, I'm going to abandon this until I come back on Jan 2, 2020.

I will do one thing though: This PR requires some changes in upstreamed TypeCheckAttr.cpp code, so I'll send a master PR that makes those changes. Hopefully these changes will get downstreamed into tensorflow through the regular merge processes before I resume work on this PR :)

@@ -3790,6 +3812,12 @@ static FuncDecl *findAutoDiffDerivativeFunction(
return funcDecl;
}

void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
if (typeCheckDerivativeAttr(Ctx, D, attr))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are @derivative attributes in the primary file already type-checked (e.g. parameter indices are resolved) after TypeChecker::typeCheckDerivativeAttrs is called in TypeCheckSourceFileRequest::evaluate?

If so, is AttributeChecker::visitDerivativeAttr doing redundant work (e.g. recomputing parameter indices)?

If so, can AttributeChecker::visitDerivativeAttr be changed to assert that attr is already type-checked, and to avoid redundant work by only checking for duplicate @derivative attributes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is redundant work happening. I can't deduplicate the work by doing exactly what you say because the second typechecking uses some of the intermediate calculations to make diagnostic messages (e.g. original function not found messages). So I need to either redo the computation or find some way of passing the intermediate calculations from the first time it typechecks to the second time it typechecks.

Here's an idea that could work:

static bool typeCheckDerivativeAttr(...) {
  if (!attr->getOriginalFunction() || !attr->getParameterIndices()) {
    // Do typechecking
    // If successful, set original function and parameter indices
    // If failure, diagnose and return true.
  }

  // Insert attr into Ctx.DerivativeAttrs
  // Diagnose duplicates
}

That way, if the first round of typechecking is successful, then we don't duplicate the work on the second round. If the first round fails, we duplicate the work during the second round, but speed is less important in the failure case so this isn't terrible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't deduplicate the work by doing exactly what you say because the second typechecking uses some of the intermediate calculations to make diagnostic messages (e.g. original function not found messages).

Aha, that makes sense.

Here's an idea that could work:

static bool typeCheckDerivativeAttr(...) {
  if (!attr->getOriginalFunction() || !attr->getParameterIndices()) {
    // Do typechecking
    // If successful, set original function and parameter indices
    // If failure, diagnose and return true.
  }

  // Insert attr into Ctx.DerivativeAttrs
  // Diagnose duplicates
}

That way, if the first round of typechecking is successful, then we don't duplicate the work on the second round. If the first round fails, we duplicate the work during the second round, but speed is less important in the failure case so this isn't terrible.

This idea sounds great! Duplicate work only for bad attributes seems very acceptable.

rxwei pushed a commit that referenced this pull request Dec 20, 2019
Since cross-file derivative registration requires a lot of distinct fixes and tests (some of which are drafted in #28790), it would be nice to be able to develop it on `tensorflow` under a flag.

This PR adds a flag and a very tiny testcase -- the only situation that I'm aware of that currently works without any of the fixes.
@marcrasi
Copy link
Collaborator Author

marcrasi commented Jan 8, 2020

Closing because I'm going to do this as separate incremental pieces gated behind the flag (#28891) instead of one large PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants