Skip to content

Conversation

@simeonschaub
Copy link
Member

closes #641

Testing locally, I am running into #624

@github-actions
Copy link
Contributor

github-actions bot commented Nov 19, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/pocl/compiler/compilation.jl b/src/pocl/compiler/compilation.jl
index fb9f9585..8831717a 100644
--- a/src/pocl/compiler/compilation.jl
+++ b/src/pocl/compiler/compilation.jl
@@ -21,11 +21,15 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
 
 GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState
 
-function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
-                                          mod::LLVM.Module, entry::LLVM.Function)
-    entry = invoke(GPUCompiler.finish_module!,
-                   Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
-                   job, mod, entry)
+function GPUCompiler.finish_module!(
+        @nospecialize(job::OpenCLCompilerJob),
+        mod::LLVM.Module, entry::LLVM.Function
+    )
+    entry = invoke(
+        GPUCompiler.finish_module!,
+        Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
+        job, mod, entry
+    )
 
     # if this kernel uses our RNG, we should prime the shared state.
     # XXX: these transformations should really happen at the Julia IR level...
@@ -37,7 +41,7 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
 
         # create a deferred compilation job for `initialize_rng_state`
         src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
-        cfg = CompilerConfig(job.config; kernel=false, name=nothing)
+        cfg = CompilerConfig(job.config; kernel = false, name = nothing)
         job = CompilerJob(src, cfg, job.world)
         id = length(GPUCompiler.deferred_codegen_jobs) + 1
         GPUCompiler.deferred_codegen_jobs[id] = job
@@ -45,7 +49,7 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
         # generate IR for calls to `deferred_codegen` and the resulting function pointer
         top_bb = first(blocks(entry))
         bb = BasicBlock(top_bb, "initialize_rng")
-        @dispose builder=IRBuilder() begin
+        @dispose builder = IRBuilder() begin
             position!(builder, bb)
             subprogram = LLVM.subprogram(entry)
             if subprogram !== nothing
@@ -158,5 +162,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
         error("Your device does not support SPIR-V, which is currently required for native execution.")
     end
     cl.build!(prog)
-    (; kernel=cl.Kernel(prog, compiled.entry), compiled.device_rng)
+    return (; kernel = cl.Kernel(prog, compiled.entry), compiled.device_rng)
 end
diff --git a/src/pocl/device/random.jl b/src/pocl/device/random.jl
index b70ce781..ff26a1e4 100644
--- a/src/pocl/device/random.jl
+++ b/src/pocl/device/random.jl
@@ -21,7 +21,7 @@ end
 function initialize_rng_state()
     subgroup_id = get_sub_group_id()
     @inbounds global_random_keys()[subgroup_id] = kernel_state().random_seed
-    @inbounds global_random_counters()[subgroup_id] = 0
+    return @inbounds global_random_counters()[subgroup_id] = 0
 end
 
 # generators
@@ -37,7 +37,7 @@ struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64} end
 @inline function Base.getproperty(rng::Philox2x32, field::Symbol)
     subgroup_id = get_sub_group_local_id()
 
-    if field === :key
+    return if field === :key
         @inbounds global_random_keys()[subgroup_id]
     elseif field === :ctr1
         @inbounds global_random_counters()[subgroup_id]
@@ -65,7 +65,7 @@ end
 Seed the on-device Philox2x32 generator with an UInt32 number.
 Should be called by at least one thread per warp.
 """
-function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=UInt32(0))
+function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer = UInt32(0))
     rng.key = seed % UInt32
     rng.ctr1 = counter
     return
@@ -95,25 +95,57 @@ end
 
 Generate a byte of random data using the on-device Tausworthe generator.
 """
-function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
+function Random.rand(rng::Philox2x32{R}, ::Type{UInt64}) where {R}
     ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key
 
-    if R > 0                               ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 1  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 2  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 3  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 4  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 5  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 6  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 7  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 8  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 9  key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
-    if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
+    if R > 0
+        ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 1
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 2
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 3
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 4
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 5
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 6
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 7
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 8
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 9
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 10
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 11
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 12
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 13
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 14
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
+    if R > 15
+        key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key)
+    end
 
     # update the warp counter
     # NOTE: this performs the same update on every thread in the warp, but each warp writes
@@ -127,13 +159,12 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
 end
 
 
-
 # a hacky method of exposing constant tables as constant GPU memory
 
 function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
-    @dispose ctx=Context() begin
+    return @dispose ctx = Context() begin
         T_val = convert(LLVMType, T)
-        T_ptr = convert(LLVMType, LLVMPtr{T,AS.UniformConstant})
+        T_ptr = convert(LLVMType, LLVMPtr{T, AS.UniformConstant})
 
         # define function and get LLVM module
         llvm_f, _ = create_function(T_ptr)
@@ -149,7 +180,7 @@ function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
         alignment!(gv, 16)
 
         # generate IR
-        @dispose builder=IRBuilder() begin
+        @dispose builder = IRBuilder() begin
             entry = BasicBlock(llvm_f, "entry")
             position!(builder, entry)
 
@@ -160,17 +191,17 @@ function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T}
             ret!(builder, untyped_ptr)
         end
 
-        call_function(llvm_f, LLVMPtr{T,AS.UniformConstant})
+        call_function(llvm_f, LLVMPtr{T, AS.UniformConstant})
     end
 end
 
 for var in [:ki, :wi, :fi, :ke, :we, :fe]
     val = getfield(Random, var)
     gpu_var = Symbol("gpu_$var")
-    arr_typ = :(CLDeviceArray{$(eltype(val)),$(ndims(val)),AS.UniformConstant})
+    arr_typ = :(CLDeviceArray{$(eltype(val)), $(ndims(val)), AS.UniformConstant})
     @eval @inline @generated function $gpu_var()
         ptr = emit_constant_array($(QuoteNode(var)), $val)
-        Expr(:call, $arr_typ, $(size(val)), ptr)
+        return Expr(:call, $arr_typ, $(size(val)), ptr)
     end
 end
 
@@ -183,17 +214,17 @@ end
         r &= 0x000fffffffffffff
         rabs = Int64(r >> 1) # One bit for the sign
         idx = rabs & 0xFF
-        x = ifelse(r % Bool, -rabs, rabs)*gpu_wi()[idx+1]
-        rabs < gpu_ki()[idx+1] && return x # 99.3% of the time we return here 1st try
+        x = ifelse(r % Bool, -rabs, rabs) * gpu_wi()[idx + 1]
+        rabs < gpu_ki()[idx + 1] && return x # 99.3% of the time we return here 1st try
         # TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions
         @inbounds if idx == 0
             while true
-                xx = -Random.ziggurat_nor_inv_r*log(Random.rand(rng))
+                xx = -Random.ziggurat_nor_inv_r * log(Random.rand(rng))
                 yy = -log(Random.rand(rng))
-                yy+yy > xx*xx &&
-                    return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r-xx : Random.ziggurat_nor_r+xx
+                yy + yy > xx * xx &&
+                    return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r - xx : Random.ziggurat_nor_r + xx
             end
-        elseif (gpu_fi()[idx] - gpu_fi()[idx+1])*Random.rand(rng) + gpu_fi()[idx+1] < exp(-0.5*x*x)
+        elseif (gpu_fi()[idx] - gpu_fi()[idx + 1]) * Random.rand(rng) + gpu_fi()[idx + 1] < exp(-0.5 * x * x)
             return x # return from the triangular area
         else
             @goto retry
@@ -213,12 +244,12 @@ end
     @inbounds begin
         ri &= 0x000fffffffffffff
         idx = ri & 0xFF
-        x = ri*gpu_we()[idx+1]
-        ri < gpu_ke()[idx+1] && return x # 98.9% of the time we return here 1st try
+        x = ri * gpu_we()[idx + 1]
+        ri < gpu_ke()[idx + 1] && return x # 98.9% of the time we return here 1st try
         # TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions
         @inbounds if idx == 0
             return Random.ziggurat_exp_r - log(Random.rand(rng))
-        elseif (gpu_fe()[idx] - gpu_fe()[idx+1])*Random.rand(rng) + gpu_fe()[idx+1] < exp(-x)
+        elseif (gpu_fe()[idx] - gpu_fe()[idx + 1]) * Random.rand(rng) + gpu_fe()[idx + 1] < exp(-x)
             return x # return from the triangular area
         else
             @goto retry
@@ -230,5 +261,7 @@ end
     @invoke Random.randexp(rng::AbstractRNG, T::Type{<:AbstractFloat})
 end
 
-@device_override Random.Sampler(::Type{<:AbstractRNG}, r::AbstractUnitRange{T},
-                                ::Random.Repetition) where {T<:Union{Int64, UInt64}} = Random.SamplerRangeFast(r)
+@device_override Random.Sampler(
+    ::Type{<:AbstractRNG}, r::AbstractUnitRange{T},
+    ::Random.Repetition
+) where {T <: Union{Int64, UInt64}} = Random.SamplerRangeFast(r)
diff --git a/src/pocl/device/runtime.jl b/src/pocl/device/runtime.jl
index b6a1aa45..0a551375 100644
--- a/src/pocl/device/runtime.jl
+++ b/src/pocl/device/runtime.jl
@@ -60,7 +60,7 @@ end
 
 # run-time equivalent
 function additional_arg_value(state, name)
-    @dispose ctx=Context() begin
+    return @dispose ctx = Context() begin
         T_state = convert(LLVMType, state)
 
         # create function
@@ -72,7 +72,7 @@ function additional_arg_value(state, name)
         state_intr_ft = function_type(state_intr)
 
         # generate IR
-        @dispose builder=IRBuilder() begin
+        @dispose builder = IRBuilder() begin
             entry = BasicBlock(llvm_f, "entry")
             position!(builder, entry)
 
diff --git a/src/pocl/nanoOpenCL.jl b/src/pocl/nanoOpenCL.jl
index 8aeb08be..e349c2d2 100644
--- a/src/pocl/nanoOpenCL.jl
+++ b/src/pocl/nanoOpenCL.jl
@@ -1325,7 +1325,7 @@ function call(
             sizeof(svm_pointers), svm_pointers
         )
     end
-    return enqueue_kernel(k, global_size, local_size; global_work_offset, rng_state, nargs=length(args))
+    return enqueue_kernel(k, global_size, local_size; global_work_offset, rng_state, nargs = length(args))
 end
 
 # convert the argument values to match the kernel's signature (specified by the user)
diff --git a/test/random.jl b/test/random.jl
index b098de63..f3ca0dd3 100644
--- a/test/random.jl
+++ b/test/random.jl
@@ -3,7 +3,7 @@ using Random
 const n = 256
 
 function apply_seed(seed)
-    if seed === missing
+    return if seed === missing
         # should result in different numbers across launches
         Random.seed!()
         # XXX: this currently doesn't work, because of the definition in Base,
@@ -33,9 +33,9 @@ function random_testsuite(backend)
             a = KernelAbstractions.zeros(backend(), T, n)
             b = KernelAbstractions.zeros(backend(), T, n)
 
-            kernel(backend())(a, seed, ndrange=n, workgroupsize=n)
+            kernel(backend())(a, seed, ndrange = n, workgroupsize = n)
             KernelAbstractions.synchronize(backend())
-            kernel(backend())(b, seed, ndrange=n, workgroupsize=n)
+            kernel(backend())(b, seed, ndrange = n, workgroupsize = n)
             KernelAbstractions.synchronize(backend())
 
             if seed === nothing || seed === missing
@@ -57,7 +57,7 @@ function random_testsuite(backend)
             a = KernelAbstractions.zeros(backend(), T, n)
             b = KernelAbstractions.zeros(backend(), T, n)
 
-            kernel(backend())(a, b, seed, ndrange=n, workgroupsize=n)
+            kernel(backend())(a, b, seed, ndrange = n, workgroupsize = n)
             KernelAbstractions.synchronize(backend())
 
             @test Array(a) != Array(b)
@@ -77,10 +77,10 @@ function random_testsuite(backend)
                 end
 
                 tx, ty, tz, bx, by, bz = [dim == active_dim ? 3 : 1 for dim in 1:6]
-                gx, gy, gz = tx*bx, ty*by, tz*bz
+                gx, gy, gz = tx * bx, ty * by, tz * bz
                 a = KernelAbstractions.zeros(backend(), T, 3)
 
-                kernel(backend())(a, seed, ndrange=(gx, gy, gz), workgroupsize=(tx, ty, tz))
+                kernel(backend())(a, seed, ndrange = (gx, gy, gz), workgroupsize = (tx, ty, tz))
                 KernelAbstractions.synchronize(backend())
 
                 # NOTE: we don't just generate two numbers and compare them, instead generating a
@@ -101,9 +101,9 @@ function random_testsuite(backend)
         a = KernelAbstractions.zeros(backend(), T, n)
         b = KernelAbstractions.zeros(backend(), T, n)
 
-        kernel(backend())(a, seed, ndrange=n, workgroupsize=n)
+        kernel(backend())(a, seed, ndrange = n, workgroupsize = n)
         KernelAbstractions.synchronize(backend())
-        kernel(backend())(b, seed, ndrange=n, workgroupsize=n)
+        kernel(backend())(b, seed, ndrange = n, workgroupsize = n)
         KernelAbstractions.synchronize(backend())
 
         if seed === nothing || seed === missing
@@ -130,9 +130,9 @@ function random_testsuite(backend)
         a = KernelAbstractions.zeros(backend(), T, n)
         b = KernelAbstractions.zeros(backend(), T, n)
 
-        kernel(backend())(a, seed, ndrange=n, workgroupsize=n)
+        kernel(backend())(a, seed, ndrange = n, workgroupsize = n)
         KernelAbstractions.synchronize(backend())
-        kernel(backend())(b, seed, ndrange=n, workgroupsize=n)
+        kernel(backend())(b, seed, ndrange = n, workgroupsize = n)
         KernelAbstractions.synchronize(backend())
 
         if seed === nothing || seed === missing
@@ -142,7 +142,7 @@ function random_testsuite(backend)
         end
     end
 
-    @testset "rand(::AbstractRange{$T}), seed $seed" for T in (Int32, Int64, UInt32, UInt64), seed in (nothing, #=missing,=# 1234)
+    return @testset "rand(::AbstractRange{$T}), seed $seed" for T in (Int32, Int64, UInt32, UInt64), seed in (nothing, #=missing,=# 1234)
         @kernel function kernel(A::AbstractArray{T}, seed) where {T}
             apply_seed(seed)
             tid = @index(Global, Linear)
@@ -152,9 +152,9 @@ function random_testsuite(backend)
         a = KernelAbstractions.zeros(backend(), T, n)
         b = KernelAbstractions.zeros(backend(), T, n)
 
-        kernel(backend())(a, seed, ndrange=n, workgroupsize=n)
+        kernel(backend())(a, seed, ndrange = n, workgroupsize = n)
         KernelAbstractions.synchronize(backend())
-        kernel(backend())(b, seed, ndrange=n, workgroupsize=n)
+        kernel(backend())(b, seed, ndrange = n, workgroupsize = n)
         KernelAbstractions.synchronize(backend())
 
         if seed === nothing || seed === missing

closes JuliaGPU#641

Testing locally, I am running into JuliaGPU#624
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[pocl] Device side RNG

1 participant