Skip to content

Commit

Permalink
Merge ba51c76 into 9d47785
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeBouton committed Oct 7, 2019
2 parents 9d47785 + ba51c76 commit ad9cb6f
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 16 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
POMDPs = "0.7.3, 0.8"
julia = "1.0"

[extras]
POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
Expand Down
3 changes: 1 addition & 2 deletions src/actions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ const ACTION_DIRS = (DSPos(0,1),
DSPos(0,-1),
DSPos(-1,0),
DSPos(0,0))

POMDPs.n_actions(pomdp::DroneSurveillancePOMDP) = N_ACTIONS

POMDPs.actions(pomdp::DroneSurveillancePOMDP) = 1:N_ACTIONS
POMDPs.actionindex(pomdp::DroneSurveillancePOMDP, a::Int64) = a

Expand Down
2 changes: 0 additions & 2 deletions src/observation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ const N_OBS_QUAD = 6

POMDPs.observations(pomdp::DroneSurveillancePOMDP{QuadCam}) = 1:N_OBS_QUAD
POMDPs.observations(pomdp::DroneSurveillancePOMDP{PerfectCam}) = 1:N_OBS_PERFECT
POMDPs.n_observations(pomdp::DroneSurveillancePOMDP{QuadCam}) = N_OBS_QUAD
POMDPs.n_observations(pomdp::DroneSurveillancePOMDP{PerfectCam}) = N_OBS_PERFECT
POMDPs.obsindex(pomdp::DroneSurveillancePOMDP, o::Int64) = o

function POMDPs.observation(pomdp::DroneSurveillancePOMDP{QuadCam}, a::Int64, s::DSState)
Expand Down
11 changes: 4 additions & 7 deletions src/states.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
POMDPs.n_states(pomdp::DroneSurveillancePOMDP) = (pomdp.size[1] * pomdp.size[2])^2 + 1

function POMDPs.stateindex(pomdp::DroneSurveillancePOMDP, s::DSState)
if isterminal(pomdp, s)
return n_states(pomdp)
return length(pomdp)
end
nx, ny = pomdp.size
LinearIndices((nx, ny, nx, ny))[s.quad[1], s.quad[2], s.agent[1], s.agent[2]]
end

function state_from_index(pomdp::DroneSurveillancePOMDP, si::Int64)
if si == n_states(pomdp)
if si == length(pomdp)
return pomdp.terminal_state
end
nx, ny = pomdp.size
Expand All @@ -21,17 +19,16 @@ end
# we define an iterator over it

POMDPs.states(pomdp::DroneSurveillancePOMDP) = pomdp
Base.length(pomdp::DroneSurveillancePOMDP) = (pomdp.size[1] * pomdp.size[2])^2 + 1

function Base.iterate(pomdp::DroneSurveillancePOMDP, i::Int64 = 1)
if i > n_states(pomdp)
if i > length(pomdp)
return nothing
end
s = state_from_index(pomdp, i)
return (s, i+1)
end

Base.length(pomdp::DroneSurveillancePOMDP) = n_states(pomdp)

function POMDPs.initialstate(pomdp::DroneSurveillancePOMDP, rng::AbstractRNG)
rand(rng, initialstate_distribution(pomdp))
end
Expand Down
4 changes: 2 additions & 2 deletions src/transition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ function POMDPs.transition(pomdp::DroneSurveillancePOMDP, s::DSState, a::Int64)
end

# move agent
new_states = MVector{n_actions(pomdp), DSState}(undef)
probs = @MVector(zeros(n_actions(pomdp)))
new_states = MVector{N_ACTIONS, DSState}(undef)
probs = @MVector(zeros(N_ACTIONS))
for (i, act) in enumerate(ACTION_DIRS)
new_agent = s.agent + act
if agent_inbounds(pomdp, new_agent)
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ end
pomdp = DroneSurveillancePOMDP()
state_iterator = states(pomdp)
ss = ordered_states(pomdp)
@test length(ss) == n_states(pomdp)
@test length(ss) == length(pomdp)
@test test_state_indexing(pomdp, ss)
pomdp = DroneSurveillancePOMDP(size=(7, 10))
state_iterator = states(pomdp)
ss = ordered_states(pomdp)
@test length(ss) == n_states(pomdp)
@test length(ss) == length(pomdp)
@test test_state_indexing(pomdp, ss)
end

@testset "action space" begin
pomdp = DroneSurveillancePOMDP()
acts = actions(pomdp)
@test acts == ordered_actions(pomdp)
@test length(acts) == n_actions(pomdp)
@test length(acts) == length(actions(pomdp))
@test length(acts) == length(DroneSurveillance.ACTION_DIRS)
end

Expand Down

0 comments on commit ad9cb6f

Please sign in to comment.