Skip to content

Conversation

@xaellison
Copy link
Contributor

An implementation of quicksort to address: #93

The performance is solid, see src/sorting/usage.jl for quick performance tests. I intend to later include handling lists with a large number of duplicates, which currently can stymie the method for partitioning.

Hopefully the tests included can help clarify the inner workings. Be warned that this is likely not the most "Julian" code that's ever been written in Julia.

@maleadt
Copy link
Member

maleadt commented Sep 17, 2020

Nice, thanks! I'll try to give this a proper look, but two quick questions:

  • not all devices can launch 1024 threads, is it possible to generalize that?
  • can you use this to implement Base.sort!?

@maleadt maleadt added the cuda array Stuff about CuArray. label Sep 17, 2020
@xaellison
Copy link
Contributor Author

xaellison commented Sep 17, 2020

Interestingly, it seems to speed up when the block size is halved. The trade-offs when you use 512 threads compared to 1024:

  • Bubble sort is half as many iterations
  • The second half of partitioning goes down to 1 block, which means it should need ~twice as many steps. Later in the recursion, these can run in parallel streams.
  • The size of static shared memory (per block) is cut in half. Even if there are twice as many blocks would they be easier to schedule because they're smaller?
    [Edit]
  • Also, each grid which "batch partitions" could waste n-1 threads for a grid with blocks of size n.

Using src/sorting/usage.jl:

julia> speed_test_random(Int32, Int(1e7), 1024)
BenchmarkTools.Trial:
  memory estimate:  1.69 KiB
  allocs estimate:  58
  --------------
  minimum time:     684.429 ms (0.00% GC)
  median time:      692.137 ms (0.00% GC)
  mean time:        691.690 ms (0.00% GC)
  maximum time:     698.537 ms (0.00% GC)
  --------------
  samples:          10
  evals/sample:     1

julia> speed_test_random(Int32, Int(1e7), 512)
BenchmarkTools.Trial:
  memory estimate:  1.69 KiB
  allocs estimate:  58
  --------------
  minimum time:     630.825 ms (0.00% GC)
  median time:      642.350 ms (0.00% GC)
  mean time:        641.250 ms (0.00% GC)
  maximum time:     645.002 ms (0.00% GC)
  --------------
  samples:          10
  evals/sample:     1

@maleadt
Copy link
Member

maleadt commented Sep 17, 2020

Interestingly, it seems to speed up when the block size is halved.

I've noticed the same in JuliaGPU/GPUArrays.jl#301 (comment), maximizing occupancy doesn't always guarantee best performance.

@maleadt maleadt added needs changes Changes are needed. needs tests Tests are requested. labels Sep 30, 2020
@maleadt
Copy link
Member

maleadt commented Nov 16, 2020

This one looks related:

      From worker 2:	ERROR: LoadError: There was an error during testing
      From worker 2:	in expression starting at /var/lib/buildkite-agent/builds/3-cyclops-juliacomputing-io/julialang/cuda-dot-jl/examples/quicksort.jl:35
      From worker 2:	error in running finalizer: CUDA.CuError(code=CUDA.cudaError_enum(0x000002bc), meta=nothing)
julia> CUDA.cudaError_enum(0x000002bc)
CUDA_ERROR_ILLEGAL_ADDRESS::cudaError_enum = 0x000002bc

As well as:

Error in testset quicksort:
  LoadError: CUDA error: too many resources requested for launch (code 701, ERROR_LAUNCH_OUT_OF_RESOURCES)

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starting to look very good, thanks! :-)

@xaellison
Copy link
Contributor Author

@maleadt thanks for the feedback! I'll happily implement all the code-organization/testing ones in the next push. The ones closer to the actual sorting will take some thought

@maleadt
Copy link
Member

maleadt commented Nov 24, 2020

I guess something went wrong here merging? Better to rebase anyway.

@xaellison
Copy link
Contributor Author

I guess something went wrong here merging? Better to rebase anyway.

That commit should be the result of a rebase. It looks like tests passed on build 303 and it failed because of missing NetworkOptions.

@maleadt
Copy link
Member

maleadt commented Nov 24, 2020

Well, something still went wrong because there's a bunch of commits of mine in here: https://github.com/JuliaGPU/CUDA.jl/pull/431/commits

@xaellison
Copy link
Contributor Author

How do you feel about me closing this PR and opening a new one :)

@maleadt
Copy link
Member

maleadt commented Nov 24, 2020

Just git reset --soft origin/master and then commit everything afresh and push -f it here.

@xiaodaigh
Copy link

xiaodaigh commented Nov 25, 2020

Can we compare the speed to the SortingAlgorithms.jl's Radix sort performance wise?

@xaellison
Copy link
Contributor Author

xaellison commented Dec 3, 2020

The update for Julia 1.6 forced me to make some major updates, which have made this work a lot better (as a standalone algorithm). I think this is up to date with master, so I'm not sure why this is failing in CI on a dependency issue.

That said, I am getting an error when I try to call quicksort! after including sorting.jl in src/CUDA.jl:

julia> quicksort!(c)
Internal error: encountered unexpected error during compilation of #launch_configuration#22:
LLVM.LLVMException(info="Symbols not found: [ __nv_llmax ]
")
handle_error at /home/ec2-user/.julia/packages/LLVM/dVU7J/src/core/context.jl:105

Because of this, I have not yet addressed all the code organization comments above.

@xaellison
Copy link
Contributor Author

Can we compare the speed to the SortingAlgorithms.jl's Radix sort performance wise?

Sure! This is from an aws ec2 instance of type p3.2xlarge. Radix sort's performance varies significantly with type, but Int32 gives you a roughly average comparison. This validates that the sort is correct, too.

julia> CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH)
4

julia> versioninfo()
Julia Version 1.6.0-DEV.1597
Commit 059ea247b0 (2020-11-30 12:49 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.0 (ORCJIT, broadwell)
Environment:
  JULIA_CUDA_USE_BINARYBUILDER = false

julia> CUDA.versioninfo()
CUDA toolkit 11.0.3, local installation
CUDA driver 11.0.0
NVIDIA driver 450.80.2

Libraries: 
- CUBLAS: 11.2.0
- CURAND: 10.2.1
- CUFFT: 10.2.1
- CUSOLVER: 10.6.0
- CUSPARSE: 11.1.1
- CUPTI: 13.0.0
- NVML: 11.0.0+450.80.2
- CUDNN: 8.0.4 (for CUDA 11.0.0)
- CUTENSOR: missing

Toolchain:
- Julia: 1.6.0-DEV.1597
- LLVM: 11.0.0
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0
- Device support: sm_35, sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80

Environment:
- JULIA_CUDA_USE_BINARYBUILDER: false

1 device:
  0: Tesla V100-SXM2-16GB (sm_70, 11.473 GiB / 15.782 GiB available)

julia> @benchmark quicksort!(c) setup=(a = rand(Int32, Int(1e7)); c=CuArray(a)) teardown=(@assert sort(a)==Array(c)) samples=10 evals=1 seconds=20
BenchmarkTools.Trial: 
  memory estimate:  2.23 KiB
  allocs estimate:  53
  --------------
  minimum time:     298.225 ms (0.00% GC)
  median time:      307.526 ms (0.00% GC)
  mean time:        307.692 ms (0.00% GC)
  maximum time:     319.580 ms (0.00% GC)
  --------------
  samples:          10
  evals/sample:     1

julia> using SortingAlgorithms

julia> @benchmark sort!(a, alg=RadixSort) setup=(a = rand(Int32, Int(1e7))) teardown=(@assert issorted(a)) samples=10 evals=1 seconds=20
BenchmarkTools.Trial: 
  memory estimate:  38.24 MiB
  allocs estimate:  10
  --------------
  minimum time:     473.968 ms (0.00% GC)
  median time:      484.322 ms (0.81% GC)
  mean time:        512.827 ms (6.67% GC)
  maximum time:     646.689 ms (24.90% GC)
  --------------
  samples:          10
  evals/sample:     1

julia> 

@maleadt
Copy link
Member

maleadt commented Dec 7, 2020

I missed your comments -- are you still running into the __nv_llmax issue?

@xaellison
Copy link
Contributor Author

I missed your comments -- are you still running into the __nv_llmax issue?

Putting all of sorting.jl into its own module fixed that

@maleadt
Copy link
Member

maleadt commented Dec 8, 2020

Disabling memcheck is only for incompatibilities, while this seems like a legitimate issue:



      From worker 2:	========= CUDA-MEMCHECK
      From worker 2:	========= Invalid __global__ write of size 8
      From worker 2:	=========     at 0x0000fdf0 in _Z30julia_qsort_async_kernel_3933713CuDeviceArrayI4Int8Li1ELi1EE5Int32S1_3ValILitrueEE
      From worker 2:	=========     by thread (0,0,0) in block (0,0,0)
      From worker 2:	=========     Address 0x7ffed6a08c28 is out of bounds

@xaellison
Copy link
Contributor Author

I have now reproduced the memcheck error. It occurs when sorting a list with many many duplicates. However, if I don't run the test through memcheck, the same test passes (not just runs, but runs correctly).

The recursion can go pretty deep for this case. Memcheck makes all launches blocking. If I prevent launching kernels at depth > 24, the error goes away. That's the maximum sync depth. The same error occurs if we run the following MWE through memcheck:

using CUDA

function kernel(n)
    if n > 0
        @cuda dynamic=true kernel(n - 1)
    end
    return nothing
end

@cuda kernel(24)
synchronize()
@info "no problem"
@cuda kernel(24)
synchronize()
@info "no problem"
# this fails:
@cuda kernel(25)
synchronize()

To prevent such launches from happening, I can make the kernel aware of when it has found a large section of identical values.

@maleadt
Copy link
Member

maleadt commented Dec 9, 2020

If I prevent launching kernels at depth > 24, the error goes away. That's the maximum sync depth.

So it seems important to honor this limit.

@codecov
Copy link

codecov bot commented Dec 19, 2020

Codecov Report

Merging #431 (7b6bbb8) into master (99f664c) will decrease coverage by 1.02%.
The diff coverage is 28.96%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #431      +/-   ##
==========================================
- Coverage   78.79%   77.76%   -1.03%     
==========================================
  Files         116      117       +1     
  Lines        6890     7035     +145     
==========================================
+ Hits         5429     5471      +42     
- Misses       1461     1564     +103     
Impacted Files Coverage Δ
src/CUDA.jl 100.00% <ø> (ø)
src/sorting.jl 28.96% <28.96%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 99f664c...7b6bbb8. Read the comment docs.

@maleadt
Copy link
Member

maleadt commented Jan 15, 2021

I've looked at the most recent failure. There is one error and it only fails when all of the following are true:

1. Compute capability <= 6.x. (quadro p5000, p6000). Pool type doesn't affect anything.

2. Calling `sort` with `by != identity`. `by=x->x` will fail.

3. Multidimensional array with `dims` (even if size is effectively 1D like `(1, 10000)`).

4. `test/sorting.jl` is launched via `test/runtests.jl`. If I manually `include` `test/sorting.jl` from repl, everything passes.

When I inspect the result of the sort, it is clear that it has been correctly partitioned once. That makes me suspect that synchronization has been affected. Is there any obvious reason why that might be the case?

Ah, we're still running into this apparently. Now it happened on an RTX2080, so it doesn't seem compute capability-related. The issue seems to require running under --check-bounds=yes, whereas running under cuda-memcheck fails to reproduce the issue. I thought it would be a failed sync due to exceeding the max sync depth, but adding error checks didn't reveal anything...

@maleadt
Copy link
Member

maleadt commented Jan 15, 2021

Oh FFS this is another case of JuliaGPU/CUDAnative.jl#4 😭 Forcing a trap instead of exit, https://github.com/JuliaGPU/GPUCompiler.jl/blob/1b436f83a7e8e89d9588afd52fd4dc168a8e5a2f/src/ptx.jl#L288-L294, fixes the issue which I reduced to:

using CUDA, Test

function main()
    for i in 1:typemax(Int)
        @show i
        A = rand(1:10, (2, 100000))
        d_A = CuArray(A)
        B = sort(A; dims=2)
        d_B = sort(d_A; dims=2)
        for x in (B, Array(d_B))
            @test issorted(x[1,:])
            @test issorted(x[2,:])
        end
        @test B == Array(d_B)
    end
end

isinteractive() || main()

(but which only reproduces on select hardware/driver combinations, here an RTX 2080 Ti on driver 450.80.2)

Filed as NVIDIA bug #3231266

@maleadt
Copy link
Member

maleadt commented Jan 19, 2021

All green! Let's merge this 🚀
@xaellison Thanks for your work here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuda array Stuff about CuArray. enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants