/
solver.jl
97 lines (64 loc) · 2.36 KB
/
solver.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# A test for solvers
# Maintained by @zsunberg
mutable struct TestSimulator
rng::AbstractRNG
max_steps::Int
end
function simulate(sim::TestSimulator, pomdp::POMDP, policy::Policy, updater::Updater, initial_distribution::Any)
s = rand(sim.rng, initial_distribution)
b = initialize_belief(updater, initial_distribution)
disc = 1.0
r_total = 0.0
step = 1
while !isterminal(pomdp, s) && step <= sim.max_steps # TODO also check for terminal observation
a = action(policy, b)
(sp, o, r) = @gen(:sp, :o, :r)(pomdp, s, a, sim.rng)
r_total += disc*r
b = update(updater, b, a, o)
disc *= discount(pomdp)
s = sp
step += 1
end
return r_total
end
function simulate(sim::TestSimulator, mdp::MDP, policy::Policy, s)
disc = 1.0
r_total = 0.0
step = 1
while !isterminal(mdp, s) && step <= sim.max_steps # TODO also check for terminal observation
a = action(policy, s)
(sp, r) = @gen(:sp, :r)(mdp, s, a, sim.rng)
r_total += disc*r
disc *= discount(mdp)
s = sp
step += 1
end
return r_total
end
"""
test_solver(solver::Solver, problem::POMDP)
test_solver(solver::Solver, problem::MDP)
Use the solver to solve the specified problem, then run a simulation.
This is designed to illustrate how solvers are expected to function. All solvers should be able to complete this standard test with the simple models in the POMDPModels package.
Note that this does NOT test the optimality of the solution, but is only a smoke test to see if the solver interacts with POMDP models as expected.
To run this with a solver called YourSolver, run
```
using POMDPToolbox
using POMDPModels
solver = YourSolver(# initialize with parameters #)
test_solver(solver, BabyPOMDP())
```
"""
function test_solver(solver::Solver, problem::POMDP; max_steps=10, updater=nothing)
policy = solve(solver, problem)
if updater==nothing
updater = POMDPs.updater(policy)
end
sim = TestSimulator(MersenneTwister(1), max_steps)
simulate(sim, problem, policy, updater, initialstate(problem))
end
function test_solver(solver::Solver, problem::MDP; max_steps=10)
policy = solve(solver, problem)
sim = TestSimulator(MersenneTwister(1), max_steps)
simulate(sim, problem, policy, rand(MersenneTwister(0), initialstate(problem)))
end