Skip to content

Commit

Permalink
Added simulations to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Omastto1 committed Apr 11, 2021
1 parent d6ab696 commit 650fba5
Showing 1 changed file with 50 additions and 18 deletions.
68 changes: 50 additions & 18 deletions test/runtests.jl
Expand Up @@ -3,6 +3,8 @@ using POMDPModels
using POMDPs
using SARSOP
using BeliefUpdaters
using POMDPModelTools: Deterministic
using POMDPSimulators: RolloutSimulator

using PointBasedValueIteration

Expand All @@ -16,26 +18,56 @@ using PointBasedValueIteration
sarsop = SARSOPSolver(verbose=false)
sarsop_policy = solve(sarsop, pomdp)

B = []
if typeof(pomdp) == MiniHallway
B = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], [0.083337, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
else
for _ in 1:100
r = rand(length(states(pomdp)))
push!(B, DiscreteBelief(pomdp, r/sum(r)))
@testset "$(typeof(pomdp)) Value function comparison" begin
B = []
if typeof(pomdp) == MiniHallway
B = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], [0.083337, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.083333, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
else
for _ in 1:100
r = rand(length(states(pomdp)))
push!(B, DiscreteBelief(pomdp, r/sum(r)))
end
end

pbvi_vals = [value(policy, b) for b in B]
sarsop_vals = [value(sarsop_policy, b) for b in B]
@test isapprox(sarsop_vals, pbvi_vals, rtol=0.1)
end

pbvi_vals = [value(policy, b) for b in B]
sarsop_vals = [value(sarsop_policy, b) for b in B]
@test isapprox(sarsop_vals, pbvi_vals, rtol=0.1)
@testset "$(typeof(pomdp)) Simulation results comparison" begin
no_simulations = typeof(pomdp) == MiniHallway ? 1 : 10_000
for s in states(pomdp)
# println(s)
# @show value(policy, Deterministic(s))
# @show value(sarsop_policy, Deterministic(s))
#
# @show action(policy, Deterministic(s))
# @show action(sarsop_policy, Deterministic(s))
#
# @show mean([simulate(RolloutSimulator(max_steps = 100), pomdp, policy, updater(policy), Deterministic(s)) for i in 1:no_simulations])
# @show mean([simulate(RolloutSimulator(max_steps = 100), pomdp, sarsop_policy, updater(sarsop_policy), Deterministic(s)) for i in 1:no_simulations])

if s == 5 && typeof(pomdp) == MiniHallway
@test_broken isapprox(value(policy, Deterministic(s)), value(sarsop_policy, Deterministic(s)), rtol=0.1)
@test_broken isapprox( mean([simulate(RolloutSimulator(max_steps = 100), pomdp, policy, updater(policy), Deterministic(s)) for i in 1:no_simulations]),
mean([simulate(RolloutSimulator(max_steps = 100), pomdp, sarsop_policy, updater(sarsop_policy), Deterministic(s)) for i in 1:no_simulations]),
rtol=0.1)
else
@test isapprox(value(policy, Deterministic(s)), value(sarsop_policy, Deterministic(s)), rtol=0.1)
@test isapprox( mean([simulate(RolloutSimulator(max_steps = 100), pomdp, policy, updater(policy), Deterministic(s)) for i in 1:no_simulations]),
mean([simulate(RolloutSimulator(max_steps = 100), pomdp, sarsop_policy, updater(sarsop_policy), Deterministic(s)) for i in 1:no_simulations]),
rtol=0.1)
end
end
# end
end
end
end

0 comments on commit 650fba5

Please sign in to comment.