Skip to content

Commit

Permalink
[AutoDiff] Enable differentiation of generic functions.
Browse files Browse the repository at this point in the history
- Relax differentiability diagnostic for generic functions.
  - Previously, an error was emitted when attempting to differentiate any
    generic function. Now, diagnose only functions with indirect
    differentiation parameters/result.
- Propagate differentiation associated function generic signature throughout
  differentiation pass.
  - Change `PrimalGenCloner` to inherit `TypeSubstCloner`.
  - Make primal value structs inherit primal function's generic parameters
    and signature.
  - Calculate correct substitution map for `PrimalGenCloner::visitApplyInst`.
    Emit diagnostic when apply instruction's associated function (e.g. VJP)
    has generic requirements unmet by the primal generic environment.
  - Remap types in `AdjointEmitter`.
- Remove manually `@differentiable` attribute where clause conformance
  requirement checks.
  - `GenericSignatureBuilder` already performs checks so manual checks are
     unnecessary.
  • Loading branch information
dan-zheng committed Jan 21, 2019
1 parent 0cce4b8 commit e2155bf
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 187 deletions.
8 changes: 6 additions & 2 deletions include/swift/AST/DiagnosticsSIL.def
Expand Up @@ -368,8 +368,9 @@ ERROR(autodiff_unsupported_type,none,
"differentiating '%0' is not supported yet", (Type))
ERROR(autodiff_function_not_differentiable,none,
"function is not differentiable", ())
NOTE(autodiff_function_generic_functions_unsupported,none,
"differentiating generic functions is not supported yet", ())
NOTE(autodiff_function_indirect_params_or_result_unsupported,none,
"differentiating functions with parameters or result of unknown size "
"is not supported yet", ())
NOTE(autodiff_external_nondifferentiable_function,none,
"cannot differentiate an external function that has not been marked "
"'@differentiable'", ())
Expand All @@ -386,6 +387,9 @@ NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none,
NOTE(autodiff_function_subset_indices_not_differentiable,none,
"function is differentiable only with respect to a smaller subset of "
"arguments", ())
NOTE(autodiff_function_assoc_func_requirements_unmet,none,
"function call is not differentiate because generic requirements are not "
"met", ())
NOTE(autodiff_opaque_function_not_differentiable,none,
"opaque non-'@autodiff' function is not differentiable", ())
NOTE(autodiff_property_not_differentiable,none,
Expand Down
3 changes: 0 additions & 3 deletions lib/SIL/SILFunctionType.cpp
Expand Up @@ -289,9 +289,6 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
ParameterConvention::Direct_Guaranteed, tangentParams, {},
tangentResults, None, ctx);
SmallVector<SILResultInfo, 8> jvpResults(
curryLevels.back()->getResults().begin(),
curryLevels.back()->getResults().end());
break;
}
case AutoDiffAssociatedFunctionKind::VJP: {
Expand Down

0 comments on commit e2155bf

Please sign in to comment.