diff --git a/README.md b/README.md index 284a79a..9bcb7d3 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ pois_rand(λ) # Using another RNG using RandomNumbers rng = Xorshifts.Xoroshiro128Plus() -pois_rand(λ,rng) +pois_rand(rng,λ) ``` ## Implementation @@ -29,46 +29,46 @@ using RandomNumbers, Distributions, BenchmarkTools, StaticArrays, labels = ["count_rand","ad_rand","pois_rand","Distributions.jl"] rng = Xorshifts.Xoroshiro128Plus() -function n_count(λ,rng,n) +function n_count(rng,λ,n) tmp = 0 for i in 1:n - tmp += PoissonRandom.count_rand(λ,rng) + tmp += PoissonRandom.count_rand(rng,λ) end end -function n_pois(λ,rng,n) +function n_pois(rng,λ,n) tmp = 0 for i in 1:n - tmp += pois_rand(λ,rng) + tmp += pois_rand(rng,λ) end end -function n_ad(λ,rng,n) +function n_ad(rng,λ,n) tmp = 0 for i in 1:n - tmp += PoissonRandom.ad_rand(λ,rng) + tmp += PoissonRandom.ad_rand(rng,λ) end end -function n_dist(λ,rng,n) +function n_dist(λ,n) tmp = 0 for i in 1:n tmp += rand(Poisson(λ)) end end -function time_λ(λ,rng,n) - t1 = @elapsed n_count(λ,rng,n) - t2 = @elapsed n_ad(λ,rng,n) - t3 = @elapsed n_pois(λ,rng,n) - t4 = @elapsed n_dist(λ,rng,n) +function time_λ(rng,λ,n) + t1 = @elapsed n_count(rng,λ,n) + t2 = @elapsed n_ad(rng,λ,n) + t3 = @elapsed n_pois(rng,λ,n) + t4 = @elapsed n_dist(λ,n) @SArray [t1,t2,t3,t4] end # Compile -time_λ(5,rng,5000000) +time_λ(rng,5,5000000) # Run with a bunch of λ -times = VectorOfArray([time_λ(n,rng,5000000) for n in 1:20])' +times = VectorOfArray([time_λ(rng,n,5000000) for n in 1:20])' plot(times,labels = labels, lw = 3) ``` diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index afd6c3d..d296060 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -4,7 +4,8 @@ using Random export pois_rand -function count_rand(λ,rng::AbstractRNG=Random.GLOBAL_RNG) +count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ) +function count_rand(rng::AbstractRNG, λ) n = 0 c = randexp(rng) while c < λ @@ -22,7 +23,8 @@ end # # For μ sufficiently large, (i.e. >= 10.0) # -function ad_rand(λ,rng::AbstractRNG=Random.GLOBAL_RNG) +ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ) +function ad_rand(rng::AbstractRNG, λ) s = sqrt(λ) d = 6.0*λ^2 L = floor(Int,λ-1.1484) @@ -139,12 +141,7 @@ function procf(λ, K::Int, s::Float64) return px,py,fx,fy end -function pois_rand(λ,rng::AbstractRNG=Random.GLOBAL_RNG) - if λ < 6 - return count_rand(λ,rng) - else - return ad_rand(λ,rng) - end -end +pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ) +pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) end # module