Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tapir to AD tests #63

Open
yebai opened this issue Jun 6, 2024 · 25 comments
Open

Add Tapir to AD tests #63

yebai opened this issue Jun 6, 2024 · 25 comments
Labels
enhancement New feature or request

Comments

@yebai
Copy link
Member

yebai commented Jun 6, 2024

No description provided.

@yebai yebai changed the title Add Tapir to tests Add Tapir to AD tests Jun 6, 2024
@Red-Portal
Copy link
Member

Closely related to this, I am wondering whether DifferentiationInterface.jl would be feasible to entirely replace the in-house AD interface. For this my main concerns would be:

  • Does it support GPUs out of the box
  • Does it support structured gradients (no flattening)

@yebai
Copy link
Member Author

yebai commented Jun 6, 2024

Does it support GPUs out of the box

yes

Does it support structured gradients (no flattening)

no and it won’t

@Red-Portal
Copy link
Member

Hi @yebai , I tried adding Tapir, but it seems that there are some issues shared with ReverseDiff's compiled tapes. It seems like Tapir is re-compiling the target function at every call to gradients. I think this is due the fact that we re-define the target function at every step (here), which is why we currently can't use pre-compiled tapes in ReverseDiff. We currently need this to support the STL estimator, which requires stopping gradients and not all AD frameworks provide a way to stop gradients. Not sure how to deal with this.

@yebai
Copy link
Member Author

yebai commented Jun 13, 2024

@willtebbutt might be able to help more.

@willtebbutt
Copy link

@Red-Portal I've just taken a quick look at your link to where we re-define the function each time. I agree that we'll need to abstract that out and re-use it everytime if we want to use Tapir.jl with any success.

Regarding stop gradients -- I think we can probably do this in Tapir.jl, but it would be great if you could open an issue about it so that we can discuss further. I think it's going to involve doing something a little bit strange.

@Red-Portal
Copy link
Member

Hi all, I gave some thought about it. It would be possible to avoid re-defining the function in the current state of things. But, this will cause issues with subsampling (once we get there in the near future): the Turing model has to be updated at each step. Unless the Turing model recorded on the tapes can be mutated externally (is this possible?), this means we will have to redefine the objective at every step.

@Red-Portal
Copy link
Member

Red-Portal commented Jun 13, 2024

Okay the following works:

using DifferentiationInterface, ReverseDiff, LinearAlgebra

struct A
    data
end
       
function main()
    rng = Random.default_rng()
    a   = A(ones(10))
    f(x) = dot(a.data,x)
    println(gradient(f, AutoReverseDiff(true), ones(10)))
    a.data[:] = zeros(10)
    println(gradient(f, AutoReverseDiff(true), ones(10)))
end

But Tapir doesn't, but I guess this is due to a different issue? The following code:

using DifferentiationInterface, Tapir, LinearAlgebra

struct A
    data
end
       
function main()
    rng = Random.default_rng()
    a   = A(ones(10))
    f(x) = dot(a.data,x)
    println(gradient(f, AutoTapir(), ones(10)))
    a.data[:] = zeros(10)
    println(gradient(f, AutoTapir(), ones(10)))
end

yields:

julia> main()
[ Info: Compiling rule for Tuple{var"#f#9"{A}, Vector{Float64}} in safe mode. Disable for best performance.
ERROR: MethodError: Cannot `convert` an object of type 
  Core.OpaqueClosure{Tuple{Any},Tuple{Union{Tapir.ZeroRData, Tapir.RData{@NamedTuple{a::Tapir.RData{@NamedTuple{data}}}}},Tapir.NoRData}} to an object of type 
  Core.OpaqueClosure{Tuple{Any},Tuple{Tapir.RData{@NamedTuple{a::Tapir.RData{@NamedTuple{data}}}},Tapir.NoRData}}

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:84

Stacktrace:
 [1] Tapir.DerivedRule{…}(fwds_oc::Function, pb_oc::Function, isva::Val{…}, nargs::Val{…})
   @ Tapir ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:654
 [2] build_rrule(interp::Tapir.TapirInterpreter{…}, sig::Type{…}; safety_on::Bool, silence_safety_messages::Bool)
   @ Tapir ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:808
 [3] build_rrule
   @ ~/.julia/packages/Tapir/7eB9t/src/interpreter/s2s_reverse_mode_ad.jl:741 [inlined]
 [4] prepare_pullback(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64}, dy::Float64)
   @ DifferentiationInterfaceTapirExt ~/.julia/packages/DifferentiationInterface/lN3yP/ext/DifferentiationInterfaceTapirExt/onearg.jl:8
 [5] prepare_gradient(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/lN3yP/src/first_order/gradient.jl:59
 [6] gradient(f::var"#f#9"{A}, backend::AutoTapir, x::Vector{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/lN3yP/src/first_order/gradient.jl:74
 [7] main()
   @ Main ./REPL[23]:9
 [8] top-level scope
   @ REPL[24]:1
Some type information was truncated. Use `show(err)` to see complete types.

@Red-Portal
Copy link
Member

Red-Portal commented Jun 13, 2024

To be honest, I would prefer having the forward path immutable as possible, but this is not going to be possible for subsampling unless we redefine the target function in every step.

@willtebbutt
Copy link

willtebbutt commented Jun 14, 2024

@Red-Portal this is a really interesting failure case (I was aware that this could happen, but hadn't figured out a concrete case in which it would). I'm going to create a unit test out of it and fix it in Tapir.

In the mean time, if you make A parametric, I find that it fixes it locally. Something like:

struct A{T}
    data::T
end

@Red-Portal
Copy link
Member

Red-Portal commented Jun 14, 2024

@willtebbutt Could we just have an interface for differentiating f(x, data) with respect to x while receiving both x and data for each evaluation? It's such a common use case for any data related stuff, and I think it would massively simplify the use of any precompilation-based AD.

@willtebbutt
Copy link

So you can definitely do this using Tapir.jl's value_and_gradient!! interface -- the limitation in the use-case above is DifferentiationInterface.jl, rather than Tapir.jl. There's a discussion about this here.

@yebai
Copy link
Member Author

yebai commented Jun 14, 2024

@willtebbutt, it might be easier if you could create a working example so @Red-Portal can adapt it.

@willtebbutt
Copy link

willtebbutt commented Jun 14, 2024

Sure. Something like this:

f(x, data) = dot(data, x)
x = randn(10)
data = randn(10)

rule = Tapir.build_rrule(f, x, data)
Tapir.value_and_gradient!!(rule, f, x, data)

should do the trick.

You should re-use rule each time you run Tapir.value_and_gradient!!.

Note that you don't have to have the same size inputs each time, just the same type -- Tapir.jl is a little less restrictive than ReverseDiff in this regard.

@yebai
Copy link
Member Author

yebai commented Jun 14, 2024

Just to clarify, Tapir is source-to-source transformation-based AD. It is like Zygote but addresses many design limitations. So, Tapir works well with data and input-dependent control flows. It even works with global variables as long as they don't mutate.

As @willtebbutt mentioned above, the only assumption is that the argument types stay the same as the ones for building a rule. When argument types change, a new rule should be built by calling Tapir.build_rrule.

@Red-Portal
Copy link
Member

Red-Portal commented Jun 14, 2024

Okay, this sounds much more promising. I'll come back to this once the rng issue is resolved. In the meantime I'll restructure things such that we don't have to redefine the target function. Thanks for the pointers!

@wsmoses
Copy link

wsmoses commented Jun 18, 2024

Might as well add Enzyme too while you're at it. I don't forsee any issues with the things mentioned above, and it indeed has rng/abstract type/support for non-differentiated data.

@Red-Portal
Copy link
Member

Red-Portal commented Jun 18, 2024

@wsmoses Enzyme has been usable for a while since this PR, but just not tested against. The problem is that Enzyme really doesn't play nice with Distributions.jl, so it's pretty much unusable at the moment. Even the most basic models result in immediate segfaults. From my understanding, that is partially why Tapir.jl is being worked on. (Correct? @yebai @willtebbutt )

@wsmoses
Copy link

wsmoses commented Jun 18, 2024

What about distributions.jl does it not play well with? I'm not aware of any outstanding issues on Enzyme.jl that are related?

@wsmoses
Copy link

wsmoses commented Jun 18, 2024

But yeah if you have any problems please open issues and we'll work fast to get them resolved!

@wsmoses
Copy link

wsmoses commented Jun 18, 2024

and while I can't speak to why the folks started work on the Taped AD tool, @yebai et al have a project funded for the next three years to add Enzyme to turing so it is intended to be well supported [with substantial speedups already shown, but more work on the ingetation end]. https://www.turing.ac.uk/research/research-projects/development-composable-parallelisable-and-user-friendly-inference-and

Again happy to quickly fix any issues that you see :)

@Red-Portal
Copy link
Member

I'll take that as a reminder and try it again. Thanks!

@yebai
Copy link
Member Author

yebai commented Jun 19, 2024

Yes, it would be good to test Enzyme, too.

Tapir is a less ambitious project than Enzyme, initially focusing on a rewrite of Zygote. It has strengths and weaknesses. It's entirely written in Julia, with a code base that is much smaller and more hackable, but Tapir's performance will slightly lag behind Enzyme for the foreseeable future.

@Red-Portal Red-Portal added the enhancement New feature or request label Jun 21, 2024
@willtebbutt
Copy link

@Red-Portal your examples (both the abstractly-typed data field and using TaskLocalRNGs) should now both work on version 0.2.23 of Tapir.jl (the latest release).

@Red-Portal
Copy link
Member

Hi @willtebbutt . Thanks for the fixes! Although, it seems there are still some issues left, as can be seen in #68. I currently don't have enough time to dig deeper (distill the issues into smaller MWEs) into this, but let me know if there is anything I can immediately help fix.

@Red-Portal
Copy link
Member

Update: I made a new PR (#71) because the previous one had a messed up git commit history.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants