New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] lift samefile derivative constriant #28790
Conversation
|
||
extension FunctionInModule1_InternalDerivatives { | ||
// TODO(TF-XXXX): This causes duplicate symbol linker errors. | ||
// TODO(TF-XXXX): Why is @usableFromInline necessary? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@usableFromInline
isn't necessary if we are registering an internal derivative, but is necessary if we are registering a public derivative.
43fac5f
to
e6b39f7
Compare
Ok, this is ready for review now. I'm running swift-apis and swift-models tests on this too because it does some pretty dramatic things. |
lib/TBDGen/TBDGen.cpp
Outdated
@@ -67,6 +67,8 @@ void TBDGenVisitor::addSymbol(StringRef name, SymbolKind kind) { | |||
if (StringSymbols && kind == SymbolKind::GlobalSymbol) { | |||
auto isNewValue = StringSymbols->insert(mangled).second; | |||
(void)isNewValue; | |||
if (!isNewValue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove debug print statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
For a public original function, how do we currently differentiate (no pun) between an internal derivative and a public derivative? Is that by the |
lib/Sema/TypeCheckAttr.cpp
Outdated
@@ -2919,8 +2920,7 @@ static bool checkFunctionSignature( | |||
static IndexSubset *computeDifferentiationParameters( | |||
ArrayRef<ParsedAutoDiffParameter> parsedWrtParams, | |||
AbstractFunctionDecl *function, GenericEnvironment *derivativeGenEnv, | |||
StringRef attrName, SourceLoc attrLoc | |||
) { | |||
StringRef attrName, SourceLoc attrLoc, bool diagnoseErrors = true) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than adding an invasive diagnoseErrors
flag, can you try using DiagnosticTransaction
for diagnoseErrors = false
users instead?
DiagnosticTransaction
collects all diagnostics and provides a way to cancel them:
DiagnosticTransaction transaction(Ctx.Diags);
SWIFT_DEFER { transaction.abort(); };
See #28717 for an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Here are some current working behaviors:
|
lib/Sema/TypeCheckAttr.cpp
Outdated
/// - Stores the attribute in the `ASTContext` list of derivative attributes. | ||
/// - Stores the derivative configuration in the original function's list of | ||
/// derivative configurations. | ||
static void typeCheckDerivativeAttr(ASTContext &Ctx, AttributeChecker *AC, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe AttributeChecker
is just a thin wrapper around an ASTContext
and we can call ASTContext::diagnose
instead. How about removing the AttributeChecker *
argument?
If it's possible to use DiagnosticTransaction
in callers to abort diagnostics instead of conditionally emitting diagnostics based on a flag, that would be preferable too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, done
e6b39f7
to
569b24c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've addressed all comments and uploaded a new commit.
I still want to add more tests, and I'll do that soon.
Actually, I think that testing and polishing this PR is going to take a lot more time and I would rather spend my last day before vacation doing some more immediately useful things. Therefore, I'm going to abandon this until I come back on Jan 2, 2020. I will do one thing though: This PR requires some changes in upstreamed TypeCheckAttr.cpp code, so I'll send a master PR that makes those changes. Hopefully these changes will get downstreamed into tensorflow through the regular merge processes before I resume work on this PR :) |
…rodiff-lift-samefile
@@ -3790,6 +3812,12 @@ static FuncDecl *findAutoDiffDerivativeFunction( | |||
return funcDecl; | |||
} | |||
|
|||
void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) { | |||
if (typeCheckDerivativeAttr(Ctx, D, attr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are @derivative
attributes in the primary file already type-checked (e.g. parameter indices are resolved) after TypeChecker::typeCheckDerivativeAttrs
is called in TypeCheckSourceFileRequest::evaluate
?
If so, is AttributeChecker::visitDerivativeAttr
doing redundant work (e.g. recomputing parameter indices)?
If so, can AttributeChecker::visitDerivativeAttr
be changed to assert that attr
is already type-checked, and to avoid redundant work by only checking for duplicate @derivative
attributes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is redundant work happening. I can't deduplicate the work by doing exactly what you say because the second typechecking uses some of the intermediate calculations to make diagnostic messages (e.g. original function not found messages). So I need to either redo the computation or find some way of passing the intermediate calculations from the first time it typechecks to the second time it typechecks.
Here's an idea that could work:
static bool typeCheckDerivativeAttr(...) {
if (!attr->getOriginalFunction() || !attr->getParameterIndices()) {
// Do typechecking
// If successful, set original function and parameter indices
// If failure, diagnose and return true.
}
// Insert attr into Ctx.DerivativeAttrs
// Diagnose duplicates
}
That way, if the first round of typechecking is successful, then we don't duplicate the work on the second round. If the first round fails, we duplicate the work during the second round, but speed is less important in the failure case so this isn't terrible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't deduplicate the work by doing exactly what you say because the second typechecking uses some of the intermediate calculations to make diagnostic messages (e.g. original function not found messages).
Aha, that makes sense.
Here's an idea that could work:
static bool typeCheckDerivativeAttr(...) { if (!attr->getOriginalFunction() || !attr->getParameterIndices()) { // Do typechecking // If successful, set original function and parameter indices // If failure, diagnose and return true. } // Insert attr into Ctx.DerivativeAttrs // Diagnose duplicates }
That way, if the first round of typechecking is successful, then we don't duplicate the work on the second round. If the first round fails, we duplicate the work during the second round, but speed is less important in the failure case so this isn't terrible.
This idea sounds great! Duplicate work only for bad attributes seems very acceptable.
Since cross-file derivative registration requires a lot of distinct fixes and tests (some of which are drafted in #28790), it would be nice to be able to develop it on `tensorflow` under a flag. This PR adds a flag and a very tiny testcase -- the only situation that I'm aware of that currently works without any of the fixes.
Closing because I'm going to do this as separate incremental pieces gated behind the flag (#28891) instead of one large PR. |
Lifts the samefile derivative constraint, and adds many testcases for cross-file and cross-module
@derivative
attrs.The tests revealed many broken things. I filed TODOs for some of the broken things, and I fixed others in this PR.
The things that are fixed in this PR are:
@differentiable
attribute, then the differentiability witness must have the linkage of the derivative function. (Otherwise, when you define a@derivative
of a function in a separate file or module, the differentiability witness definition gets external linkage because the original function has external linkage, and you're not supposed to give external linkage to definitions.)AFD->getDerivativeFunctionConfigurations()
, because a@derivative
of a function in a different module puts a configuration on the function in the different module, and the TBDGenVisitor doesn't visit functions in other modules!@derivative
attributes in the module (rather than just the@derivative
attributes in the primary file), so that the differentiation pass sees configurations arising from@derivative
s in other files. This approach is slightly inspired bybindExtensions
.