Skip to content

Commit

Permalink
added convert tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Omastto1 committed Apr 28, 2021
1 parent a45d908 commit 467ab05
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
5 changes: 3 additions & 2 deletions Project.toml
@@ -1,7 +1,7 @@
name = "PointBasedValueIteration"
uuid = "835c131e-675f-4498-8e2c-c054c75556e1"
authors = ["Dominik Straub <straub@psychologie.tu-darmstadt.de>"]
version = "0.1.0"
authors = ["Dominik Straub <straub@psychologie.tu-darmstadt.de> and Tomáš Omasta <tomom@email.cz>"]
version = "0.2.0"

[deps]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
Expand All @@ -15,6 +15,7 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"

[compat]
BeliefUpdaters = "0.2"
FiniteHorizonPOMDPs = "0.3"
POMDPLinter = "0.1"
POMDPModelTools = "0.3.2"
POMDPPolicies = "0.3, 0.4"
Expand Down
2 changes: 1 addition & 1 deletion src/PointBasedValueIteration.jl
Expand Up @@ -10,7 +10,7 @@ using Distributions
using FiniteHorizonPOMDPs

import POMDPs: Solver, solve
import Base: ==, hash
import Base: ==, hash, convert
import FiniteHorizonPOMDPs: InStageDistribution, FixedHorizonPOMDPWrapper


Expand Down
11 changes: 9 additions & 2 deletions src/pbvi.jl
Expand Up @@ -35,12 +35,19 @@ end
==(a::AlphaVec, b::AlphaVec) = (a.alpha,a.action) == (b.alpha, b.action)
Base.hash(a::AlphaVec, h::UInt) = hash(a.alpha, hash(a.action, h))

convert(::Type{Array{Float64, 1}}, d::BoolDistribution, pomdp) = [d.p, 1 - d.p]
convert(::Type{Array{Float64, 1}}, d::BoolDistribution, pomdp) = [1 - d.p, d.p]
convert(::Type{Array{Float64, 1}}, d::DiscreteUniform, pomdp) = [pdf(d, stateindex(pomdp, s)) for s in states(pomdp)]
convert(::Type{Array{Float64, 1}}, d::SparseCat, pomdp) = d.probs

convert(::Type{Array{Float64, 1}}, d::InStageDistribution{DiscreteUniform}, m::FixedHorizonPOMDPWrapper) = vec([pdf(d, s) for s in states(m)])
convert(::Type{Array{Float64, 1}}, d::InStageDistribution{BoolDistribution}, m::FixedHorizonPOMDPWrapper) = [[d.d.p[1], 1 - d.d.p[1]]..., zeros(length(states(m)) - 2)...]

function convert(::Type{Array{Float64, 1}}, d::InStageDistribution{BoolDistribution}, m::FixedHorizonPOMDPWrapper)
if stage(d) == 1
append!([1 - d.d.p[1], d.d.p[1]], zeros(length(states(m)) - 2))
else
append!(append!(zeros((stage(d) - 1) * length(stage_states(m, 1))), [1 - d.d.p[1], d.d.p[1]]), zeros((horizon(m) - stage(d) + 1) * length(stage_states(m, 1))))
end
end


function _argmax(f, X)
Expand Down
39 changes: 39 additions & 0 deletions test/runtests.jl
Expand Up @@ -5,9 +5,48 @@ using SARSOP
using BeliefUpdaters
using POMDPModelTools: Deterministic
using POMDPSimulators: RolloutSimulator
using FiniteHorizonPOMDPs

using PointBasedValueIteration

@testset "Convert test" begin
@testset "Infinite Horizon POMDP tests" begin
tigerPOMDP = TigerPOMDP()
babyPOMDP = BabyPOMDP()
minihallwayPOMDP = MiniHallway()

@test convert(Array{Float64, 1}, initialstate(tigerPOMDP), tigerPOMDP) == [0.5, 0.5]
@test convert(Array{Float64, 1}, initialstate(babyPOMDP), babyPOMDP) == [1., 0.]
@test convert(Array{Float64, 1}, initialstate(minihallwayPOMDP), minihallwayPOMDP) == append!(fill(1/12, 12), zeros(1))
end

@testset "Finite Horizon POMDP tests" begin
@testset "Finite Horizon POMDP initial state convert tests" begin
tigerPOMDP = fixhorizon(TigerPOMDP(), 1)
babyPOMDP = fixhorizon(BabyPOMDP(), 1)
minihallwayPOMDP = fixhorizon(MiniHallway(), 1)

@test convert(Array{Float64, 1}, initialstate(tigerPOMDP), tigerPOMDP) == [0.5, 0.5, 0., 0.]
@test convert(Array{Float64, 1}, initialstate(babyPOMDP), babyPOMDP) == [1., 0., 0., 0.]
@test convert(Array{Float64, 1}, initialstate(minihallwayPOMDP), minihallwayPOMDP) == append!(fill(1/12, 12), zeros(14))
end

@testset "Finite Horizon POMDP other than initial stage distribution tests" begin
tigerPOMDP = fixhorizon(TigerPOMDP(), 2)
babyPOMDP = fixhorizon(BabyPOMDP(), 2)
minihallwayPOMDP = fixhorizon(MiniHallway(), 2)

tigerbelief = FiniteHorizonPOMDPs.InStageDistribution(FiniteHorizonPOMDPs.distribution(initialstate(tigerPOMDP)), 2)
babybelief = FiniteHorizonPOMDPs.InStageDistribution(FiniteHorizonPOMDPs.distribution(initialstate(babyPOMDP)), 2)
minihallwaybelief = FiniteHorizonPOMDPs.InStageDistribution(FiniteHorizonPOMDPs.distribution(initialstate(minihallwayPOMDP)), 2)

@test convert(Array{Float64, 1}, tigerbelief, tigerPOMDP) == [0., 0., 0.5, 0.5, 0., 0.]
@test convert(Array{Float64, 1}, babybelief, babyPOMDP) == [0., 0., 1., 0., 0., 0.]
@test convert(Array{Float64, 1}, minihallwaybelief, minihallwayPOMDP) == append!(append!(zeros(13), fill(1/12, 12)), zeros(14))
end
end
end

@testset "Comparison with SARSOP" begin
pomdps = [TigerPOMDP(), BabyPOMDP(), MiniHallway()]

Expand Down

0 comments on commit 467ab05

Please sign in to comment.