Skip to content

Commit

Permalink
Merge pull request #18 from JuliaPOMDP/policy-evaluation
Browse files Browse the repository at this point in the history
Policy evaluation
  • Loading branch information
zsunberg committed Jun 4, 2019
2 parents 3f756c4 + 368bf4a commit 715451c
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Coverage Status](https://coveralls.io/repos/github/JuliaPOMDP/POMDPModelTools.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaPOMDP/POMDPModelTools.jl?branch=master)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://JuliaPOMDP.github.io/POMDPModelTools.jl/latest)

Support tools for writing [POMDPs.jl](github.com/JuliaPOMDP/POMDPs.jl) models and solvers.
Support tools for writing and working with [POMDPs.jl](github.com/JuliaPOMDP/POMDPs.jl) models and solvers.

Please read the documentation here for a list of tools: [https://JuliaPOMDP.github.io/POMDPModelTools.jl/latest](https://JuliaPOMDP.github.io/POMDPModelTools.jl/latest)

Expand Down
7 changes: 7 additions & 0 deletions docs/src/policy_evaluation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Policy Evaluation

The [`evaluate`](@ref) function provides a policy evaluation tool for MDPs:

```@docs
evaluate
```
7 changes: 6 additions & 1 deletion src/POMDPModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ module POMDPModelTools

using POMDPs
using Random
using LinearAlgebra
using SparseArrays

import POMDPs: actions, n_actions, actionindex
import POMDPs: states, n_states, stateindex
import POMDPs: observations, n_observations, obsindex
import POMDPs: sampletype, generate_sr, initialstate, isterminal, discount
# import POMDPs: Updater, update, initialize_belief, pdf, mode, updater
import POMDPs: implemented
import Distributions: pdf, mode, mean, support
import Random: rand, rand!
Expand Down Expand Up @@ -74,4 +75,8 @@ include("distributions/deterministic.jl")
# convenient implementations
include("convenient_implementations.jl")

export
evaluate
include("policy_evaluation.jl")

end # module
83 changes: 83 additions & 0 deletions src/policy_evaluation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Value function for a policy on an MDP.
If `v` is a `DiscreteValueFunction`, access the value for a state with `v(s)`
"""
struct DiscreteValueFunction{M<:MDP} <: Function
m::M
values::Vector{Float64}
end

(v::DiscreteValueFunction)(s) = v.values[stateindex(v.m, s)]

"""
evaluate(m::MDP, p::Policy)
evaluate(m::MDP, p::Policy; rewardfunction=POMDPs.reward)
Calculate the value for a policy on an MDP using the approach in equation 4.2.2 of Kochenderfer, *Decision Making Under Uncertainty*, 2015.
Returns a DiscreteValueFunction, which maps states to values.
# Example
```
using POMDPModelTools, POMDPPolicies, POMDPModels
m = SimpleGridWorld()
u = evaluate(m, FunctionPolicy(x->:left))
u([1,1]) # value of always moving left starting at state [1,1]
```
"""
function evaluate(m::MDP, p::Policy; rewardfunction=POMDPs.reward)
t = policy_transition_matrix(m, p)
r = policy_reward_vector(m, p, rewardfunction=rewardfunction)
u = (I-discount(m)*t)\r
return DiscreteValueFunction(m, u)
end

"""
policy_transition_matrix(m::Union{MDP, POMDP}, p::Policy)
Create an |S|x|S| sparse transition matrix for a given policy.
The row corresponds to the current state and column to the next state. Corresponds to ``T^π`` in equation (4.7) in Kochenderfer, *Decision Making Under Uncertainty*, 2015.
"""
function policy_transition_matrix(m::Union{MDP,POMDP}, p::Policy)
ns = n_states(m)
rows = Int[]
cols = Int[]
probs = Float64[]

for s in states(m)
if !isterminal(m, s) # if terminal, the transition probabilities are all just zero
si = stateindex(m, s)
a = action(p, s)
td = transition(m, s, a)
for (sp, p) in weighted_iterator(td)
if p > 0.0
spi = stateindex(m, sp)
push!(rows, si)
push!(cols, spi)
push!(probs, p)
end
end
end
end

return sparse(rows, cols, probs, ns, ns)
end

function policy_reward_vector(m::Union{MDP,POMDP}, p::Policy; rewardfunction=POMDPs.reward)
r = zeros(n_states(m))
for s in states(m)
if !isterminal(m, s) # if terminal, the transition probabilities are all just zero
si = stateindex(m, s)
a = action(p, s)
td = transition(m, s, a)
for (sp, p) in weighted_iterator(td)
if p > 0.0
r[si] += p*rewardfunction(m, s, a, sp)
end
end
end
end
return r
end
1 change: 1 addition & 0 deletions test/REQUIRE
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
POMDPModels
POMDPSimulators
POMDPPolicies
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Random
using Test
using Pkg
using POMDPSimulators
using POMDPPolicies

@testset "ordered" begin
include("test_ordered_spaces.jl")
Expand Down Expand Up @@ -53,3 +54,7 @@ end
@testset "vis" begin
include("test_visualization.jl")
end

@testset "evaluation" begin
include("test_evaluation.jl")
end
8 changes: 8 additions & 0 deletions test/test_evaluation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
let
m = SimpleGridWorld(rewards=Dict([3,1]=>10.0), tprob=1.0)
u = evaluate(m, FunctionPolicy(x->:right))
@test u([2,1]) == 9.5
@test u([3,1]) == 10.0
@test u([4,1]) == 0.0
@test u([3,2]) == 0.0
end
2 changes: 1 addition & 1 deletion test/test_info.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ POMDPs.solve(solver::RandomSolver, problem::P) where {P<:Union{POMDP,MDP}} = Ran
let
rng = MersenneTwister(7)

mdp = GridWorld()
mdp = LegacyGridWorld()
s = initialstate(mdp, rng)
a = rand(rng, actions(mdp))
@inferred generate_sri(mdp, s, a, rng)
Expand Down

0 comments on commit 715451c

Please sign in to comment.