Skip to content
This repository has been archived by the owner on Aug 14, 2021. It is now read-only.

Commit

Permalink
Merge pull request #18 from JuliaPOMDP/0_6
Browse files Browse the repository at this point in the history
changes to 0.6
  • Loading branch information
ebalaban committed Jul 22, 2017
2 parents 888223e + fbe368b commit 7efbe08
Show file tree
Hide file tree
Showing 20 changed files with 276 additions and 283 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -2,7 +2,7 @@ language: julia
os:
- linux
julia:
- 0.5
- 0.6
notifications:
email: false
before_script:
Expand Down
2 changes: 1 addition & 1 deletion REQUIRE
@@ -1,4 +1,4 @@
julia 0.5
julia 0.6
POMDPs 0.4
Distributions
JSON
30 changes: 15 additions & 15 deletions src/DESPOT.jl
Expand Up @@ -13,21 +13,21 @@ import POMDPs:

include("history.jl")

abstract DESPOTBeliefUpdate{S,A,O}
abstract type DESPOTBeliefUpdate{S,A,O} end

typealias DESPOTReward Float64
const DESPOTReward = Float64

type DESPOTRandomNumber <: POMDPs.AbstractRNG
mutable struct DESPOTRandomNumber <: POMDPs.AbstractRNG
number::Float64
end

type DESPOTParticle{S}
mutable struct DESPOTParticle{S}
state::S
id::Int64
weight::Float64
end

type DESPOTBelief{S,A,O}
mutable struct DESPOTBelief{S,A,O}
particles::Vector{DESPOTParticle{S}}
history::History{A,O}
end
Expand All @@ -37,18 +37,18 @@ function rand!(rng::DESPOTRandomNumber, random_number::Array{Float64})
return nothing
end

type DESPOTDefaultRNG <: POMDPs.AbstractRNG
mutable struct DESPOTDefaultRNG <: POMDPs.AbstractRNG
seed::Array{UInt32,1}
rand_max::Int64
debug::Int64

function DESPOTDefaultRNG(seed::UInt32, rand_max::Int64, debug::Int64 = 0)
this = new()
this.seed = Cuint[seed]
this.rand_max = rand_max
this.debug = debug
return this
end
end
end

function rand!(rng::DESPOTDefaultRNG, random_number::Array{Float64})
Expand All @@ -67,7 +67,7 @@ include("nodes.jl")
include("utils.jl")
include("solver.jl")

type DESPOTPolicy{S,A,O,B} <: POMDPs.Policy
mutable struct DESPOTPolicy{S,A,O,B} <: POMDPs.Policy
solver::DESPOTSolver{S,A,O,B}
pomdp ::POMDPs.POMDP{S,A,O}
end
Expand All @@ -79,9 +79,9 @@ create_policy{S,A,O,B}(solver::DESPOTSolver{S,A,O,B}, pomdp::POMDPs.POMDP{S,A,O}
bounds{S,A,O,B}(bounds::B,
pomdp::POMDP{S,A,O},
particles::Vector{DESPOTParticle{S}},
config::DESPOTConfig) =
config::DESPOTConfig) =
error("no bounds() method found for $(typeof(bounds)) type")

init_bounds{S,A,O,B}(bounds::B,
pomdp::POMDPs.POMDP{S,A,O},
config::DESPOTConfig) = nothing
Expand All @@ -105,7 +105,7 @@ end
# for any kind of belief besides DESPOTBelief
function action{S,A,O}(p::DESPOTPolicy{S,A,O}, b)
N = p.solver.config.n_particles
pool = Array(DESPOTParticle{S}, N)
pool = Array{DESPOTParticle{S}}(N)
w = 1.0/N
for i in 1:N
pool[i] = DESPOTParticle{S}(rand(p.solver.rng, b), i-1, w)
Expand All @@ -130,10 +130,10 @@ end
@req actions(::P)
@req generate_sor(::P,::S,::A,::typeof(create_rng(solver.random_streams)))
@req reward(::P,::S,::A,::S)
@req discount(::P)
@req discount(::P)
@req isterminal(::P,::S)
as = actions(pomdp)
@req iterator(::typeof(as))
@req iterator(::typeof(as))
end

include("visualization.jl")
Expand Down Expand Up @@ -170,5 +170,5 @@ export
########## Visualization #######
blink,
DESPOTVisualizer

end #module
66 changes: 31 additions & 35 deletions src/beliefUpdate/beliefUpdateParticle.jl
@@ -1,7 +1,7 @@
import POMDPs: update, initialize_belief, updater
using DESPOT

type DESPOTBeliefUpdater{S,A,O} <: POMDPs.Updater
mutable struct DESPOTBeliefUpdater{S,A,O} <: POMDPs.Updater
pomdp::POMDP
num_updates::Int64
rng::AbstractRNG
Expand All @@ -10,15 +10,15 @@ type DESPOTBeliefUpdater{S,A,O} <: POMDPs.Updater
belief_update_seed::UInt32
particle_weight_threshold::Float64
eff_particle_fraction::Float64

#pre-allocated variables (TODO: add the rest at some point)
n_particles::Int64
next_state::S
observation::O
new_particle::DESPOTParticle{S}
n_sampled::Int64
obs_probability::Float64

#default constructor
function DESPOTBeliefUpdater(pomdp::POMDP{S,A,O};
seed::UInt32 = convert(UInt32, 42),
Expand All @@ -28,18 +28,18 @@ type DESPOTBeliefUpdater{S,A,O} <: POMDPs.Updater
eff_particle_fraction::Float64 = 0.05,
next_state::S = S(),
observation::O = O(),
rng::AbstractRNG=DESPOTDefaultRNG(convert(UInt32, seed $ (n_particles+1)), rand_max)
rng::AbstractRNG=DESPOTDefaultRNG(convert(UInt32, seed (n_particles+1)), rand_max)
)
this = new()
this.pomdp = pomdp
this.num_updates = 0
this.belief_update_seed = seed $ (n_particles + 1)
this.num_updates = 0
this.belief_update_seed = seed (n_particles + 1)
this.rng = rng
this.rand_max = rand_max
this.particle_weight_threshold = particle_weight_threshold
this.eff_particle_fraction = eff_particle_fraction
this.n_particles = n_particles

# init preallocated variables
this.next_state = next_state
this.observation = observation
Expand All @@ -51,34 +51,34 @@ type DESPOTBeliefUpdater{S,A,O} <: POMDPs.Updater
end

# Special create_belief version for DESPOTBeliefUpdater
create_belief{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O}) =
DESPOTBelief{S,A,O}(Array(DESPOTParticle{S}, bu.n_particles), History{A,O}())
create_belief{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O}) =
DESPOTBelief{S,A,O}(Array{DESPOTParticle{S}}(bu.n_particles), History{A,O}())

get_belief_update_seed(bu::DESPOTBeliefUpdater) = bu.seed $ (bu.n_particles + 1)
get_belief_update_seed(bu::DESPOTBeliefUpdater) = bu.seed (bu.n_particles + 1)

reset_belief(bu::DESPOTBeliefUpdater) = bu.num_updates = 0

function initialize_belief{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O},
state_distribution::ParticleDistribution{S})

new_belief = create_belief(bu)
n_particles = length(state_distribution.particles)

# convert to DESPOTParticle type
pool = Array(DESPOTParticle{S}, n_particles)
pool = Array{DESPOTParticle{S}}(n_particles)

for i in 1:n_particles
pool[i] = DESPOTParticle{S}(state_distribution.particles[i].state,
i, # id
state_distribution.particles[i].weight)
end

DESPOT.sample_particles!(new_belief.particles,
pool,
bu.n_particles,
bu.belief_update_seed,
bu.rand_max)

#shuffle!(new_belief.particles) #TODO: uncomment if higher randomness is required
return new_belief
end
Expand All @@ -87,7 +87,7 @@ end
function initialize_belief{S}(bu::DESPOTBeliefUpdater{S}, b)
new_belief = create_belief(bu)

pool = Array(DESPOTParticle{S}, bu.n_particles)
pool = Array{DESPOTParticle{S}}(bu.n_particles)
w = 1.0/bu.n_particles
for i in 1:bu.n_particles
pool[i] = DESPOTParticle{S}(rand(bu.rng, b), i, w)
Expand All @@ -109,7 +109,7 @@ updater{S,A,O}(p::DESPOTPolicy{S,A,O}) = DESPOTBeliefUpdater{S,A,O}(p.pomdp,
rng=p.solver.rng
)

function normalize!{S}(particles::Vector{DESPOTParticle{S}})
function normalize!{S}(particles::Vector{DESPOTParticle{S}})
prob_sum = 0.0
for p in particles
prob_sum += p.weight
Expand All @@ -123,12 +123,12 @@ function update{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O},
current_belief::DESPOTBelief{S},
action::A,
obs::O)

updated_belief::DESPOTBelief{S} = create_belief(bu)
random_number = Array{Float64}(1)

if bu.n_particles != length(current_belief.particles)
err("belief size mismatch: belief updater - $(bu.n_particles) particles, belief - $(length(current_belief.particles))")
err("belief size mismatch: belief updater - $(bu.n_particles) particles, belief - $(length(current_belief.particles))")
end
updated_belief.particles = []

Expand All @@ -145,20 +145,20 @@ function update{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O},
else
rng = bu.rng
end

bu.next_state = POMDPs.generate_s(bu.pomdp, p.state, action, rng)

#get observation distribution for (s,a,s') tuple
od = POMDPs.observation(bu.pomdp, p.state, action, bu.next_state)

bu.obs_probability = pdf(od, obs)

if bu.obs_probability > 0.0
bu.new_particle = DESPOTParticle(bu.next_state, p.id, p.weight * bu.obs_probability)
bu.new_particle = DESPOTParticle(bu.next_state, p.id, p.weight * bu.obs_probability)
push!(updated_belief.particles, bu.new_particle)
end
end

normalize!(updated_belief.particles)

if length(updated_belief.particles) == 0
Expand All @@ -173,7 +173,7 @@ function update{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O},
#Pick a random particle from the current belief state as the initial state
rand!(resample_rng, random_number)
particle_number = ceil(bu.n_particles * random_number[1])

next_state = POMDPs.rand(resample_rng, states(bu.pomdp)) #generate a random state
od = POMDPs.observation(bu.pomdp,
current_belief.particles[particle_number].state,
Expand All @@ -191,7 +191,7 @@ function update{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O},
end

# Remove all particles below the threshold weight
viable_particle_indices = Array(Int64,0)
viable_particle_indices = Array{Int64}(0)
for i in 1:length(updated_belief.particles)
if updated_belief.particles[i].weight >= bu.particle_weight_threshold
push!(viable_particle_indices, i)
Expand All @@ -213,21 +213,17 @@ function update{S,A,O}(bu::DESPOTBeliefUpdater{S,A,O},
num_eff_particles = 1./num_eff_particles
if (num_eff_particles < bu.n_particles * bu.eff_particle_fraction) ||
(length(updated_belief.particles) < bu.n_particles)
resampled_set = Array(DESPOTParticle{S}, bu.n_particles)
sample_particles!(resampled_set,
resampled_set = Array{DESPOTParticle{S}}(bu.n_particles)
sample_particles!(resampled_set,
updated_belief.particles,
bu.n_particles,
bu.belief_update_seed,
bu.rand_max)
updated_belief.particles = resampled_set
end

# Finally, update history
add(updated_belief.history, action, obs)

return updated_belief
end




8 changes: 4 additions & 4 deletions src/config.jl
@@ -1,4 +1,4 @@
type DESPOTConfig
mutable struct DESPOTConfig
# Maximum depth of the search tree
search_depth::Int64
# Random-number seed
Expand All @@ -21,10 +21,10 @@ type DESPOTConfig
# particle_weight_threshold::Float64
# eff_particle_fraction::Float64
tiny::Float64 # tiny number
max_trials::Int64
rand_max::Int64
max_trials::Int64
rand_max::Int64
debug::UInt8

# construct empty
function DESPOTConfig()
this = new()
Expand Down
6 changes: 3 additions & 3 deletions src/history.jl
@@ -1,13 +1,13 @@
type History{A,O}
mutable struct History{A,O}
actions::Vector{A}
observations::Vector{O}
History() = new(Array(A, 0), Array(O, 0))
History{A,O}() where {A,O} = new(Array{A}(0), Array{O}(0))
end

function add{A,O}(history::History{A,O},
action::A,
obs::O)

push!(history.actions, action)
push!(history.observations, obs)
end
Expand Down
4 changes: 2 additions & 2 deletions src/lowerBound/DESPOTDefaultLowerBound.jl
@@ -1,14 +1,14 @@
import DESPOT:
lower_bound

type DESPOTDefaultLowerBound
mutable struct DESPOTDefaultLowerBound
#placeholder for now
end

function init_bound{S,A,O}(ub::DESPOTDefaultLowerBound,
pomdp::POMDP{S,A,O},
config::DESPOTConfig)
error("Function init_bound for $(typeof(ub)) has not been implemented yet")
error("Function init_bound for $(typeof(ub)) has not been implemented yet")
end

function lower_bound{S,A,O}(lb::DESPOTDefaultLowerBound,
Expand Down

0 comments on commit 7efbe08

Please sign in to comment.