Skip to content

Commit

Permalink
updated convert to comply with JuliaPOMDP/POMDPs.jl#85
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Aug 17, 2017
1 parent d3cc1ed commit 64ca404
Show file tree
Hide file tree
Showing 17 changed files with 33 additions and 70 deletions.
10 changes: 0 additions & 10 deletions src/CryingBabies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,6 @@ function generate_o(p::BabyPOMDP, s::Bool, rng::AbstractRNG)
return rand(rng, d)
end

# same for both state and observation
function Base.convert(::Type{Array{Float64}}, so::Bool, prob::BabyPOMDP)
v = copy!(Array{Float64}(1), so)
return v
end

function Base.convert(::Type{Bool}, so::Vector{Float64}, prob::BabyPOMDP)
return Bool(so[1])
end

# some example policies
mutable struct Starve <: Policy end
action{B}(::Starve, ::B) = false
Expand Down
3 changes: 0 additions & 3 deletions src/Discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,3 @@ function observation(prob::DiscretePOMDP, a::Int64, sp::Int64)
d.a = a
return d
end

Base.convert(::Type{Array{Float64}}, s::Int64, prob::Union{DiscreteMDP,DiscretePOMDP}) = Float64[s]
Base.convert(::Type{Int}, s::Array{Float64}, prob::Union{DiscreteMDP,DiscretePOMDP}) = Int(s[1])
8 changes: 4 additions & 4 deletions src/GridWorlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ end

discount(mdp::GridWorld) = mdp.discount_factor

Base.convert(::Type{Array{Float64}}, s::GridWorldState, mdp::GridWorld) = Float64[s.x, s.y, s.done]
Base.convert(::Type{GridWorldState}, s::Vector{Float64}, mdp::GridWorld) = GridWorldState(s[1], s[2], s[3])
convert_s(::Type{A}, s::GridWorldState, mdp::GridWorld) where A<:AbstractArray = Float64[s.x, s.y, s.done]
convert_s(::Type{GridWorldState}, s::AbstractArray, mdp::GridWorld) = GridWorldState(s[1], s[2], s[3])

function a2int(a::Symbol, mdp::GridWorld)
if a == :up
Expand Down Expand Up @@ -363,8 +363,8 @@ function int2a(a::Int, mdp::GridWorld)
end
end

Base.convert(::Type{Array{Float64}}, a::Symbol, mdp::GridWorld) = [Float64(a2int(a, mdp))]
Base.convert(::Type{Symbol}, a::Vector{Float64}, mdp::GridWorld) = int2a(Int(a[1]), mdp)
convert_a(::Type{A}, a::Symbol, mdp::GridWorld) where A<:AbstractArray = [Float64(a2int(a, mdp))]
convert_a(::Type{Symbol}, a::A, mdp::GridWorld) where A<:AbstractArray = int2a(Int(a[1]), mdp)

initial_state(mdp::GridWorld, rng::AbstractRNG) = GridWorldState(rand(rng, 1:mdp.size_x), rand(rng, 1:mdp.size_y))

Expand Down
4 changes: 2 additions & 2 deletions src/InvertedPendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ function generate_s( ip::InvertedPendulum,
return sp
end

function Base.convert(::Type{Array{Float64}}, s::Tuple{Float64,Float64}, ip::InvertedPendulum)
function convert_s(::Type{A}, s::Tuple{Float64,Float64}, ip::InvertedPendulum) where A<:AbstractArray
v = copy!(Array{Float64}(2), s)
return v
end

function Base.convert(::Type{Tuple{Float64,Float64}}, s::Vector{Float64}, ip::InvertedPendulum)
function convert_s(::Type{Tuple{Float64,Float64}}, s::A, ip::InvertedPendulum) where A<:AbstractArray
return (s[1], s[2])
end
10 changes: 2 additions & 8 deletions src/LightDark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,8 @@ function reward(p::LightDark1D, s::LightDark1DState, a::Int)
end


Base.convert(::Type{Array{Float64}}, s::LightDark1DState, p::LightDark1D) = Float64[s.status, s.y]
Base.convert(::Type{LightDark1DState}, s::Vector{Float64}, p::LightDark1D) = LightDark1DState(Int64(s[1]), s[2])

Base.convert(::Type{Array{Float64}}, o::Float64, p::LightDark1D) = Float64[o]
Base.convert(::Type{Float64}, o::Vector{Float64}, p::LightDark1D) = o[1]

Base.convert(::Type{Array{Float64}}, a::Int, p::LightDark1D) = Float64[a]
Base.convert(::Type{Int}, a::Vector{Float64}, p::LightDark1D) = Int(a[1])
convert_s(::Type{A}, s::LightDark1DState, p::LightDark1D) where A<:AbstractArray = eltype(A)[s.status, s.y]
convert_s(::Type{LightDark1DState}, s::A, p::LightDark1D) where A<:AbstractArray = LightDark1DState(Int64(s[1]), s[2])

# XXX this is specifically for MCVI
# it is also implemented in the MCVI tests
Expand Down
6 changes: 3 additions & 3 deletions src/MountainCar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ function generate_s( mc::MountainCar,
end


function Base.convert(::Type{Array{Float64}}, s::Tuple{Float64,Float64}, mc::MountainCar)
v = copy!(Array{Float64}(2), s)
function convert_s(::Type{A}, s::Tuple{Float64,Float64}, mc::MountainCar) where A<:AbstractArray
v = copy!(A(2), s)
return v
end
Base.convert(::Type{Tuple{Float64,Float64}}, s::Vector{Float64}, mc::MountainCar) = (s[1], s[2])
convert_s(::Type{Tuple{Float64,Float64}}, s::A, mc::MountainCar) where A<:AbstractArray = (s[1], s[2])

# Example policy -- works pretty well
mutable struct Energize <: Policy end
Expand Down
13 changes: 1 addition & 12 deletions src/POMDPModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,7 @@ using StaticArrays
using AutoHashEquals
using StatsBase

import POMDPs: n_states, n_actions, n_observations # space sizes for discrete problems
import POMDPs: state_index, action_index, obs_index
import POMDPs: discount, states, actions, observations # model functions
import POMDPs: transition, observation, reward, isterminal, isterminal_obs # model functions
import POMDPs: rand, pdf # common distribution functions
import POMDPs: iterator, dimensions # space functions
import POMDPs: initial_state_distribution
import POMDPs: update, updater
import POMDPs: vec

# for example policies
import POMDPs: Policy, action
importall POMDPs

import Base.rand!
import Base.rand
Expand Down
3 changes: 0 additions & 3 deletions src/TMazes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,6 @@ function Base.convert(maze::TMaze, s::TMazeState)
return v
end

Base.convert(::Type{Array{Float64}}, o::Int64, ::TMaze) = Float64[o]
Base.convert(::Type{Int64}, o::Vector{Float64}, ::TMaze) = Int64(o[1])

mutable struct MazeBelief
last_obs::Int64
mem::Symbol # memory
Expand Down
4 changes: 0 additions & 4 deletions src/TigerPOMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ function generate_o(p::TigerPOMDP, s::Bool, rng::AbstractRNG)
return rand(rng, d)
end

# same for both state and observation
Base.convert(::Type{Array{Float64}}, so::Bool, p::TigerPOMDP) = Float64[so]
Base.convert(::Type{Bool}, so::Vector{Float64}, p::TigerPOMDP) = Bool(so[1])

# This doesn't seem to work well
# type TigerBeliefUpdater <: Updater
# pomdp::TigerPOMDP
Expand Down
4 changes: 2 additions & 2 deletions test/car.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ sim = RolloutSimulator(rng=MersenneTwister(1), max_steps=1000)
r = simulate(sim, problem, policy, initial_state(problem, MersenneTwister(2)))
@test r < 0.0

sv = convert(Array{Float64}, (0.5, 0.25), problem)
sv = convert_s(Array{Float64}, (0.5, 0.25), problem)
@test sv == [0.5, 0.25]
s = convert(Tuple{Float64,Float64}, sv, problem)
s = convert_s(Tuple{Float64,Float64}, sv, problem)
@test s == (0.5, 0.25)

problem = MountainCar(discount=1.0, cost=-0.1, jackpot=100.0)
Expand Down
4 changes: 2 additions & 2 deletions test/crying.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ r = simulate(sim, problem, policy, updater(policy), ib)
o = generate_o(problem, true, MersenneTwister(1))
@test o == 1
# test vec
ov = convert(Array{Float64}, true, problem)
ov = convert_s(Array{Float64}, true, problem)
@test ov == [1.]
o = convert(Bool, ov, problem)
o = convert_s(Bool, ov, problem)
@test o == true

probability_check(problem)
Expand Down
10 changes: 5 additions & 5 deletions test/gridworld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ for i in 1:length(sim.action_hist)
end


sv = convert(Array{Float64}, GridWorldState(1, 1, false), problem)
sv = convert_s(Array{Float64}, GridWorldState(1, 1, false), problem)
@test sv == [1.0, 1.0, 0.0]
sv = convert(Array{Float64}, GridWorldState(5, 3, false), problem)
sv = convert_s(Array{Float64}, GridWorldState(5, 3, false), problem)
@test sv == [5.0, 3.0, 0.0]
s = convert(GridWorldState, sv, problem)
s = convert_s(GridWorldState, sv, problem)
@test s == GridWorldState(5, 3, false)

av = convert(Array{Float64}, :up, problem)
av = convert_a(Array{Float64}, :up, problem)
@test av == [0.0]
a = convert(Symbol, av, problem)
a = convert_a(Symbol, av, problem)
@test a == :up

@test GridWorldState(1,1,false) == GridWorldState(1,1,false)
Expand Down
4 changes: 2 additions & 2 deletions test/inverted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ sim = RolloutSimulator(MersenneTwister(1))

simulate(sim, problem, policy, initial_state(problem, MersenneTwister(2)))

sv = convert(Array{Float64}, (0.5, 0.25), problem)
sv = convert_s(Array{Float64}, (0.5, 0.25), problem)
@test sv == [0.5, 0.25]
s = convert(Tuple{Float64,Float64}, sv, problem)
s = convert_s(Tuple{Float64,Float64}, sv, problem)
@test s == (0.5, 0.25)
8 changes: 4 additions & 4 deletions test/lightdark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ obs = generate_o(p, nothing, nothing, s2, rng)
@test abs(obs-6.0) <= 1.1


sv = convert(Array{Float64}, s2, p)
sv = convert_s(Array{Float64}, s2, p)
@test sv == [0.0, 5.0]
s = convert(LightDark1DState, sv, p)
s = convert_s(LightDark1DState, sv, p)
@test s == s2

ov = convert(Array{Float64}, obs, p)
ov = convert_o(Array{Float64}, obs, p)
@test ov == [obs]
o = convert(Float64, ov, p)
o = convert_o(Float64, ov, p)
@test o == obs
4 changes: 2 additions & 2 deletions test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ sim = RolloutSimulator(rng=MersenneTwister(3), max_steps=100)
simulate(sim, pomdp, policy, updater(policy), initial_state_distribution(pomdp))
probability_check(pomdp)

ov = convert(Array{Float64}, 1, pomdp)
ov = convert_o(Array{Float64}, 1, pomdp)
@test ov == [1.]
o = convert(Int, ov, pomdp)
o = convert_o(Int, ov, pomdp)
@test o == 1
4 changes: 2 additions & 2 deletions test/tiger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ simulate(sim, pomdp1, policy, updater(policy), initial_state_distribution(pomdp1
o = generate_o(pomdp1, true, MersenneTwister(1))
@test o == 1
# test vec
ov = convert(Array{Float64}, true, pomdp1)
ov = convert_o(Array{Float64}, true, pomdp1)
@test ov == [1.]
o = convert(Bool, ov, pomdp1)
o = convert_o(Bool, ov, pomdp1)
@test o == true

probability_check(pomdp1)
Expand Down
4 changes: 2 additions & 2 deletions test/tmaze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ test_obs(TMazeState(5, :south, false), 3) # corridor
test_obs(TMazeState(11, :south, false), 4) # junction
test_obs(TMazeState(11, :south, true), 5) # terminal

ov = convert(Array{Float64}, 1, problem)
ov = convert_o(Array{Float64}, 1, problem)
@test ov == [1.]
o = convert(Int64, ov, problem)
o = convert_o(Int64, ov, problem)
@test o == 1


0 comments on commit 64ca404

Please sign in to comment.