Skip to content

Commit

Permalink
[Serialization] Implement serialization for @differentiable attribute. (
Browse files Browse the repository at this point in the history
#17155)

Implement (de)serialization for all components of `@differentiable` attribute
except the trailing where clause (which needs to be type-checked).

This is a necessary step for the `#adjoint` expression to look up
`@differentiable` attributes declared on functions in other modules correctly.

Addresses SR-7977.
  • Loading branch information
dan-zheng committed Jun 13, 2018
1 parent 3f68d50 commit d2bf0dc
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 7 deletions.
4 changes: 1 addition & 3 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,8 @@ SIMPLE_DECL_ATTR(_forbidSerializingReference, ForbidSerializingReference,
77)

// SWIFT_ENABLE_TENSORFLOW
// FIXME: Make it serialized
DECL_ATTR(differentiable, Differentiable,
OnFunc | LongAttribute | NotSerialized,
/* Not serialized */ 78)
OnFunc | LongAttribute, 78)

SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
OnFunc | OnConstructor, /* Not serialized */ 79)
Expand Down
13 changes: 11 additions & 2 deletions include/swift/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -1447,8 +1447,6 @@ namespace decls_block {
= BCRecordLayout<RestatedObjCConformance_DECL_ATTR>;
using ClangImporterSynthesizedTypeDeclAttrLayout
= BCRecordLayout<ClangImporterSynthesizedType_DECL_ATTR>;
// SWIFT_ENABLE_TENSORFLOW
using DifferentiableDeclAttrLayout = BCRecordLayout<Differentiable_DECL_ATTR>;

using InlineDeclAttrLayout = BCRecordLayout<
Inline_DECL_ATTR,
Expand Down Expand Up @@ -1505,6 +1503,17 @@ namespace decls_block {
BCFixed<1> // specialization kind
>;

// SWIFT_ENABLE_TENSORFLOW
using DifferentiableDeclAttrLayout = BCRecordLayout<
Differentiable_DECL_ATTR,
BCFixed<1>, // Differentiation mode ('forward' or 'reverse').
IdentifierIDField, // Primal name.
DeclIDField, // Primal function declaration.
IdentifierIDField, // Adjoint name.
DeclIDField, // Adjoint function declaration.
BCArray<BCFixed<32>> // Differentiation parameters.
>;

#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
using CLASS##DeclAttrLayout = BCRecordLayout< \
CLASS##_DECL_ATTR, \
Expand Down
46 changes: 46 additions & 0 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2506,6 +2506,52 @@ ModuleFile::getDeclCheckedImpl(DeclID DID, Optional<DeclContext *> ForcedContext
break;
}

// SWIFT_ENABLE_TENSORFLOW
case decls_block::Differentiable_DECL_ATTR: {
AutoDiffMode autodiffMode = AutoDiffMode::Reverse;
unsigned autodiffModeValue;
uint64_t primalNameId;
DeclID primalDeclId;
uint64_t adjointNameId;
DeclID adjointDeclId;
ArrayRef<uint64_t> paramValues;

serialization::decls_block::DifferentiableDeclAttrLayout::readRecord(
scratch, autodiffModeValue, primalNameId, primalDeclId, adjointNameId,
adjointDeclId, paramValues);
autodiffMode = autodiffModeValue
? AutoDiffMode::Reverse
: AutoDiffMode::Forward;

using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
Optional<FuncSpecifier> primal;
FuncDecl *primalDecl = nullptr;
if (primalNameId != 0 && primalDeclId != 0) {
primal = { getIdentifier(primalNameId), DeclNameLoc() };
primalDecl = cast<FuncDecl>(getDecl(primalDeclId));
}
FuncSpecifier adjoint = { getIdentifier(adjointNameId), DeclNameLoc() };
FuncDecl *adjointDecl = cast<FuncDecl>(getDecl(adjointDeclId));

SmallVector<AutoDiffParameter, 4> parameters;
SourceLoc loc;
for (auto paramValue : paramValues) {
auto parameter = paramValue & 0x01
? AutoDiffParameter::getSelfParameter(loc)
: AutoDiffParameter::getIndexParameter(loc, paramValue >> 1);
parameters.push_back(parameter);
}
// TODO: Deserialize trailing where clause.
auto diffAttr =
DifferentiableAttr::create(ctx, loc, SourceRange(), autodiffMode,
loc, parameters, primal, adjoint,
/*TrailingWhereClause*/ nullptr);
diffAttr->setPrimalFunction(primalDecl);
diffAttr->setAdjointFunction(adjointDecl);
Attr = diffAttr;
break;
}

#define SIMPLE_DECL_ATTR(NAME, CLASS, ...) \
case decls_block::CLASS##_DECL_ATTR: { \
bool isImplicit; \
Expand Down
39 changes: 37 additions & 2 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2174,8 +2174,6 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
case DAK_ObjCRuntimeName:
case DAK_RestatedObjCConformance:
case DAK_ClangImporterSynthesizedType:
// SWIFT_ENABLE_TENSORFLOW
case DAK_Differentiable:
llvm_unreachable("cannot serialize attribute");

case DAK_Count:
Expand Down Expand Up @@ -2333,6 +2331,43 @@ void Serializer::writeDeclAttribute(const DeclAttribute *DA) {
writeGenericRequirements(SA->getRequirements(), DeclTypeAbbrCodes);
return;
}

// SWIFT_ENABLE_TENSORFLOW
case DAK_Differentiable: {
auto abbrCode = DeclTypeAbbrCodes[DifferentiableDeclAttrLayout::Code];
auto attr = cast<DifferentiableAttr>(DA);

IdentifierID primalName = 0;
DeclID primalRef = 0;
if (auto primal = attr->getPrimal()) {
primalName = addDeclBaseNameRef(primal->Name.getBaseName());
primalRef = addDeclRef(attr->getPrimalFunction());
}
auto adjointName = addDeclBaseNameRef(attr->getAdjoint().Name.getBaseName());
auto adjointRef = addDeclRef(attr->getAdjointFunction());

SmallVector<uint32_t, 4> parameters;
for (auto param : attr->getParameters()) {
switch (param.getKind()) {
// The self parameter is uniquely identified by 0x01.
case AutoDiffParameter::Kind::Self:
parameters.push_back(1);
break;
// Index parameters are left-shifted by 1.
case AutoDiffParameter::Kind::Index:
parameters.push_back(param.getIndex() << 1);
break;
}
}

DifferentiableDeclAttrLayout::emitRecord(
Out, ScratchRecord, abbrCode, (unsigned) attr->getMode(), primalName,
primalRef, adjointName, adjointRef, parameters);
// TODO: Serialize trailing where clause.
// Type-checking where clause should be done first (mimicking the
// @_specialize attribute).
return;
}
}
}

Expand Down
68 changes: 68 additions & 0 deletions test/Serialization/differentiable_attr.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SWIFT_ENABLE_TENSORFLOW
// TODO: Handle trailing where clause in @differentiable attribute.

// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
// RUN: %target-sil-opt -disable-sil-linking -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s

struct CheckpointsFoo {}
func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) {
return (CheckpointsFoo(), x * x)
}
func dfoo_checkpointed(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float {
return 2 * x
}
// CHECK-DAG: @differentiable(reverse, primal: pfoo, adjoint: dfoo_checkpointed)
// CHECK-DAG: func foo_checkpointed(_ x: Float) -> Float
@differentiable(reverse, primal: pfoo(_:), adjoint: dfoo_checkpointed(_:checkpoints:originalValue:seed:))
func foo_checkpointed(_ x: Float) -> Float {
return x * x
}

struct S<T> {
struct Checkpoints {
let s: S
}
func primal(x: Float) -> (Checkpoints, Float) {
return (Checkpoints(s: self), x)
}
func adjoint_checkpointed(x: Float, _: Checkpoints, _: Float, _: Float) -> S {
return self
}

// CHECK-DAG: @differentiable(reverse, (self), primal: primal, adjoint: adjoint_checkpointed)
// CHECK-DAG: func original(x: Float) -> Float
@differentiable(reverse, withRespectTo: (self), primal: primal, adjoint: adjoint_checkpointed)
func original(x: Float) -> Float {
return x
}
}

func pbaz1<T>(_ x: T, _ y: T) -> ((T, T), T) {
return ((y, y), x)
}
func dbaz1_checkpointed<T>(_ x: T, _ y: T, primal: (T, T), originalValue: T, seed: T) -> (T, T) {
return (y, x)
}
// CHECK-DAG: @differentiable(reverse, primal: pbaz1, adjoint: dbaz1_checkpointed)
// CHECK-DAG: func baz1_checkpointed<T>(_ x: T, _ y: T) -> T
@differentiable(reverse, primal: pbaz1(_:_:), adjoint: dbaz1_checkpointed(_:_:primal:originalValue:seed:))
func baz1_checkpointed<T>(_ x: T, _ y: T) -> T {
return x
}

struct CheckpointsFP<T : FloatingPoint> {
let meow: T
}
func pbaz2<T : FloatingPoint>(_ x: T, _ y: T) -> (CheckpointsFP<T>, T) {
return (CheckpointsFP(meow: 1), x + y)
}
func dbaz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T, primal: CheckpointsFP<T>, originalValue: T, seed: T) -> (T, T) {
return (1, 1)
}
// CHECK-DAG: @differentiable(reverse, primal: pbaz2, adjoint: dbaz2_checkpointed)
// CHECK-DAG: func baz2_checkpointed<T>(_ x: T, _ y: T) -> T where T : FloatingPoint
@differentiable(reverse, primal: pbaz2(_:_:), adjoint: dbaz2_checkpointed(_:_:primal:originalValue:seed:))
func baz2_checkpointed<T : FloatingPoint>(_ x: T, _ y: T) -> T {
return x
}

0 comments on commit d2bf0dc

Please sign in to comment.