diff --git a/test/runtests.jl b/test/runtests.jl index 806fca1..91fd97d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,9 +29,9 @@ include("updater.jl") include("tree.jl") -@testset "Tiger POMDP" begin +@testset "Tiger POMDP (no binning)" begin pomdp = TigerPOMDP() - solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false) + solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false, use_binning=false) tree = SARSOPTree(pomdp) Γ = solve(solver, pomdp) iterations = 0 @@ -50,9 +50,9 @@ include("tree.jl") @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01 end -@testset "Baby POMDP" begin +@testset "Baby POMDP (no binning)" begin pomdp = BabyPOMDP() - solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false) + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false, use_binning=false) tree = SARSOPTree(pomdp) Γ = solve(solver, pomdp) iterations = 0 @@ -71,9 +71,88 @@ end @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01 end -@testset "RockSample POMDP" begin +@testset "RockSample POMDP (no binning)" begin pomdp = RockSamplePOMDP() - solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false) + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false, use_binning=false) + tree = SARSOPTree(pomdp) + Γ = solve(solver, pomdp) + iterations = 0 + while JSOP.root_diff(tree) > solver.precision + iterations += 1 + JSOP.sample!(solver, tree) + JSOP.backup!(tree) + JSOP.prune!(solver, tree) + end + # @test isapprox(tree.V_lower[1], -16.3; atol=1e-2) + @test JSOP.root_diff(tree) < solver.precision + + solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-2, verbose=false) + policyCPP = solve(solverCPP, pomdp) + @test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.1 + @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.1 +end + +@testset "Binning" begin + pomdp = BabyPOMDP() + + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-8, max_time=3.0, verbose=false, use_binning=false) + Γ1 = solve(solver, pomdp) + + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-8, max_time=3.0, verbose=false, use_binning=true) + Γ2 = solve(solver, pomdp) + + @test abs(value(Γ1, initialstate(pomdp)) - value(Γ2, initialstate(pomdp))) < 1e-7 + + Γ, info = solve_info(solver, pomdp) + @test !isempty(info.tree.bm.bin_levels[1][:bin_count]) + @test length(info.tree.bm.bin_levels) == 2 +end + +@testset "Tiger POMDP (with binning)" begin + pomdp = TigerPOMDP() + solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false, use_binning=true) + tree = SARSOPTree(pomdp) + Γ = solve(solver, pomdp) + iterations = 0 + while JSOP.root_diff(tree) > solver.precision + iterations += 1 + JSOP.sample!(solver, tree) + JSOP.backup!(tree) + JSOP.prune!(solver, tree) + end + @test isapprox(tree.V_lower[1], 19.37; atol=1e-1) + @test JSOP.root_diff(tree) < solver.precision + + solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false) + policyCPP = solve(solverCPP, pomdp) + @test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01 + @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01 +end + +@testset "Baby POMDP (with binning)" begin + pomdp = BabyPOMDP() + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false, use_binning=true) + tree = SARSOPTree(pomdp) + Γ = solve(solver, pomdp) + iterations = 0 + while JSOP.root_diff(tree) > solver.precision + iterations += 1 + JSOP.sample!(solver, tree) + JSOP.backup!(tree) + JSOP.prune!(solver, tree) + end + @test isapprox(tree.V_lower[1], -16.3; atol=1e-2) + @test JSOP.root_diff(tree) < solver.precision + + solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false) + policyCPP = solve(solverCPP, pomdp) + @test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01 + @test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01 +end + +@testset "RockSample POMDP (with binning)" begin + pomdp = RockSamplePOMDP() + solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false, use_binning=true) tree = SARSOPTree(pomdp) Γ = solve(solver, pomdp) iterations = 0