Skip to content

Conversation

@maleadt
Copy link
Member

@maleadt maleadt commented Jan 18, 2021

This is the next big step (after JuliaGPU/CUDAnative.jl#609 and #395) in marrying Julia's task-based parallelism with the CUDA APIs. By managing our own default stream, and using that instead of CUDA's default stream objects, we can ensure that all operations (kernel launches, library calls, etc) happen using the stream that's active for the current task. That should make it much easier to isolate independent operations. For example:

# operation that both executes a kernel and calls a library function
julia> run(a,b,c) = (mul!(c, a, b); broadcast!(sin, c))

julia> function main(N=1024)
       a = CUDA.rand(N,N)
       b = CUDA.rand(N,N)
       c = CUDA.rand(N,N)
       NVTX.@range "warmup" run(a,b,c)
       synchronize()
       NVTX.@range "global" x = run(a,b,c)
       NVTX.@range "local" y = CUDA.stream!(CuStream()) do
           run(a,b,c)
       end
       synchronize()
       x == y
       end

image

As you can see, the operations in the global stream are independent from the ones executed in a local stream context 🎉 Here, that results in overlapping execution, which is great for performance.

Remains to be done/decided:

  • should synchronize now default to synchronizing the current stream, or should it still be a device-wide sync?
  • we can further improve this by using CUDA 11.2's async allocation features.

@maleadt maleadt added enhancement New feature or request cuda kernels Stuff about writing CUDA kernels. cuda libraries Stuff about CUDA library wrappers. labels Jan 18, 2021
@maleadt
Copy link
Member Author

maleadt commented Jan 18, 2021

Changing the stream on a task now updates the active task-bound handles of libraries like CUBLAS and CUDNN. That means switching streams isn't entirely free; so it might be costly to write code that performs a single operation like that. That doesn't seem very realistic though, so I don't think it's worth additional complexity (like lazily changing the handle's stream).

@codecov
Copy link

codecov bot commented Jan 18, 2021

Codecov Report

Merging #662 (b6b4afd) into master (d719a6f) will increase coverage by 0.28%.
The diff coverage is 92.10%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #662      +/-   ##
==========================================
+ Coverage   77.79%   78.08%   +0.28%     
==========================================
  Files         117      118       +1     
  Lines        7035     7132      +97     
==========================================
+ Hits         5473     5569      +96     
- Misses       1562     1563       +1     
Impacted Files Coverage Δ
lib/cutensor/wrappers.jl 94.31% <ø> (ø)
src/precompile.jl 0.00% <0.00%> (ø)
lib/cudadrv/memory.jl 82.06% <70.37%> (+0.84%) ⬆️
lib/cudadrv/stream.jl 90.24% <80.00%> (+0.77%) ⬆️
lib/cutensor/CUTENSOR.jl 96.00% <88.88%> (-4.00%) ⬇️
lib/cusolver/CUSOLVER.jl 96.15% <91.30%> (+6.41%) ⬆️
lib/curand/CURAND.jl 96.29% <92.00%> (-3.71%) ⬇️
lib/cudnn/CUDNN.jl 66.03% <92.30%> (+8.25%) ⬆️
lib/cusparse/CUSPARSE.jl 96.66% <92.30%> (+5.75%) ⬆️
lib/cublas/CUBLAS.jl 83.33% <96.00%> (+3.05%) ⬆️
... and 30 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cdd7972...1b4ebb2. Read the comment docs.

@maleadt
Copy link
Member Author

maleadt commented Jan 18, 2021

Current overhead:

julia> @benchmark @cuda identity(nothing)
BenchmarkTools.Trial: 
  memory estimate:  400 bytes
  allocs estimate:  10
  --------------
  minimum time:     1.339 μs (0.00% GC)
  median time:      1.383 μs (0.00% GC)
  mean time:        1.403 μs (0.00% GC)
  maximum time:     4.159 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark CUDA.stream!(()->@cuda(identity(nothing)), s) setup=(s=CUDA.stream())
BenchmarkTools.Trial: 
  memory estimate:  1.36 KiB
  allocs estimate:  53
  --------------
  minimum time:     2.517 μs (0.00% GC)
  median time:      2.606 μs (0.00% GC)
  mean time:        2.756 μs (2.62% GC)
  maximum time:     367.082 μs (98.34% GC)
  --------------
  samples:          10000
  evals/sample:     9

@maleadt
Copy link
Member Author

maleadt commented Jan 18, 2021

julia> @benchmark @cuda identity(nothing)
BenchmarkTools.Trial: 
  memory estimate:  400 bytes
  allocs estimate:  10
  --------------
  minimum time:     1.300 μs (0.00% GC)
  median time:      1.343 μs (0.00% GC)
  mean time:        1.364 μs (0.00% GC)
  maximum time:     3.986 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark CUDA.stream!(()->@cuda(identity(nothing)), s) setup=(s=CUDA.stream())
BenchmarkTools.Trial: 
  memory estimate:  400 bytes
  allocs estimate:  10
  --------------
  minimum time:     1.421 μs (0.00% GC)
  median time:      1.477 μs (0.00% GC)
  mean time:        1.493 μs (0.00% GC)
  maximum time:     4.049 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> # let's initialize some libraries

julia> CUBLAS.handle(); CURAND.default_rng(); CUSPARSE.handle(); CUSOLVER.dense_handle(); CUSOLVER.sparse_handle();

julia> @benchmark CUDA.stream!(()->@cuda(identity(nothing)), s) setup=(s=CUDA.stream())
BenchmarkTools.Trial: 
  memory estimate:  400 bytes
  allocs estimate:  10
  --------------
  minimum time:     1.452 μs (0.00% GC)
  median time:      1.531 μs (0.00% GC)
  mean time:        1.543 μs (0.00% GC)
  maximum time:     8.312 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

@vchuravy
Copy link
Member

vchuravy commented Jan 18, 2021

Changing the stream on a task now updates the active task-bound handles of libraries like CUBLAS and CUDNN. That means switching streams isn't entirely free; so it might be costly to write code that performs a single operation like that. That doesn't seem very realistic though, so I don't think it's worth additional complexity (like lazily changing the handle's stream).

If we keep the stream as a keyword argument to low-level functions I agree, for KA purposes I need to run a short sequence of ops on the same stream and then restore the previous one, or I need to pass a explicit stream object for async_copy, record, wait and launch.

@maleadt
Copy link
Member Author

maleadt commented Jan 18, 2021

for KA purposes I need to run a short sequence of ops on the same stream and then restore the previous one

Is the switching overhead low enough now? We could add some stream kwargs, but again where to draw the line. I imagine I'd have to add most of the ones removed in this PR back. I could imagine adding it back for kernel launches though, as @cuda dynamic=true still has the stream argument anyway. And for other functions its easier to use the lower-level API (crafting a kernel launch like that is pretty tricky).

Copy link
Member

@vchuravy vchuravy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this direction a whole lot! I would push for making most memory operations use the async variant by default (iirc the non-async are device level syncs) and then emit a sync on the stream for async=false, i.e. have async represent the semantics w.r.t. the host. KA uses a busy-wait https://github.com/JuliaGPU/KernelAbstractions.jl/blob/c9c5daa579653877cd9f7893633d4fc224881cce/src/backends/cuda.jl#L63-L72 but that might not be worth it.

I will open a corresponding PR to KA to test these changes there.

cc: @alir @simonbyrne @kpamnany @jakebolewski

@maleadt
Copy link
Member Author

maleadt commented Jan 18, 2021

I would push for making most memory operations use the async variant by default (iirc the non-async are device level syncs) and then emit a sync on the stream for async=false, i.e. have async represent the semantics w.r.t. the host.

That's interesting. That way, the stream argument actually means something in the async=false case (with regular CUDA APIs, you can't specify the stream to sync). That would remove most of the ugliness that the stream kwarg has, so maybe we should keep it then.

@maleadt
Copy link
Member Author

maleadt commented Jan 21, 2021

Update: I retained stream arguments to most functions for the purpose of KA.jl. At the same time though, the performance overhead of querying/switching streams has been greatly reduced, so the new API could be used too. I've also pushed a commit that assigns each task its own non-blocking stream, so tasks will overlap their computations automatically (I would have expected this to break something at least, but apparently we don't accidentally rely on default stream semantics). And memory operations now always use the async APIs + synchronize only the current stream (not the whole device), but still default to synchronous behavior.

@maleadt maleadt changed the title Manage our own per-task stream. Automatic task-based concurrency using local streams Jan 21, 2021
@maleadt maleadt marked this pull request as ready for review January 21, 2021 15:23
@maleadt
Copy link
Member Author

maleadt commented Jan 22, 2021

Well, this is promising. By handling blocking ourselves now during synchronize, and yielding back to Julia's task scheduler, we can naively sync from tasks without global effects:

using CUDA, LinearAlgebra

# dummy calculation (that does not allocate or otherwise synchronize the GPU)
function run(a,b,c)
    NVTX.@range "mul!" mul!(c, a, b)  # uses CUBLAS, so needs a library handle
    NVTX.@range "broadcast!" broadcast!(sin, c, c)
end

# one "iteration", performing the above calculation twice in two tasks
# and comparing the output.
function iteration(a,b,c)
    x,y = missing, missing
    NVTX.@range "iteration" @sync begin
        @async begin
            x = NVTX.@range "run 1" run(a,b,c)
            synchronize()
        end
        @async begin
            y = NVTX.@range "run 2" run(a,b,c)
            synchronize()
        end
    end
    # no need to synchronize here, as both tasks have been synchronized already
    x == y
end

function main(N=1024)
    a = CUDA.rand(N,N)
    b = CUDA.rand(N,N)
    c = CUDA.rand(N,N)
    synchronize() # to make sure we can use this data from other tasks
    NVTX.@range "warmup" iteration(a,b,c)
    GC.gc(true) # we want to collect and cache the library handles used during warmup
    NVTX.@range "main" iteration(a,b,c)
end

With some improvements in caching library handles (turns out creating them is expensive, so we can't naively do so each time a new task performs its first library call), we get very nice concurrent execution from the above example (which doesn't even use unreasonably-large inputs to hide latency):

image

The tiny marks at the bottom are calls to yield(). As you can see at the top of the trace, execution of independent mul! and broadcast! are overlapping.

We use a task-local stream now, so don't need these default stream semantics.
Since we now use explicit per-task streams, we don't rely on
these default stream semantics anymore.
Async H2D copies require pinned memory, but if it isn't
the copy will just execute synchronously.
That way we can yield to other tasks.
This ensures newly-created tasks don't have to spend time on
creating these handles (which often requires memory allocations
and/or global synchronization) during their first library call.
This makes stream switching take 500ns instead of 80, but it simplifies handle management
(making sure any active handles immediately use the new stream, reducing the risk of a
mismatch there), and avoids needless stream switches when switching tasks. The latter
seems like it might happen much more frequently, so it makes sense to have the overhead
when switching instead of when querying.

Furthermore, for KA.jl many APIs take stream arguments again, so it won't have to switch
streams globally anyway.
Now that we don't check for errors on _every_ API call,
it's possible an exception doesn't get caught by
CUDA.check_exceptions, but results in a CUDA error instead.
This state management is too tricky.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuda kernels Stuff about writing CUDA kernels. cuda libraries Stuff about CUDA library wrappers. enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants