New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AD] Serialize/type-check differentiable attribute trailing where clauses. #21675
[AD] Serialize/type-check differentiable attribute trailing where clauses. #21675
Conversation
…uses. - Store where clause `Requirement`s in `DifferentiableAttr` and `SILDifferentiableAttr`. - Implement serialization/deserialization logic for where clause requirements. - Implement type-checking for where clauses. - The primal/adjoint/JVP/VJP generic signature is expected to be the union of the original function generic signature and the where clause generic signature. Where clauses enable the definition of autodiff associated functions that are more constrained than the original function. This is useful for `<Scalar : Numeric>` operations, for example, which are differentiable only `where Scalar : FloatingPoint`. TODOs: - Round trip serialization doesn't quite work: where clauses don't show up in `test/Serialization/differentiable_attr.swift`. - Where clauses with member type constraints `e.g. <Self == Self.CotangentVector>` type-check but crash at serialization.
Revert changes to ParseSIL.cpp (namely making `convertRequirements` public).
c45d869
to
0720d41
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good but I didn’t look super carefully. The general direction makes sense. Let’s merge and progress towards JVPs-in-stdlib.
Previously, `DifferentiableAttr` manually implemented some "trailing objects" behavior, but the pointer arithmetic was incorrect. Now switched to `TrailingObjects`. Fix where clause serialization tests. All tests pass except those involves member type constraints.
0720d41
to
e42f50b
Compare
@swift-ci Please test tensorflow |
@swift-ci Please test tensorflow |
…uses. (apple#21675) - Store where clause `Requirement`s in `DifferentiableAttr` and `SILDifferentiableAttr`. - Reimplement `DifferentiableAttr` using `TrailingObjects`, fix tests. - Implement serialization/deserialization logic for where clause requirements. - Implement type-checking for where clauses. - The primal/adjoint/JVP/VJP generic signature is expected to be the union of the original function generic signature and the where clause generic signature. Where clauses enable the definition of autodiff associated functions that are more constrained than the original function. This is useful for `<Scalar : Numeric>` operations, for example, which are differentiable only `where Scalar : FloatingPoint`. Todos: - Fix where clauses with member type constraints `e.g. <Self == Self.CotangentVector>`, which type-check but crash at serialization. - Add `@differentiable` attribute where clauses to relevant functions in stdlib.
I haven't tracked down the root causes here, but these are some programs that crash the compiler (my guess is I'm doing something basic incorrectly; if you can't reproduce this it could also be an artifact of other patches I'm carrying): @differentiable(adjoint: adjointFoo
where T: Differentiable, T == T.CotangentVector)
func foo<T: Numeric>(_ x: T) -> T {
return x
}
func adjointFoo<T: Differentiable & Numeric>(
_ seed: T, _ originalValue: T, _ dy: T
) -> T where T == T.CotangentVector {
return dy
}
let x = 1.0
let pb = pullback(at: x, in: foo)
print(pb(1.0)) crashes in mandatory inlining:
and @differentiable(adjoint: adjointFoo
where T: Differentiable, T == T.CotangentVector)
func foo<T: Numeric>(_ x: Tensor<T>) -> Tensor<T> {
return x
}
func adjointFoo<T: Differentiable & Numeric>(
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ dy: Tensor<T>
) -> Tensor<T> where T == T.CotangentVector {
return dy
}
let x = Tensor<Float>(1.0)
let pb = pullback(at: x, in: foo)
print(pb(Tensor<Float>(1.0))) crashes in IRGen:
|
@jekbradbury does |
Requirement
s inDifferentiableAttr
andSILDifferentiableAttr
.original function generic signature and the where clause generic signature.
Where clauses enable the definition of autodiff associated functions that are more
constrained than the original function. This is useful for
<Scalar : Numeric>
operations, for example, which are differentiable only
where Scalar : FloatingPoint
.TODOs:
e.g. <Self == Self.CotangentVector>
,which type-check but crash at serialization.
where
clauses to@differentiable
attributes in stdlib. (I've started this.)