Skip to content

Commit

Permalink
handle constrained actions
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeBouton committed Aug 9, 2019
1 parent aaec63e commit b70b062
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/sparse_tabular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ function transition_matrix_a_s_sp(mdp::Union{MDP, POMDP})
end
end
transmats_A_S_S2 = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], n_states(mdp), n_states(mdp)) for a in 1:n_actions(mdp)]
@assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
# if an action is not valid from a state, the transition is 0.0 everywhere
# @assert all(all(sum(transmats_A_S_S2[a], dims=2) .≈ ones(n_states(mdp))) for a in 1:n_actions(mdp)) "Transition probabilities must sum to 1"
return transmats_A_S_S2
end

Expand Down Expand Up @@ -251,6 +252,7 @@ POMDPs.n_actions(prob::SparseTabularProblem) = size(prob.T, 1)

POMDPs.states(p::SparseTabularProblem) = 1:n_states(p)
POMDPs.actions(p::SparseTabularProblem) = 1:n_actions(p)
POMDPs.actions(p::SparseTabularProblem, s::Int64) = [a for a in actions(p) if sum(transition_matrix(p, a)) n_states(p)]

POMDPs.stateindex(::SparseTabularProblem, s::Int64) = s
POMDPs.actionindex(::SparseTabularProblem, a::Int64) = a
Expand Down
2 changes: 2 additions & 0 deletions test/test_tabular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ smdp3 = SparseTabularMDP(smdp, discount = 0.5)
gw = SimpleGridWorld()
sparsegw = SparseTabularMDP(gw)
@test isterminal(sparsegw, 101)
@inferred actions(sparsegw, 101)
@test actions(sparsegw, 101) == collect(actions(sparsegw))

## POMDP

Expand Down

0 comments on commit b70b062

Please sign in to comment.