Skip to content

Commit

Permalink
Implement kernel random state using CuArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 22, 2021
1 parent 9157693 commit 84d94d2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 14 deletions.
7 changes: 1 addition & 6 deletions src/compiler/execution.jl
Expand Up @@ -205,15 +205,10 @@ mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT}
mod::CuModule
fun::CuFunction

random_state::Union{Nothing,Missing,Mem.DeviceBuffer}
random_state::Union{Nothing,Missing,CuVector{UInt32}}

function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, random_state) where {F,TT}
kernel = new{F,TT}(ctx, mod, fun, random_state)
finalizer(kernel) do k
if k.random_state isa Mem.DeviceBuffer
@context! skip_destroyed=true k.ctx Mem.free(k.random_state; stream_ordered=false)
end
end
end
end

Expand Down
11 changes: 3 additions & 8 deletions src/device/intrinsics/random.jl
Expand Up @@ -17,17 +17,12 @@ struct ThreadLocalXorshift32 <: RandomNumbers.AbstractRNG{UInt32}
end

function init_random_state!(kernel, len)
required_size = sizeof(UInt32) * len

if kernel.random_state === missing
kernel.random_state = Mem.alloc(Mem.Device, required_size; async=true)
elseif sizeof(kernel.random_state) < required_size
Mem.free(kernel.random_state, async=true)
kernel.random_state = Mem.alloc(Mem.Device, required_size; async=true)
if kernel.random_state === missing || length(kernel.random_state) < len
kernel.random_state = CuVector{UInt32}(undef, len)
end

random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state")
random_state_ptr[] = reinterpret(Ptr{Cvoid}, convert(CuPtr{Cvoid}, kernel.random_state))
random_state_ptr[] = reinterpret(Ptr{Cvoid}, pointer(kernel.random_state))
end

@eval @inline function global_random_state()
Expand Down

0 comments on commit 84d94d2

Please sign in to comment.