In [18]:
using Curiosity, Statistics, Plots, FileIO, BSON
using Random, StatsBase, RollingFunctions
using GVFHordes

const MCU = Curiosity.MountainCarUtils

Curiosity.MountainCarUtils

In [19]:
# Cumulant
struct Reward end

(c::Reward)(;kwargs...) = kwargs[:r]

struct ConstantDiscount{F}
    val::F
end
   
(d::ConstantDiscount)(;kwargs...) = d.val

In [20]:
# Environment 
function construct_env()
    normalized = true
    env = MountainCar(0.0, 0.0, normalized)
end


construct_env (generic function with 1 method)

In [21]:
# Agent
function construct_agent(numtilings, numtiles, lu_str, α, λ, ϵ, γ)
    obs_size = 2
    fc = Curiosity.SparseTileCoder(numtilings, numtiles, obs_size)
    feature_size = size(fc)

    lu = if lu_str == "ESARSA"
        ESARSA(lambda=λ, opt=Curiosity.Descent(α))
    elseif lu_str == "SARSA"
        SARSA(lambda=λ, opt=Curiosity.Descent(α))
    elseif lu_str == "TB"
        TB(lambda=λ, opt=Curiosity.Descent(α))
    else
        throw(ArgumentError("$(lu_str) Not a valid behaviour learning update"))
    end
#     (update, num_features, num_actions, num_demons, w_init)
    learner = LinearQLearner(lu, feature_size, 3, 1,0)
    exploration = EpsilonGreedy(ϵ)
    cumulant = Reward()
    discount = ConstantDiscount(γ)
    
    b_gvf = make_behaviour_gvf(learner, discount, fc, exploration)
    b_demons = Horde([b_gvf])

    Curiosity.PolicyLearner(learner, 
                            fc, 
                            exploration, 
                            discount, 
                            cumulant, 
                            zeros(2), 
                            0,
                            b_demons)
end

construct_agent (generic function with 1 method)

In [22]:
function make_behaviour_gvf(behaviour_learner, γ, fc, exploration_strategy)
    function b_π(state_constructor, learner, exploration_strategy; kwargs...)
        s = state_constructor(kwargs[:state_t])
        preds = learner(s)
        return exploration_strategy(preds)[kwargs[:action_t]]
    end
    GVF_policy = GVFParamFuncs.FunctionalPolicy((;kwargs...) -> b_π(fc, behaviour_learner, exploration_strategy; kwargs...))
    BehaviourGVF = GVF(GVFParamFuncs.RewardCumulant(), GVFParamFuncs.ConstantDiscount(γ), GVF_policy)
end

make_behaviour_gvf (generic function with 1 method)

In [23]:
# Learn the policy
seed = 1029
Random.seed!(seed)
numtilings, numtiles = 8, 8
lu_str = "TB"
α = 0.1/numtilings
λ = 0.9
ϵ = 0.1
γ = 0.99

info = Dict(
    "seed"=>seed,
    "numtilings"=>numtilings,
    "numtiles"=>numtiles,
    "lu"=>"lu_str",
    "α"=>α,
    "λ"=>λ,
    "ϵ"=>ϵ,
    "γ"=>γ,
    "rew"=>"Env"
)


env = construct_env()

agent = construct_agent(numtilings, numtiles, lu_str, α, λ, ϵ, γ)

steps = Int[]
ret = Float64[]
max_num_steps = 100000
eps = 0
while sum(steps) < max_num_steps
    is_terminal = false

    max_episode_steps = min(max_num_steps - sum(steps), 1000)
    s = start!(env)
    a = start!(agent, s)
    stp = 0
    a = 0
    tr = 0.0
    while !is_terminal && stp <= max_episode_steps
        s, r, is_terminal = MinimalRLCore.step!(env, a)
        println(s)
        println("ERE")
        a = MinimalRLCore.step!(agent, s, r, is_terminal)
        tr += r
        stp += 1
    end

    push!(steps, stp)
    push!(ret, tr)
    
    eps += 1
end
    
Curiosity.save(agent, "policy.bson", info)

LoadError: MethodError: no method matching GVF(::GVFHordes.GVFParamFuncs.RewardCumulant, ::ConstantDiscount{Float64}, ::GVFHordes.GVFParamFuncs.FunctionalPolicy{var"#22#26"{var"#22#23#27"{QLearner{Matrix{Float64}, TB{Flux.Optimise.Descent, AccumulatingTraces}}, TileCoder{SparseArrays.SparseVector{Int64, Ti} where Ti<:Integer}, EpsilonGreedy}}})
[0mClosest candidates are:
[0m  GVF(::C, [91m::D[39m, ::P) where {C<:GVFHordes.GVFParamFuncs.AbstractCumulant, D<:GVFHordes.GVFParamFuncs.AbstractDiscount, P<:GVFHordes.GVFParamFuncs.AbstractPolicy} at /home/matthewmcleod/Documents/Masters/curiosity/src/GVFHordes/src/GVFHordes.jl:70

In [7]:
plot(rollmean(ret, 100))

LoadError: UndefVarError: ret not defined

In [34]:
# Save the policy