Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User-facing API like vmap #117

Open
jessebett opened this issue Jul 29, 2020 · 11 comments
Open

User-facing API like vmap #117

jessebett opened this issue Jul 29, 2020 · 11 comments

Comments

@jessebett
Copy link

@ChrisRackauckas suggests that this package provides much of the utilities that would make broadcasting over specified axes efficient. This can be seen in DiffEqGPU.jl.

Can we discuss a user facing API so we can directly compare against JAX vmap.

For instance if I have a function

f(x::Scalar, y::Vector, A::Array) = linalg...

How can I efficiently broadcast over collections of inputs stored in collections with axes like multidimensional arrays ("tensors").

# Broadcast over rows of second argument
vmap(f, in_axes=(nothing, 1, nothing))(scalar, array, array)

# Broadcast over axes for all arguments
vmap(f, in_axes=(1, 1, 3))(vector, array, tensor)

Further, is it possible to provide these as defaults for something like eachslice so that broadcasting Just Works?

f.(scalar, eachrow(array), array)
@vchuravy
Copy link
Member

function vmap!(f, out, args...)
         @kernel function kernelF(f::F, out, args...) where F
            I = @index(Global, Linear)
            elems = ntuple(Val(length(args))) do i
              Base.@_inline_meta
              @inbounds args[i][I]
            end
            @inbounds out[I] = f(elems...)
         end
         
         if out isa Array
           device = CPU()
         # elseif out isa CuArray
         #  device = CUDADevice()
         end
         
         kernel = kernelF(device, (256,))
         event = kernel(f, out, args..., ndrange=size(out))
         wait(event)
         return out
       end

@chriselrod
Copy link

What would be a good API for fusing matmuls and dots with the vmap!, like in this post?

@mcabbott
Copy link

mcabbott commented Jul 30, 2020

This allows vmap!(sqrt, cu(rand(3)), cu(rand(1:10,3).^2)) but things like vmap!(sum, rand(3), [rand(Int8,4) for _ in 1:3]) will only work on the CPU right?

I never thought to try, but kernels which take slices of larger arrays seem to work fine on the CPU. Somehow the linked DiffEqGPU.jl source is able to do this for GPU too, what's the trick? Edit: mcabbott/Tullio.jl#20

@darsnack
Copy link

I posted my long form comment on discourse, but here are the parts related to this discussion:

  • I think to get the performance of fused BLAS operations, the user has to specify f appropriately to vmap?
  • Unlike Jax.vmap, this doesn't have any "depth," right? It can't do anything about what's inside f. Ultimately, I think this would require a custom compiler pass or a significant change to function calling behavior to pass the axis info all the way through the call stack (bad idea imo).

@ChrisRackauckas
Copy link
Member

KA kernels are all about changing what's inside of f to build a SPMD function? I'm not sure what you mean by "depth" in this context.

@darsnack
Copy link

Something like this where "depth" refers to the depth in the call stack. Does KA do this? That's awesome!

@ToucheSir
Copy link

I believe depth here refers to LA op/kernel fusion and (where applicable) reordering/lifting of operations.

@vchuravy
Copy link
Member

vchuravy commented Jul 31, 2020 via email

@ToucheSir
Copy link

ToucheSir commented Aug 25, 2020

Apologies for reviving this thread with yet more questions, but what would be the most appropriate place to define such a DSL (if indeed one exists at all)? In Python land one would likely pick up XLA or TVM, but such monolithic frameworks seem like a poor fit given that all of CUDA, GPUCompiler, KernelAbstractions and Dagger(GPU) exist.

@vchuravy
Copy link
Member

I think the design space is still quite open. Tullio.jl is something like that for fans of Einstein notation. I have my own playground where I explore infrastructure ideas. I might also be convinced that it is a value add for KA, but in general we have orthogonal packages in Julia.

@aterenin
Copy link

I wanted to write a short message that there is definitely user demand for some flavor of vmap.

There are at least two reasons vmap is interesting.

  1. Improved performance. For instance, fusing matrix-vector multiplications into a matrix-matrix multiplication. For more general cases where no such fusion is possible, one can run a for loop in parallel by, say, launching multiple CUDA kernels simultaneously rather than one-at-a-time.
  2. Improved syntax, code readability, and user experience. I have found that vmap significantly reduces the amount of unreadable batching code and boilerplate that one needs to write, and in fact that this one of the main reasons to have vmap in the first place. In this sense, vmap could be considered to perform a similar functions to TensorCast.jl, which more-or-less provides different ways of expressing existing functionality rather than adding new functionality.

I have not myself seen point 2 discussed much and would like to add that I believe there is great value here from the users' perspective, particularly for those who are either newer to the language, or aren't interested in getting into too many details. JAX' main difficulty from my perspective is a significant quantity of boilerplate mess the surrounding ecosystem generates (think repeating Haiku's hk.get_parameter hundreds of times, and similar), and Zygote/Flux do much better in general, but not in all cases. vmap is one of the things JAX gets very right and I think it would benefit the Julia ecosystem to have a well-thought-out analogue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants