# Reinforcement Learning Tutorial for POMDPs.jl

inspired from CS234

In [1]:
using POMDPs

## Install Reinforcement Learning algorithms

In [2]:
# POMDPs.add("TabularTDLearning")

In [3]:
using TabularTDLearning, POMDPToolbox

## Problem overview

### Wall-E exploration


Let's consider a 1D gridworld and 2 different rewards at each extremities. Wall-E must find the plant but he might greedily join his lover instead...

![wall-e-mdp](initial_state.png)

The environment can be modeled as follow

- 10 states: 1,2,3,4,5,6,7,8,9,10
- 2 actions: left, right
- one episode lasts until a reward is found
- there is a reward of +1 in state 1 (Eve) and +2 in state 10 (the plant)
- Wall-E starts at 3


In [4]:
mutable struct MarsExp <: MDP{Int64, Symbol}
    r_left::Float64
    r_right::Float64
    start::Int64
    γ::Float64
    MarsExp(;r_left::Float64 = 1., r_right::Float64 = 2.,start::Int64 = 5, γ::Float64 = 0.9) = new(r_left, r_right, start, γ)
end

In [5]:
@requirements_info QLearningSolver MarsExp()

LoadError: [91mMethodError: no method matching requirements_info(::Type{TabularTDLearning.QLearningSolver}, ::MarsExp)[0m
Closest candidates are:
  requirements_info([91m::Union{POMDPs.Simulator, POMDPs.Solver}[39m, ::Union{POMDPs.MDP, POMDPs.POMDP}, [91m::Any...[39m) at C:\Users\Maxime\.julia\v0.6\POMDPs\src\requirements_interface.jl:140[39m

In [6]:
function POMDPs.states(mdp::MarsExp)
    return 1:1:10
end
POMDPs.state_index(mdp::MarsExp, s::Int64) = s
POMDPs.n_states(mdp::MarsExp) = 10

In [7]:
function POMDPs.actions(mdp::MarsExp)
    return [:left, :right]
end
POMDPs.action_index(mdp::MarsExp, a::Symbol) = a == :left ? 1 : 2
POMDPs.n_actions(mdp::MarsExp) = 2

In [8]:
function POMDPs.generate_s(mdp::MarsExp, s::Int64, a::Symbol, rng::AbstractRNG)
    if a == :left
        return max(1, s-1)
    elseif a == :right
        return min(10, s+1)
    end
end
function POMDPs.reward(mdp::MarsExp, s::Int64, a::Symbol, sp::Int64)
    if sp == 1
        return mdp.r_left
    elseif sp == 10
        return mdp.r_right
    else
        return 0.0
    end
end     
function POMDPs.initial_state(mdp::MarsExp, rng::AbstractRNG)
    return mdp.start
end

In [9]:
function POMDPs.isterminal(mdp::MarsExp, s::Int64)
    return s == 1 || s == 10
end

## Solve with Q-learning

First we need to initialize the solver with the desired hyper parameters:

In [10]:
@requirements_info QLearningSolver(MarsExp()) MarsExp()


INFO: POMDPs.jl requirements for [34msolve(::QLearningSolver, ::Union{POMDPs.MDP,POMDPs.POMDP})[39m and dependencies. ([✔] = implemented correctly; [X] = missing)

For [34msolve(::QLearningSolver, ::Union{POMDPs.MDP,POMDPs.POMDP})[39m:
[32m  [✔] initial_state(::MarsExp, ::AbstractRNG)[39m
[32m  [✔] generate_sr(::MarsExp, ::Int64, ::Symbol, ::AbstractRNG)[39m
[32m  [✔] state_index(::MarsExp, ::Int64)[39m
[32m  [✔] action_index(::MarsExp, ::Symbol)[39m
[32m  [✔] discount(::MarsExp)[39m



true

Then initialize the problem and the solver using the desired hyper-parameters

In [11]:
mdp = MarsExp(start=2)
solver = QLearningSolver(mdp, learning_rate=0.1, n_episodes=400, max_episode_length=50, eval_every=50, n_eval_traj=100, 
                         exp_policy=EpsGreedyPolicy(mdp, 0.9));

We are now ready to solve for the optimal policy!

In [12]:
policy = solve(solver, mdp)

On Iteration 50, Returns: 1.0
On Iteration 100, Returns: 1.0
On Iteration 150, Returns: 1.0
On Iteration 200, Returns: 1.0
On Iteration 250, Returns: 1.0
On Iteration 300, Returns: 1.0
On Iteration 350, Returns: 1.0
On Iteration 400, Returns: 1.0


POMDPToolbox.ValuePolicy{Any}(MarsExp(1.0, 2.0, 2, 0.9), [0.0 0.0; 1.0 1.0; … ; 0.797659 1.43514; 0.0 0.0], Any[:left, :right])

## Simulate

In [13]:
include("render.jl")

In [14]:
rng = MersenneTwister(3)
hist = HistoryRecorder(max_steps=1000)

hist = simulate(hist, mdp, policy, initial_state(mdp, rng))

POMDPToolbox.MDPHistory{Int64,Symbol}([2, 1], Symbol[:left], [1.0], 1.0, Nullable{Exception}(), Nullable{Any}())

## Some interesting experiments

If you reduce exploration, you end up in a suboptimal policy.

If the discount factor is too low then you end up with a greedier policy