Skip to content

Commit

Permalink
[AutoDiff] Rename "associated function" to "derivative function". (#2…
Browse files Browse the repository at this point in the history
…7603)

`assocFn` -> `derivativeFn`
`AssocFn` -> `DerivativeFn`
`assocFunc` -> `derivativeFunc`
`AssocFunc` -> `DerivativeFunc`
`associatedFunction` -> `derivativeFunction`
`AssociatedFunction` -> `DerivativeFunction`
`autodiff associated function` -> `derivative function`
`autodiff-associated function` -> `derivative function`
`AD associated function` -> `derivative function`
`associated differentiation function` -> `derivative function`

This is a follow-up to #27597.

Resolves [TF-882](https://bugs.swift.org/browse/TF-882).
  • Loading branch information
rxwei committed Oct 10, 2019
1 parent 0aee08a commit 431cc43
Show file tree
Hide file tree
Showing 44 changed files with 510 additions and 512 deletions.
23 changes: 11 additions & 12 deletions docs/SIL.rst
Expand Up @@ -5609,35 +5609,34 @@ differentiable_function
sil-differentiable-function-parameter-indices?
sil-differentiable-function-order?
sil-value ':' sil-type
sil-differentiable-function-associated-functions-clause?
sil-differentiable-function-derivative-functions-clause?
sil-differentiable-function-parameter-indices ::=
'[' 'wrt' [0-9]+ (',', [0-9]+)* ']'
sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']'
sil-differentiable-associated-functions-clause ::=
'with' sil-differentiable-associated-function-list
(',' sil-differentiable-associated-function-list)*
sil-differentiable-function-associated-function-list ::=
sil-differentiable-derivative-functions-clause ::=
'with' sil-differentiable-derivative-function-list
(',' sil-differentiable-derivative-function-list)*
sil-differentiable-function-derivative-function-list ::=
'{' sil-value ',' sil-value '}'

differentiable_function [wrt 0] %0 : $(T) -> T \
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}

Bundles a function with its associated differentiation functions into a
``@differentiable`` function. There are two associated functions:
a Jacobian-vector products (JVP) function and a vector-Jacobian products (VJP)
function.
Bundles a function with its derivative functions into a ``@differentiable``
function. There are two derivative functions: a Jacobian-vector products (JVP)
function and a vector-Jacobian products (VJP) function.

``[wrt ...]`` specifies parameter indices that the original function is
differentiable with respect to. When not specified, it defaults to all
parameters.

A ``with`` clause specifies the differentiation functions associated
with the original function. When a ``with`` clause is not specified, the first
operand will be differentiated to produce associated functions, and a ``with``
operand will be differentiated to produce derivative functions, and a ``with``
clause will be added to the instruction.

In raw SIL, it is optional to provide an associated function ``with`` clause.
In raw SIL, it is optional to provide a derivative function ``with`` clause.
In canonical SIL, a ``with`` clause is mandatory.


Expand All @@ -5660,7 +5659,7 @@ differentiable_function_extract
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T

Extracts the original function or an associated function from the given
Extracts the original function or a derivative function from the given
``@differentiable`` function. It must be provided with an extractee:
``[original]``, ``[jvp]`` or ``[vjp]``.

Expand Down
8 changes: 4 additions & 4 deletions include/swift/AST/ASTMangler.h
Expand Up @@ -155,12 +155,12 @@ class ASTMangler : public Mangler {
ModuleDecl *Module);

// SWIFT_ENABLE_TENSORFLOW
// Mangle the autodiff associated function (JVP/VJP) with the given:
// Mangle the derivative function (JVP/VJP) with the given:
// - Mangled original function name.
// - Associated function kind.
// - Derivative function kind.
// - Parameter/result indices.
std::string mangleAutoDiffAssociatedFunctionHelper(
StringRef name, AutoDiffAssociatedFunctionKind kind,
std::string mangleAutoDiffDerivativeFunctionHelper(
StringRef name, AutoDiffDerivativeFunctionKind kind,
const SILAutoDiffIndices &indices);

// SWIFT_ENABLE_TENSORFLOW
Expand Down
6 changes: 3 additions & 3 deletions include/swift/AST/Attr.h
Expand Up @@ -1555,7 +1555,7 @@ class DifferentiableAttr final
AutoDiffIndexSubset *ParameterIndices = nullptr;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
/// The generic signature for autodiff associated functions. Resolved by the
/// The generic signature for autodiff derivative functions. Resolved by the
/// type checker based on the original function's generic signature and the
/// attribute's where clause requirements. This is set only if the attribute
/// has a where clause.
Expand Down Expand Up @@ -1650,10 +1650,10 @@ class DifferentiableAttr final

// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitAssociatedFunctions` is true, omit printing associated functions.
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
void print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause = false,
bool omitAssociatedFunctions = false) const;
bool omitDerivativeFunctions = false) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
Expand Down
40 changes: 20 additions & 20 deletions include/swift/AST/AutoDiff.h
Expand Up @@ -431,48 +431,48 @@ struct AutoDiffLinearMapKind {
operator innerty() const { return rawValue; }
};

/// The kind of an associated function.
struct AutoDiffAssociatedFunctionKind {
/// The kind of a derivative function.
struct AutoDiffDerivativeFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1
} rawValue;

AutoDiffAssociatedFunctionKind() = default;
AutoDiffAssociatedFunctionKind(innerty rawValue) : rawValue(rawValue) {}
AutoDiffAssociatedFunctionKind(AutoDiffLinearMapKind linMapKind)
AutoDiffDerivativeFunctionKind() = default;
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
explicit AutoDiffAssociatedFunctionKind(StringRef string);
explicit AutoDiffDerivativeFunctionKind(StringRef string);
operator innerty() const { return rawValue; }
AutoDiffLinearMapKind getLinearMapKind() {
return (AutoDiffLinearMapKind::innerty)rawValue;
}
};

/// In conjunction with the original function declaration, identifies an
/// autodiff associated function.
/// autodiff derivative function.
///
/// Is uniquely allocated within an ASTContext so that it can be hashed and
/// compared by opaque pointer value.
class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
const AutoDiffAssociatedFunctionKind kind;
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
const AutoDiffDerivativeFunctionKind kind;
AutoDiffIndexSubset *const parameterIndices;

AutoDiffAssociatedFunctionIdentifier(
AutoDiffAssociatedFunctionKind kind,
AutoDiffDerivativeFunctionIdentifier(
AutoDiffDerivativeFunctionKind kind,
AutoDiffIndexSubset *parameterIndices) :
kind(kind), parameterIndices(parameterIndices) {}

public:
AutoDiffAssociatedFunctionKind getKind() const { return kind; }
AutoDiffDerivativeFunctionKind getKind() const { return kind; }
AutoDiffIndexSubset *getParameterIndices() const {
return parameterIndices;
}

static AutoDiffAssociatedFunctionIdentifier *get(
AutoDiffAssociatedFunctionKind kind,
static AutoDiffDerivativeFunctionIdentifier *get(
AutoDiffDerivativeFunctionKind kind,
AutoDiffIndexSubset *parameterIndices, ASTContext &C);

void Profile(llvm::FoldingSetNodeID &ID) {
Expand Down Expand Up @@ -520,15 +520,15 @@ AutoDiffIndexSubset *getLoweredParameterIndices(AutoDiffIndexSubset *indices,
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
/// Returns true if the function name is parsed successfully.
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
AutoDiffAssociatedFunctionKind &kind,
AutoDiffDerivativeFunctionKind &kind,
unsigned &arity, bool &rethrows);

/// Computes the correct linkage for an associated function given the linkage of
/// Computes the correct linkage for a derivative function given the linkage of
/// the original function. If the original linkage is not external and
/// `isAssocFnExported` is true, use the original function's linkage. Otherwise,
/// return hidden linkage.
SILLinkage getAutoDiffAssociatedFunctionLinkage(SILLinkage originalLinkage,
bool isAssocFnExported);
/// `isDerivativeFnExported` is true, use the original function's linkage.
/// Otherwise, return hidden linkage.
SILLinkage getAutoDiffDerivativeFunctionLinkage(SILLinkage originalLinkage,
bool isDerivativeFnExported);

} // end namespace autodiff

Expand Down
8 changes: 4 additions & 4 deletions include/swift/AST/DiagnosticsParse.def
Expand Up @@ -1597,15 +1597,15 @@ ERROR(sil_inst_autodiff_attr_expected_rsquare,PointsToFirstBadToken,
ERROR(sil_inst_autodiff_expected_parameter_index,PointsToFirstBadToken,
"expected the index of a parameter to differentiate with respect to", ())
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
"expected '{' to start an associated function list", ())
"expected '{' to start a derivative function list", ())
ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken,
"expected ',' between operands in an associated function list", ())
"expected ',' between operands in a derivative function list", ())
ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
"expected '}' to start an associated function list", ())
"expected '}' to start a derivative function list", ())
ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
"the number of operand lists does not match the order", ())
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
"expected an associated function kind attribute, e.g. '[jvp]'", ())
"expected a derivative function kind attribute, e.g. '[jvp]'", ())
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
"expected an operand of a function type", ())

Expand Down
11 changes: 5 additions & 6 deletions include/swift/AST/DiagnosticsSema.def
Expand Up @@ -2859,8 +2859,7 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
ERROR(differentiable_attr_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(differentiable_attr_associated_function_protocol,none,
"cannot specify associated differentiation function on protocol "
"requirement", ())
"cannot specify derivative functions on protocol requirements", ())
ERROR(differentiable_attr_overload_not_found,none,
"%0 does not have expected type %1", (DeclName, Type))
ERROR(differentiable_attr_no_currying,none,
Expand All @@ -2874,17 +2873,17 @@ NOTE(differentiable_attr_duplicate_note,none,
ERROR(differentiable_attr_function_not_same_type_context,none,
"%0 is not defined in the current type context", (DeclName))
ERROR(differentiable_attr_specified_not_function,none,
"%0 is not a function to be used as associated differentiation function",
"%0 is not a function to be used as derivative function",
(DeclName))
ERROR(differentiable_attr_class_derivative_not_final,none,
"class member derivative must be final", ())
ERROR(differentiable_attr_ambiguous_function_identifier,none,
"ambiguous or overloaded identifier %0 cannot be used in '@differentiable' "
"attribute", (DeclName))
ERROR(differentiable_attr_invalid_access,none,
"associated differentiation function %0 is required to either be public "
"or @usableFromInline because the original function %1 is public or "
"@usableFromInline", (DeclName, DeclName))
"derivative function %0 is required to either be public or "
"'@usableFromInline' because the original function %1 is public or "
"'@usableFromInline'", (DeclName, DeclName))
ERROR(differentiable_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
Expand Down
18 changes: 9 additions & 9 deletions include/swift/AST/Types.h
Expand Up @@ -3100,10 +3100,10 @@ class AnyFunctionType : public TypeBase {

// SWIFT_ENABLE_TENSORFLOW
/// Given `indices` and `kind`, calculates the type of the corresponding
/// autodiff associated function.
/// autodiff derivative function.
///
/// By default, if the original type has a self parameter list and parameter
/// indices include self, the computed associated function type will return a
/// indices include self, the computed derivative function type will return a
/// linear map taking/returning self's tangent *last* instead of first, for
/// consistency with SIL.
///
Expand All @@ -3114,18 +3114,18 @@ class AnyFunctionType : public TypeBase {
/// \note The original function type (`self`) need not be `@differentiable`.
/// The resulting function will preserve all `ExtInfo` of the original
/// function, including `@differentiable`.
AnyFunctionType *getAutoDiffAssociatedFunctionType(
AnyFunctionType *getAutoDiffDerivativeFunctionType(
AutoDiffIndexSubset *indices, unsigned resultIndex,
AutoDiffAssociatedFunctionKind kind,
AutoDiffDerivativeFunctionKind kind,
LookupConformanceFn lookupConformance,
GenericSignature *whereClauseGenericSignature = nullptr,
bool makeSelfParamFirst = false);

/// Given the type of an autodiff associated function, returns the
/// Given the type of an autodiff derivative function, returns the
/// corresponding original function type.
AnyFunctionType *getAutoDiffOriginalFunctionType();

/// Given the type of a transposing associated function, returns the
/// Given the type of a transposing derivative function, returns the
/// corresponding original function type.
AnyFunctionType *
getTransposeOriginalFunctionType(TransposingAttr *attr,
Expand Down Expand Up @@ -4222,11 +4222,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

/// Returns the type of a differentiation function that is associated with
/// a function of this type.
CanSILFunctionType getAutoDiffAssociatedFunctionType(
CanSILFunctionType getAutoDiffDerivativeFunctionType(
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
AutoDiffAssociatedFunctionKind kind, Lowering::TypeConverter &TC,
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
LookupConformanceFn lookupConformance,
CanGenericSignature associatedFunctionGenericSignature = nullptr);
CanGenericSignature derivativeFunctionGenericSignature = nullptr);

/// Returns a bit vector that specifices which parameters you can
/// differentiate with respect to for this differentiable function type. (e.g.
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SIL/SILCloner.h
Expand Up @@ -973,7 +973,7 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
Optional<std::pair<SILValue, SILValue>> derivativeFns = None;
if (Inst->hasDerivativeFunctions())
derivativeFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
getOpValue(Inst->getVJPFunction()));
getOpValue(Inst->getVJPFunction()));
recordClonedInstruction(
Inst, getBuilder().createDifferentiableFunction(
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
Expand Down

0 comments on commit 431cc43

Please sign in to comment.