Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DictPolicy and special Q-learning based on key-value storage #459

Open
NeroBlackstone opened this issue Feb 8, 2023 · 2 comments
Open

Comments

@NeroBlackstone
Copy link
Contributor

NeroBlackstone commented Feb 8, 2023

If we have a discrete space, discrete action, generative MDP.
And states space and actions space are hard to enumerate. But we still want to use the traditional tabular RL algorithm to solve it.
So, I implement a DictPolicy, it used to store state-action pair values. (Sure. Users need to add Base.isequal() and Base.hash() for their state and action type.)

DictPolicy.jl :

struct DictPolicy{P<:Union{POMDP,MDP}, T<:AbstractDict{Tuple,Float64}} <: Policy
    mdp::P
    value_dict::T
end

# Returns the action that the policy deems best for the current state
function action(p::DictPolicy, s)
    available_actions = actions(mdp,s)
    max_action = nothing
    max_action_value = 0
    for a in available_actions
        if haskey(p.value_dict,(s,a))
            action_value = p.value_dict[(s,a)]
            if action_value > max_action_value
                max_action = a
                max_action_value = action_value
            end
        else
            p.value_dict[(s,a)] = 0
        end
    end
    if max_action === nothing
        max_action = available_actions[1]
    end
    return max_action
end

# returns the values of each action at state s in a dict
function actionvalues(p::DictPolicy, s) ::Dict
    available_actions = actions(mdp,s)
    action_dict = Dict()
    for a in available_actions
        haskey(p.value_dict,(s,a)) ? action_dict[a]  = value_dict[(s,a)] : action_dict[a] = 0
    end
    return action_dict
end

function Base.show(io::IO, mime::MIME"text/plain", p::DictPolicy{M}) where M <: MDP
    summary(io, p)
    println(io, ':')
    ds = get(io, :displaysize, displaysize(io))
    ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds)))
    showpolicy(io, mime, p.mdp, p)
end

Then we have a special Q-learning based on key-value storage, we don't need to enumerate states space and actions space in MDP definition. (okay, most code copy from TabularTDLearning.jl, but change Q-value store and read.

dict_q_learning.jl :

@with_kw mutable struct QLearningSolver{E<:ExplorationPolicy} <: Solver
   n_episodes::Int64 = 100
   max_episode_length::Int64 = 100
   learning_rate::Float64 = 0.001
   exploration_policy::E
   Q_vals::Union{Nothing, Dict{Tuple,Float64}} = nothing
   eval_every::Int64 = 10
   n_eval_traj::Int64 = 20
   rng::AbstractRNG = Random.GLOBAL_RNG
   verbose::Bool = true
end

function solve(solver::QLearningSolver, mdp::MDP)
    rng = solver.rng
    if solver.Q_vals === nothing
        Q = Dict{Tuple,Float64}()
    else
        Q = solver.Q_vals
    end
    exploration_policy = solver.exploration_policy
    sim = RolloutSimulator(rng=rng, max_steps=solver.max_episode_length)

    on_policy = DictPolicy(mdp, Q)
    k = 0
    for i = 1:solver.n_episodes
        s = rand(rng, initialstate(mdp))
        t = 0
        while !isterminal(mdp, s) && t < solver.max_episode_length
            a = action(exploration_policy, on_policy, k, s)
            k += 1
            sp, r = @gen(:sp, :r)(mdp, s, a, rng)
            max_sp_prediction = 0
            for k in keys(Q)
                if sp == k[1] && max_sp_prediction < Q[k]
                    max_sp_prediction = Q[k]
                end
            end
            current_s_prediction = 0 
            haskey(Q,(s,a)) ? (current_s_prediction = Q[(s,a)]) : (Q[(s,a)] = 0)
            Q[(s,a)] += solver.learning_rate * (r + discount(mdp) * max_sp_prediction - current_s_prediction)
            s = sp
            t += 1
        end
        if i % solver.eval_every == 0
            r_tot = 0.0
            for traj in 1:solver.n_eval_traj
                r_tot += simulate(sim, mdp, on_policy, rand(rng, initialstate(mdp)))
            end
            solver.verbose ? println("On Iteration $i, Returns: $(r_tot/solver.n_eval_traj)") : nothing
        end
    end
    return on_policy
end

What's your point of view? Do you have any advice?
Thank you for taking the time to read my issue.
If you think it's meaningful, I can opne a PR and add some test.
It's okay if you think it's meaningless and no versatility. I just finish it for solve my MDP.

@zsunberg
Copy link
Member

@NeroBlackstone sorry that we never responded to this! This is actually something that people often want to do. If you're still interested in contributing it, I think we can integrate it in with a few small adjustments. Let me know if you're interested in doing that.

@NeroBlackstone
Copy link
Contributor Author

NeroBlackstone commented May 14, 2023

Hi, thanks for your comment.
I'm ready to contribute to this feature.

I will do these things:

  1. Add DictPolicy and some test code in POMDPs.jl.
  2. Once DictPolicy is merged, I will contribute a Q-Learning solver and Prioritized Sweeping using this policy in TabularTDLearning.jl.
  3. A vanilla Prioritized Sweeping in TabularTDLearning.jl

I will open PR for the first step soon. If there are code problems, please point them out.

Thank you very much again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants