Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AD] Serialize/type-check differentiable attribute trailing where clauses. #21675

Merged

Conversation

dan-zheng
Copy link
Collaborator

@dan-zheng dan-zheng commented Jan 7, 2019

  • Store where clause Requirements in DifferentiableAttr and SILDifferentiableAttr.
  • Implement serialization/deserialization logic for where clause requirements.
  • Implement type-checking for where clauses.
    • The primal/adjoint/JVP/VJP generic signature is expected to be the union of the
      original function generic signature and the where clause generic signature.

Where clauses enable the definition of autodiff associated functions that are more
constrained than the original function. This is useful for <Scalar : Numeric>
operations, for example, which are differentiable only where Scalar : FloatingPoint.

TODOs:

  • Fix where clauses with member type constraints e.g. <Self == Self.CotangentVector>,
    which type-check but crash at serialization.
  • Add where clauses to @differentiable attributes in stdlib. (I've started this.)

…uses.

- Store where clause `Requirement`s in `DifferentiableAttr` and `SILDifferentiableAttr`.
- Implement serialization/deserialization logic for where clause requirements.
- Implement type-checking for where clauses.
  - The primal/adjoint/JVP/VJP generic signature is expected to be the union of the
    original function generic signature and the where clause generic signature.

Where clauses enable the definition of autodiff associated functions that are more
constrained than the original function. This is useful for `<Scalar : Numeric>`
operations, for example, which are differentiable only `where Scalar : FloatingPoint`.

TODOs:
- Round trip serialization doesn't quite work: where clauses don't show up in
  `test/Serialization/differentiable_attr.swift`.
- Where clauses with member type constraints `e.g. <Self == Self.CotangentVector>`
  type-check but crash at serialization.
@dan-zheng dan-zheng added the tensorflow This is for "tensorflow" branch PRs. label Jan 7, 2019
@dan-zheng dan-zheng requested a review from rxwei January 7, 2019 10:33
Revert changes to ParseSIL.cpp (namely making `convertRequirements` public).
@dan-zheng dan-zheng force-pushed the differentiable-attr-where-clause branch from c45d869 to 0720d41 Compare January 7, 2019 11:13
Copy link
Member

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but I didn’t look super carefully. The general direction makes sense. Let’s merge and progress towards JVPs-in-stdlib.

lib/AST/TypeCheckRequests.cpp Outdated Show resolved Hide resolved
lib/AST/TypeCheckRequests.cpp Show resolved Hide resolved
Previously, `DifferentiableAttr` manually implemented some "trailing objects"
behavior, but the pointer arithmetic was incorrect. Now switched to
`TrailingObjects`.

Fix where clause serialization tests. All tests pass except those involves
member type constraints.
@dan-zheng dan-zheng force-pushed the differentiable-attr-where-clause branch from 0720d41 to e42f50b Compare January 7, 2019 16:11
@dan-zheng
Copy link
Collaborator Author

@swift-ci Please test tensorflow

@dan-zheng
Copy link
Collaborator Author

@swift-ci Please test tensorflow

@dan-zheng dan-zheng merged commit 8152b78 into apple:tensorflow Jan 7, 2019
jekbradbury pushed a commit to jekbradbury/swift that referenced this pull request Jan 8, 2019
…uses. (apple#21675)

- Store where clause `Requirement`s in `DifferentiableAttr` and `SILDifferentiableAttr`.
- Reimplement `DifferentiableAttr` using `TrailingObjects`, fix tests.
- Implement serialization/deserialization logic for where clause requirements.
- Implement type-checking for where clauses.
  - The primal/adjoint/JVP/VJP generic signature is expected to be the union of the
    original function generic signature and the where clause generic signature.

Where clauses enable the definition of autodiff associated functions that are more
constrained than the original function. This is useful for `<Scalar : Numeric>`
operations, for example, which are differentiable only `where Scalar : FloatingPoint`.

Todos:
- Fix where clauses with member type constraints `e.g. <Self == Self.CotangentVector>`,
  which type-check but crash at serialization.
- Add `@differentiable` attribute where clauses to relevant functions in stdlib.
@jekbradbury
Copy link
Collaborator

I haven't tracked down the root causes here, but these are some programs that crash the compiler (my guess is I'm doing something basic incorrectly; if you can't reproduce this it could also be an artifact of other patches I'm carrying):

@differentiable(adjoint: adjointFoo
                where T: Differentiable, T == T.CotangentVector)
func foo<T: Numeric>(_ x: T) -> T {
  return x
}

func adjointFoo<T: Differentiable & Numeric>(
    _ seed: T, _ originalValue: T, _ dy: T
) -> T where T == T.CotangentVector {
  return dy
}

let x = 1.0
let pb = pullback(at: x, in: foo)
print(pb(1.0))

crashes in mandatory inlining:

Cannot construct Inlined loc from the given location.
UNREACHABLE executed at /usr/local/google/home/jekbradbury/swift-sources/swift/lib/SIL/SILLocation.cpp:220!
Stack dump:
0.      [...]
1.      While running pass #60 SILModuleTransform "MandatoryInlining".

and

@differentiable(adjoint: adjointFoo
                where T: Differentiable, T == T.CotangentVector)
func foo<T: Numeric>(_ x: Tensor<T>) -> Tensor<T> {
  return x
}

func adjointFoo<T: Differentiable & Numeric>(
    _ seed: Tensor<T>, _ originalValue: Tensor<T>, _ dy: Tensor<T>
) -> Tensor<T> where T == T.CotangentVector {
  return dy
}

let x = Tensor<Float>(1.0)
let pb = pullback(at: x, in: foo)
print(pb(Tensor<Float>(1.0)))

crashes in IRGen:

swift: /usr/local/google/home/jekbradbury/swift-sources/llvm/include/llvm/ADT/Optional.h:176: T *llvm::Optional<swift::ProtocolConformanceRef>::getPointer() [T = swift::ProtocolConformanceRef]: Assertion `Storage.hasVal' failed.
Stack dump:
0.      [...]
1.      While emitting IR SIL function "@main".

@rxwei
Copy link
Member

rxwei commented Jan 8, 2019

@jekbradbury does vjp: work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tensorflow This is for "tensorflow" branch PRs.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants