Skip to content

Commit

Permalink
Merge d8dcb35 into 35829dd
Browse files Browse the repository at this point in the history
  • Loading branch information
findmyway committed Jan 13, 2019
2 parents 35829dd + d8dcb35 commit dea631f
Show file tree
Hide file tree
Showing 12 changed files with 645 additions and 27 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -4,4 +4,6 @@

### Added

- Add `SumTree` to efficiently sample priorities. ([#9](https://github.com/Ju-jl/Ju.jl/pull/9))
- Add `batch_sample` for `CircularSARDBuffer` to sample a random batch of turn infos. ([#9](https://github.com/Ju-jl/Ju.jl/pull/9))
- Add a new environment named `CartPoleEnv` for future experiments. ([#7](https://github.com/Ju-jl/Ju.jl/pull/7))
1 change: 1 addition & 0 deletions benchmarks/Project.toml
@@ -1,3 +1,4 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Ju = "449ae9ca-b987-11e8-3919-0764a06dfe61"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
107 changes: 107 additions & 0 deletions benchmarks/circular_array_buffer.jl
@@ -0,0 +1,107 @@
using Ju
using BenchmarkTools
using StatsBase:sample

buffer = CircularArrayBuffer{Array{Float64, 3}}(10, (80, 80, 1));

println("\n", repeat('=', 50))
println("\npush! into CircularArrayBuffer\n")
display(@benchmark push!($buffer, $(rand(80, 80, 1))))

println("\n", repeat('=', 50))
println("\ngetindex of CircularArrayBuffer\n")
display(@benchmark $buffer[$(rand(1:length(buffer)))])

println("\n", repeat('=', 50))
println("\nview element of CircularArrayBuffer\n")
display(@benchmark view($buffer, $(rand(1:length(buffer)))))

n = 16
println("\n", repeat('=', 50))
println("\nview $n elements of CircularArrayBuffer\n")
display(@benchmark view($buffer, $(sample(1:length(buffer), n))))

n = 32
println("\n", repeat('=', 50))
println("\nview $n elements of CircularArrayBuffer\n")
display(@benchmark view($buffer, $(sample(1:length(buffer), n))))

# ==================================================

# push! into CircularArrayBuffer

# BenchmarkTools.Trial:
# memory estimate: 0 bytes
# allocs estimate: 0
# --------------
# minimum time: 2.411 μs (0.00% GC)
# median time: 2.456 μs (0.00% GC)
# mean time: 2.464 μs (0.00% GC)
# maximum time: 5.133 μs (0.00% GC)
# --------------
# samples: 10000
# evals/sample: 9

# ==================================================

# getindex of CircularArrayBuffer

# BenchmarkTools.Trial:
# memory estimate: 50.08 KiB
# allocs estimate: 2
# --------------
# minimum time: 5.540 μs (0.00% GC)
# median time: 15.940 μs (0.00% GC)
# mean time: 16.852 μs (27.09% GC)
# maximum time: 8.680 ms (99.80% GC)
# --------------
# samples: 10000
# evals/sample: 5

# ==================================================

# view element of CircularArrayBuffer

# BenchmarkTools.Trial:
# memory estimate: 64 bytes
# allocs estimate: 1
# --------------
# minimum time: 12.913 ns (0.00% GC)
# median time: 14.916 ns (0.00% GC)
# mean time: 23.723 ns (30.27% GC)
# maximum time: 45.953 μs (99.93% GC)
# --------------
# samples: 10000
# evals/sample: 999

# ==================================================

# view 16 elements of CircularArrayBuffer

# BenchmarkTools.Trial:
# memory estimate: 400 bytes
# allocs estimate: 6
# --------------
# minimum time: 96.007 ns (0.00% GC)
# median time: 101.996 ns (0.00% GC)
# mean time: 130.636 ns (17.69% GC)
# maximum time: 45.819 μs (99.68% GC)
# --------------
# samples: 10000
# evals/sample: 952

# ==================================================

# view 32 elements of CircularArrayBuffer

# BenchmarkTools.Trial:
# memory estimate: 528 bytes
# allocs estimate: 6
# --------------
# minimum time: 109.849 ns (0.00% GC)
# median time: 123.340 ns (0.00% GC)
# mean time: 166.186 ns (18.24% GC)
# maximum time: 51.037 μs (99.33% GC)
# --------------
# samples: 10000
# evals/sample: 934
119 changes: 119 additions & 0 deletions benchmarks/circular_turn_buffer.jl
@@ -0,0 +1,119 @@
using Ju
using BenchmarkTools

state_size = (80, 80, 1)
actions = 1:10

buffer = CircularSARDBuffer(10^5; state_type=Array{Float64, 3}, state_size=state_size)

push!(buffer, rand(state_size...), rand(actions))

println("\n", repeat('=', 50))
println("\n push! buffer\n")
display(@benchmark push!($buffer, 1.0, false, $(rand(state_size...)), $(rand(actions))))

batch_size = 32
println("\n", repeat('=', 50))
println("\n batch_sample buffer (batch_size=$batch_size)\n")
display(@benchmark batch_sample($buffer, $batch_size))

batch_size = 64
println("\n", repeat('=', 50))
println("\n batch_sample buffer (batch_size=$batch_size)\n")
display(@benchmark batch_sample($buffer, $batch_size))

function sample_N_batches(buffer, N, batch_size)
for _ in 1:N
batch_sample(buffer, batch_size)
end
end

N, batch_size = 10, 32
println("\n", repeat('=', 50))
println("\n batch_sample buffer (batch_size=$batch_size) $N times\n")
display(@benchmark sample_N_batches($buffer, $N, $batch_size))

N, batch_size = 10, 64
println("\n", repeat('=', 50))
println("\n batch_sample buffer (batch_size=$batch_size) $N times\n")
display(@benchmark sample_N_batches($buffer, $N, $batch_size))

# ==================================================

# push! buffer

# BenchmarkTools.Trial:
# memory estimate: 0 bytes
# allocs estimate: 0
# --------------
# minimum time: 4.657 μs (0.00% GC)
# median time: 5.114 μs (0.00% GC)
# mean time: 5.138 μs (0.00% GC)
# maximum time: 9.500 μs (0.00% GC)
# --------------
# samples: 10000
# evals/sample: 7

# ==================================================

# batch_sample buffer (batch_size=32)

# BenchmarkTools.Trial:
# memory estimate: 3.13 KiB
# allocs estimate: 34
# --------------
# minimum time: 958.163 ns (0.00% GC)
# median time: 1.081 μs (0.00% GC)
# mean time: 1.239 μs (11.31% GC)
# maximum time: 25.488 μs (94.94% GC)
# --------------
# samples: 10000
# evals/sample: 43

# ==================================================

# batch_sample buffer (batch_size=64)

# BenchmarkTools.Trial:
# memory estimate: 5.09 KiB
# allocs estimate: 34
# --------------
# minimum time: 1.270 μs (0.00% GC)
# median time: 1.460 μs (0.00% GC)
# mean time: 1.709 μs (9.25% GC)
# maximum time: 80.030 μs (95.69% GC)
# --------------
# samples: 10000
# evals/sample: 10

# ==================================================

# batch_sample buffer (batch_size=32) 10 times

# BenchmarkTools.Trial:
# memory estimate: 30.47 KiB
# allocs estimate: 320
# --------------
# minimum time: 8.699 μs (0.00% GC)
# median time: 10.000 μs (0.00% GC)
# mean time: 13.832 μs (9.60% GC)
# maximum time: 1.497 ms (98.13% GC)
# --------------
# samples: 10000
# evals/sample: 1

# ==================================================

# batch_sample buffer (batch_size=64) 10 times

# BenchmarkTools.Trial:
# memory estimate: 50.16 KiB
# allocs estimate: 320
# --------------
# minimum time: 12.599 μs (0.00% GC)
# median time: 14.301 μs (0.00% GC)
# mean time: 16.655 μs (9.19% GC)
# maximum time: 741.201 μs (96.80% GC)
# --------------
# samples: 10000
# evals/sample: 1
68 changes: 68 additions & 0 deletions benchmarks/sum_tree.jl
@@ -0,0 +1,68 @@
using Ju
using BenchmarkTools

t = SumTree(10^5)

println("\n", repeat('=', 50))
println("\n push! priority into sum tree \n")
display(@benchmark push!($t, $(rand())))

batch_size = 32

println("\n", repeat('=', 50))
println("\n sample $batch_size \n")
display(@benchmark sample($t, $batch_size))

batch_size = 64

println("\n", repeat('=', 50))
println("\n sample $batch_size \n")
display(@benchmark sample($t, $batch_size))

# ==================================================

# push! priority into sum tree

# BenchmarkTools.Trial:
# memory estimate: 0 bytes
# allocs estimate: 0
# --------------
# minimum time: 29.025 ns (0.00% GC)
# median time: 29.199 ns (0.00% GC)
# mean time: 31.569 ns (0.00% GC)
# maximum time: 86.101 ns (0.00% GC)
# --------------
# samples: 10000
# evals/sample: 995

# ==================================================

# sample 32

# BenchmarkTools.Trial:
# memory estimate: 704 bytes
# allocs estimate: 3
# --------------
# minimum time: 4.863 μs (0.00% GC)
# median time: 5.300 μs (0.00% GC)
# mean time: 6.415 μs (14.46% GC)
# maximum time: 7.253 ms (99.87% GC)
# --------------
# samples: 10000
# evals/sample: 7

# ==================================================

# sample 64

# BenchmarkTools.Trial:
# memory estimate: 1.25 KiB
# allocs estimate: 3
# --------------
# minimum time: 9.076 μs (0.00% GC)
# median time: 10.344 μs (0.00% GC)
# mean time: 10.990 μs (0.00% GC)
# maximum time: 46.137 μs (0.00% GC)
# --------------
# samples: 10000
# evals/sample: 1
1 change: 1 addition & 0 deletions docs/src/utilities.md
Expand Up @@ -52,6 +52,7 @@ Pages = ["helper_functions.jl"]
## Others

```@docs
batch_sample
CircularArrayBuffer
Tiling
```
1 change: 0 additions & 1 deletion src/buffers/abstract_turn_buffer.jl
Expand Up @@ -38,7 +38,6 @@ const SARDSABuffer = AbstractTurnBuffer{SARDSA}

function isfull end
function capacity end

buffers(b::AbstractTurnBuffer) = getfield(b, :buffers)

size(b::AbstractTurnBuffer) = (length(b),)
Expand Down
4 changes: 3 additions & 1 deletion src/buffers/buffers.jl
@@ -1,8 +1,10 @@
export CircularArrayBuffer,
export CircularArrayBuffer, batch_sample,
SumTree,
CircularTurnBuffer, CircularSARDBuffer, CircularSARDSBuffer, CircularSARDSABuffer,
EpisodeTurnBuffer, EpisodeSARDBuffer, EpisodeSARDSBuffer, EpisodeSARDSABuffer

include("circular_array_buffer.jl")
include("sum_tree.jl")

include("circular_turn_buffer.jl")
include("episode_turn_buffer.jl")
13 changes: 11 additions & 2 deletions src/buffers/circular_array_buffer.jl
Expand Up @@ -54,9 +54,18 @@ size(cb::CircularArrayBuffer{E, T, N}) where {E, T, N} = (size(cb.buffer)[1:N-1]

for func in [:view, :getindex]
@eval @__MODULE__() begin
$func(cb::CircularArrayBuffer{E, T, 1}, i::Int) where {E, T} = $func(cb.buffer, _buffer_index(cb, i))
$func(cb::CircularArrayBuffer{E, T, 2}, i::Int) where {E, T} = $func(cb.buffer, :, _buffer_index(cb, i))
$func(cb::CircularArrayBuffer{E, T, 3}, i::Int) where {E, T} = $func(cb.buffer, :, :, _buffer_index(cb, i))
$func(cb::CircularArrayBuffer{E, T, 4}, i::Int) where {E, T} = $func(cb.buffer, :, :, :, _buffer_index(cb, i))
$func(cb::CircularArrayBuffer{E, T, N}, i::Int) where {E, T, N} = $func(cb.buffer, [(:) for _ in 1 : N-1]..., _buffer_index(cb, i))
$func(cb::CircularArrayBuffer{E, T, N}, i::UnitRange{Int}) where {E, T, N} = $func(cb.buffer, [(:) for _ in 1 : N-1]..., _buffer_index(cb, i))
$func(cb::CircularArrayBuffer{E, T, N}, I::Vector{Int}) where {E, T, N} = $func(cb.buffer, [(:) for _ in 1 : N-1]..., [_buffer_index(cb, i) for i in I])

$func(cb::CircularArrayBuffer{E, T, 1}, I::Vector{Int}) where {E, T} = $func(cb.buffer, map(i -> _buffer_index(cb, i), I))
$func(cb::CircularArrayBuffer{E, T, 2}, I::Vector{Int}) where {E, T} = $func(cb.buffer, :, map(i -> _buffer_index(cb, i), I))
$func(cb::CircularArrayBuffer{E, T, 3}, I::Vector{Int}) where {E, T} = $func(cb.buffer, :, :, map(i -> _buffer_index(cb, i), I))
$func(cb::CircularArrayBuffer{E, T, 4}, I::Vector{Int}) where {E, T} = $func(cb.buffer, :, :, :, map(i -> _buffer_index(cb, i), I))
$func(cb::CircularArrayBuffer{E, T, N}, I::Vector{Int}) where {E, T, N} = $func(cb.buffer, [(:) for _ in 1 : N-1]..., map(i -> _buffer_index(cb, i), I))
# $func(cb::CircularArrayBuffer{E, T, N}, i::UnitRange{Int}) where {E, T, N} = $func(cb.buffer, [(:) for _ in 1 : N-1]..., _buffer_index(cb, i)) # TODO: seems not useful?
end
end

Expand Down

0 comments on commit dea631f

Please sign in to comment.