In [None]:
import Pkg; Pkg.activate(".")

In [None]:
using ArgCheck
using Distributions
using HMMBase
using LinearAlgebra
using Random
using Test

In [None]:
import POMDPs: actionindex, actions, dimensions, discount, reward, stateindex, states, transition
import POMDPs: MDP
import POMDPModelTools: SparseCat

In [None]:
include("src/spaces.jl")
include("src/problem.jl")
# TODO: Tests

## Simulation

In [None]:
using Distributions
using DiscreteValueIteration
using POMDPs
using POMDPModelTools
using POMDPSimulators
using PyPlot

### JONS paper
#### 8.1 A first simple example

In [None]:
# TODO: Use DiscreteNonParametric instead of 0-variance Normal distn.
p1 = HMM(ones(1,1), [Normal(8,0)])
p2 = HMM([0.99 0.01; 0.02 0.98], [Normal(5,0), Normal(10,0)])
mdp = MonitoringMDP([p1,p2], [300, 300], [0.65, 0.65], 0.01);

In [None]:
@time smdp = SparseTabularMDP(mdp);

In [None]:
smdp = SparseTabularMDP(mdp);

In [None]:
smdp = SparseTabularMDP(smdp, discount = 0.99);

In [None]:
solver = SparseValueIterationSolver(max_iterations=1000, belres=1e-6, verbose=true)
res = solve(solver, smdp);

In [None]:
function belief_1d(mdp::MonitoringMDP, p::Int, k::Int)
    states_ = states(mdp)
    belief = Vector{Float64}(undef, length(states_))
    model = mdp.models[p]
    for (i, (state)) in enumerate(states_)
        timesteps, laststate = getstate(state)[p]
        belief[i] = (model.A^timesteps)[laststate,k]
    end
    belief
end

In [None]:
# TODO: Plot value function

In [None]:
belief = belief_1d(mdp, 2, 1)
fig, ax = subplots(figsize = (3, 1.0))
ax.scatter(belief, ones(length(belief)), c = res.policy, s = 1.0)

In [None]:
struct ConstantPolicy <: Policy
    action::CartesianIndex
end
POMDPs.action(policy::ConstantPolicy, _) = policy.action

In [None]:
struct MDPPolicy <: Policy
    mdp::MonitoringMDP
    policy::Vector{Int}
end

function MDPPolicy(mdp::MonitoringMDP, policy::ValueIterationPolicy)
    MDPPolicy(mdp, policy.policy)
end

function POMDPs.action(policy::MDPPolicy, s)
    state = stateindex(mdp, s)
    action = policy.policy[state]
    actions(mdp)[action]
end

In [None]:
# pol = ConstantPolicy(CartesianIndex(1,1))
# pol = MDPPolicy(mdp, res);

In [None]:
# rs = RolloutSimulator(max_steps=10)
# r = simulate(rs, mdp, pol, rand(mdp.states))

In [None]:
s0 = rand(mdp.states);
# s0 = CartesianIndex(0, 1, 0, 1);

In [None]:
hr = HistoryRecorder(max_steps=3000)
h_always = simulate(hr, mdp, ConstantPolicy(CartesianIndex(0,1)), s0);
h_never = simulate(hr, mdp, ConstantPolicy(CartesianIndex(0,0)), s0);
h_mdp = simulate(hr, mdp, MDPPolicy(mdp, res), s0);

In [None]:
sum(map(x -> x[:a] == CartesianIndex(0,1), h_mdp.hist))

In [None]:
mean(map(x -> x[:r], h_mdp.hist))

In [None]:
plot(cumsum(map(x -> x[:r], h_always.hist)))
plot(cumsum(map(x -> x[:r], h_never.hist)))
plot(cumsum(map(x -> x[:r], h_mdp.hist)))

https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-Heuristic-Policy.ipynb

In [None]:
x, y = [], []
for (i, action) in enumerate(res.policy)
    action = getaction(actions(mdp)[action])
    timesteps, laststate = getstate(states(mdp)[i])[2]
    push!(x, (p2.A^timesteps)[laststate,1])
    push!(y, action[2])
end
scatter(x, y)

In [None]:
x, y = [], []
for (i, action) in enumerate(res.policy)
    action = getaction(actions(mdp)[action])
    timesteps, laststate = getstate(states(mdp)[i])[2]
    push!(x, (p2.A^timesteps)[laststate,1])
    push!(y, action[2])
end
scatter(x, y)

#### 8.2 Two Markov chains of two states each

In [None]:
# TODO: Use DiscreteNonParametric instead of 0-variance Normal distn.
p1 = HMM([0.7 0.3; 0.3 0.7], [Normal(0.5, 0), Normal(2.0, 0)])
p2 = HMM([0.9 0.1; 0.1 0.9], [Normal(1.0,0), Normal(3.0,0)])
mdp = MonitoringMDP([100, 100], [p1, p2], [0.05, 0.15], 0.01);

In [None]:
smdp = SparseTabularMDP(mdp);

In [None]:
solver = SparseValueIterationSolver(max_iterations=100, belres=1e-6, verbose=true)
res = solve(solver, smdp);

In [None]:
x, y, z = [], [], []
for (i, action) in enumerate(res.policy)
    state = getstate(states(mdp)[i])
    timesteps, laststate = state[1]
    push!(x, (p1.A^(timesteps+1))[laststate,1])
    timesteps, laststate = state[2]
    push!(y, (p2.A^(timesteps+1))[laststate,1])
    push!(z, action)
end
# scatter(x, y)

In [None]:
scatter(x, y, c=z)
xlim(0,1)
ylim(0,1)

TODO: Implement https://juliapomdp.github.io/POMDPModelTools.jl/latest/visualization.html

## Simulation

https://juliapomdp.github.io/POMDPSimulators.jl/stable/parallel/#Parallel-1