Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rand: seed kernels from the host. #2035

Merged
merged 7 commits into from Aug 17, 2023
Merged

rand: seed kernels from the host. #2035

merged 7 commits into from Aug 17, 2023

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Aug 15, 2023

We should not (and cannot) rely on the fact that updating shared memory will be visible across kernel launches. In some cases, it isn't, which combined with pre-initialized shared memory (resulting in us not hitting the make_seed path) causes identical values for the key and counters, resulting in identical random numbers getting generated.

Avoid this by passing a seed from the host. Even in the case that updates to counters don't survive to the next kernel, now the key will be different, resulting in new numbers being generated.

Fixes #2008

@maleadt maleadt added the bugfix This gets something working again. label Aug 15, 2023
@maleadt
Copy link
Member Author

maleadt commented Aug 16, 2023

Initializing the RNG every time doesn't work, so I'm having the compiler insert the initialization sequence now at the start of the kernel via deferred compilation. Not great, but it works.

I wonder if we could generalize this pattern into something reusable. Maybe shared memory with an initialization guarantee, or some kind of shared local storage.

The remaining issue is a miscompilation with our quicksort implementation (which uses dynamic parallelism) when I add a field to the kernel state... Haven't figured that out yet.

@maleadt
Copy link
Member Author

maleadt commented Aug 16, 2023

The remaining issue is a miscompilation with our quicksort implementation (which uses dynamic parallelism) when I add a field to the kernel state... Haven't figured that out yet.

Reproducer:

using CUDA

@inline flex_lt(a, b, eq) = (eq && a == b) || isless(a, b)

function cumsum!(sums)
    shift = 1

    while shift < length(sums)
        to_add = 0
        if threadIdx().x - shift > 0
            to_add = sums[threadIdx().x - shift]
        end

        sync_threads()
        if threadIdx().x - shift > 0
            sums[threadIdx().x] += to_add
        end

        sync_threads()
        shift *= 2
    end
end

@inline function batch_partition(values, pivot, swap, sums, lo, hi, parity)
    sync_threads()
    blockIdx_yz = (blockIdx().z - 1) * gridDim().y + blockIdx().y
    idx0 = lo + (blockIdx_yz - 1) * blockDim().x + threadIdx().x
    if idx0 <= hi
        val = values[idx0]
        comparison = flex_lt(pivot, val, parity)
    end

    if idx0 <= hi
         sums[threadIdx().x] = 1 & comparison
    else
         sums[threadIdx().x] = 1
    end
    sync_threads()

    cumsum!(sums)

    if idx0 <= hi
        dest_idx = comparison ? blockDim().x - sums[end] + sums[threadIdx().x] : threadIdx().x - sums[threadIdx().x]
        if dest_idx <= length(swap)
            swap[dest_idx] = val
        end
    end
    sync_threads()

    if idx0 <= hi
         values[idx0] = swap[threadIdx().x]
    end
    sync_threads()
end

function partition_batches_kernel(values::AbstractArray{T}, pivot, lo, hi, parity) where {T}
    sums = CuDynamicSharedArray(Int, blockDim().x)
    swap = CuDynamicSharedArray(T, blockDim().x, sizeof(sums))
    batch_partition(values, pivot, swap, sums, lo, hi, parity)
    return
end

function find_partition(array, pivot, lo, hi, parity)
    low = lo + 1
    high = hi
    while low <= high
        mid = (low + high) ÷ 2
        if flex_lt(pivot, array[mid], parity)
            high = mid - 1
        else
            low = mid + 1
        end
    end
    return low - 1
end

@inline function consolidate_batch_partition(vals::AbstractArray{T}, pivot, lo, L, b_sums,
                                             parity) where {T}
    sync_threads()
    @inline N_b() = cld(L , blockDim().x)
    @inline batch(k) = threadIdx().x + k * blockDim().x

    my_iter = 0
    a = 0
    b = 0

    for batch_i in 1:N_b()
        if batch_i % blockDim().x == 1
            if batch(my_iter) <= N_b()
                seek_lo = lo + (batch(my_iter) - 1) * blockDim().x
                seek_hi = lo + min(L, batch(my_iter) * blockDim().x)
                b_sums[threadIdx().x] =
                    seek_hi - find_partition(vals, pivot, seek_lo, seek_hi, parity)
            end
            my_iter += 1
        end

        function n_eff()
            if batch_i != N_b() || L % blockDim().x == 0
                blockDim().x
            else
                L % blockDim().x
            end
        end

        sync_threads()
        d = b_sums[batch_i - (my_iter - 1) * blockDim().x]
        c = n_eff() - d
        to_move = min(b, c)
        sync_threads()
        if threadIdx().x <= to_move
            swap = vals[lo + a + threadIdx().x]
        end
        sync_threads()
        if threadIdx().x <= to_move
            vals[lo + a + threadIdx().x] = vals[lo + a + b + c - to_move + threadIdx().x]
        end
        sync_threads()
        if threadIdx().x <= to_move
            vals[lo + a + b + c - to_move + threadIdx().x] = swap
        end
        sync_threads()
        a += c
        b += d
    end

    sync_threads()
    return lo + a
end

function bitonic_median(vals :: AbstractArray{T}, swap, lo, L, stride) where {T}
    sync_threads()
    bitonic_lt(i1, i2) = flex_lt(swap[i1 + 1], swap[i2 + 1], false)

    swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
    sync_threads()

    log_blockDim = begin
        out = 0
        k = blockDim().x
        while k > 1
            k = k >> 1
            out += 1
        end
        out
    end

    log_k = 1
    while log_k <= log_blockDim
        k = 1 << log_k
        j = k ÷ 2

        while j > 0
            i = threadIdx().x - 1
            l = xor(i, j)
            to_swap = (i & k) == 0 && bitonic_lt(l, i) || (i & k) != 0 && bitonic_lt(i, l)
            to_swap = to_swap == (i < l)

            if to_swap
                old_val = swap[l + 1]
            end
            sync_threads()
            if to_swap
                swap[i + 1] = old_val
            end
            sync_threads()
            j = j ÷ 2
        end
        log_k += 1
    end
    sync_threads()
    return swap[blockDim().x ÷ 2]
end

@inline function bubble_sort(vals, swap, lo, L, stride)
    sync_threads()
    L = min(blockDim().x, L)
    if threadIdx().x <= L
        swap[threadIdx().x] = vals[lo + threadIdx().x * stride]
    end
    sync_threads()
    for level in 0:L
        # get left/right neighbor depending on even/odd level
        buddy = threadIdx().x - 1 + 2 * (1 & (threadIdx().x % 2 != level % 2))
        if 1 <= buddy <= L && threadIdx().x <= L
                buddy_val = swap[buddy]
        end
        sync_threads()
        if 1 <= buddy <= L && threadIdx().x <= L
            is_left = threadIdx().x < buddy
            # flex_lt needs to handle equivalence in opposite ways for the
            # two threads in each swap pair. Otherwise, if there are two
            # different values with the same by, one will overwrite the other
            if is_left != flex_lt(swap[threadIdx().x], buddy_val, is_left)
                swap[threadIdx().x] = buddy_val
            end
        end
        sync_threads()
    end
    if threadIdx().x <= L
        vals[lo + threadIdx().x * stride] = swap[threadIdx().x]
    end
    sync_threads()
end

@inline function call_batch_partition(vals::AbstractArray{T}, pivot, swap, b_sums, lo, hi,
                                      parity, sync::Val{true}) where {T}
    L = hi - lo
    if threadIdx().x == 1
        blocks_y = cld(L, blockDim().x)

        # TODO: add wrappers
        device = Ref{Cint}()
        CUDA.cudaGetDevice(device)
        max_blocks_y = Ref{Cint}()
        CUDA.cudaDeviceGetAttribute(max_blocks_y, CUDA.cudaDevAttrMaxGridDimY, device[])

        blocks_z, blocks_y = fldmod1(blocks_y, max_blocks_y[])

        @cuda(blocks=(1,blocks_y,blocks_z), threads=blockDim().x, dynamic=true,
              shmem=blockDim().x*(sizeof(Int)+sizeof(T)),
              partition_batches_kernel(vals, pivot, lo, hi, parity))
        device_synchronize()
    end
end

@inline function call_batch_partition(vals::AbstractArray{T}, pivot, swap, b_sums, lo, hi,
                                      parity, sync::Val{false}) where {T}
    while lo <= hi
        batch_partition(vals, pivot, swap, b_sums, lo, min(hi, lo + blockDim().x), parity)
        lo += blockDim().x
    end
end

function partial_range_overlap(lo, hi, partial :: Nothing)
    true
end

function partial_range_overlap(lo, hi, partial_k)
    return !(lo > last(partial_k) || hi < first(partial_k))
end

function qsort_kernel(vals::AbstractArray{T,N}, lo, hi, parity, sync::Val{S}, sync_depth,
                      prev_pivot, ::Val{dims}, partial=nothing, stuck=-1) where {T, N, S, dims}
    b_sums = CuDynamicSharedArray(Int, blockDim().x)
    swap = CuDynamicSharedArray(T, blockDim().x, sizeof(b_sums))
    shmem = sizeof(b_sums) + sizeof(swap)
    L = hi - lo

    slice = if N == 1
        vals
    else
        otherdims = ntuple(i -> i == dims ? 1 : size(vals, i), N)
        other = CartesianIndices(otherdims)[blockIdx().x]

        slicedims = map(Base.Slice, axes(vals))
        idxs = ntuple(i->i==dims ? slicedims[i] : other[i], N)
        view(vals, idxs...)
    end

    if L <= blockDim().x
        bubble_sort(slice, swap, lo, L, 1)
        return
    end

    pivot = bitonic_median(slice, swap, lo, L, L ÷ blockDim().x)

    call_batch_partition(slice, pivot, swap, b_sums, lo, hi, parity, sync)

    partition = consolidate_batch_partition(slice, pivot, lo, L, b_sums, parity)

    if threadIdx().x == 1
        stuck = (pivot == prev_pivot && partition == lo || partition == hi) ? stuck + 1 : 0

        if stuck < 2 && partition > lo && partial_range_overlap(lo, partition, partial)
            s = CuDeviceStream()
            if S && sync_depth > 1
                @cuda(threads=blockDim().x, dynamic=true, stream=s, shmem=shmem,
                      qsort_kernel(slice, lo, partition, !parity, Val(true), sync_depth - 1,
                      pivot, Val(1), partial, stuck))
            else
                @cuda(threads=blockDim().x, dynamic=true, stream=s, shmem=shmem,
                      qsort_kernel(slice, lo, partition, !parity, Val(false), sync_depth - 1,
                      pivot, Val(1), partial, stuck))
            end
            CUDA.unsafe_destroy!(s)
        end

        if stuck < 2 && partition < hi && partial_range_overlap(partition, hi, partial)
            s = CuDeviceStream()
            if S && sync_depth > 1
                @cuda(threads=blockDim().x, dynamic=true, stream=s, shmem=shmem,
                      qsort_kernel(slice, partition, hi, !parity, Val(true), sync_depth - 1,
                      pivot, Val(1), partial, stuck))
            else
                @cuda(threads=blockDim().x, dynamic=true, stream=s, shmem=shmem,
                      qsort_kernel(slice, partition, hi, !parity, Val(false), sync_depth - 1,
                      pivot, Val(1), partial, stuck))
            end
            CUDA.unsafe_destroy!(s)
        end
    end

    return
end

function quicksort!(c::AbstractArray{T,N}; dims::Int) where {T,N}
    max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH)
    len = size(c, dims)

    1 <= dims <= N || throw(ArgumentError("dimension out of range"))
    otherdims = ntuple(i -> i == dims ? 1 : size(c, i), N)

    my_sort_args = (c, 0, len, true, Val(N==1 && max_depth > 1),
                    max_depth, nothing, Val(dims))

    kernel = @cuda launch=false qsort_kernel(my_sort_args...)

    get_shmem(threads) = threads * (sizeof(Int) + sizeof(T))
    config = launch_configuration(kernel.fun, shmem=threads->get_shmem(threads))
    threads = prevpow(2, config.threads)

    kernel(my_sort_args...;
           blocks=prod(otherdims), threads=threads, shmem=get_shmem(threads))

    return c
end

function main()
    run(`clear`)

    x = CuArray(rand(Float32, 1000))
    quicksort!(x; dims=1)

    y = CuArray(rand(Int32, (1, 1, 1)))
    quicksort!(y; dims=3)

    z = CuArray(rand(Float32, 2^25))
    quicksort!(z; dims=1)

    try
        # this errors with ERROR_ILLEGAL_ADDRESS
        CUDA.cuCtxSynchronize()
    catch err
        # ... show the error
        Base.showerror(stderr, err)
        println(stderr)
    finally
        # ... and then quickly abort (as finalizers otherwise generate lots of output anyway)
        ccall(:kill, Cint, (Cint,Cint), getpid(), 9)
    end
end

isinteractive() || main()

For some reason, extending the kernel state makes this example, which uses dynamic parallelism, fail. It also depends on either --check-bounds=yes or removal of all @inbounds (the above), which makes sense as those constructs use the kernelstate for signal reporting.

Running this under compute-sanitizer hangs the tool... so doing a brute-force reduction now.

We should not (and cannot) rely on the fact that updating shared memory
will be visible across kernel launches. In some cases, it isn't, which
combined with pre-initialized shared memory (resulting in us not hitting
the make_seed path) causes identical values for the key and counters,
resulting in identical random numbers getting generated.

Avoid this by passing a seed from the host. Even in the case that updates
to counters don't survive to the next kernel, now the key will be different,
resulting in new numbers being generated.
@maleadt maleadt merged commit 9796d5a into master Aug 17, 2023
1 check was pending
@maleadt maleadt deleted the tb/rand_seed branch August 17, 2023 19:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bugfix This gets something working again.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

rand in kernel works in a deterministic way
1 participant