In [None]:
import Pkg; Pkg.activate("..")

In [None]:
using ArgCheck
using Distributions
using HMMBase
using IterTools
using LinearAlgebra
using Random
using Test

In [None]:
import Base: IdentityUnitRange
import POMDPs: actionindex, actions, dimensions, discount, reward, stateindex, states, transition
import POMDPs: MDP
import POMDPModelTools: SparseCat

In [None]:
vproduct(args...) = ivec(Iterators.product(args...))
splatmap(f, args...) = map(x -> f(x...), args...);

https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-Heuristic-Policy.ipynb

## Spaces
[Spaces and Distributions](https://juliapomdp.github.io/POMDPs.jl/stable/interfaces/#space-interface-1)

https://discourse.julialang.org/t/linearindices-for-non-1-based-indices/26906/2

Est-ce qu'on mesure quand on atteind τmax ?

In [None]:
# https://discourse.julialang.org/t/efficient-tuple-concatenation/5398/9
@inline tuplejoin(x) = x
@inline tuplejoin(x, y) = (x..., y...)
@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...);

In [None]:
Action{P} = CartesianIndex{P}
State{P} = CartesianIndex{P}
ActionSpace{P} = CartesianIndices{P,NTuple{P,IdentityUnitRange{UnitRange{Int}}}}
StateSpace{P} = CartesianIndices{P,NTuple{P,IdentityUnitRange{UnitRange{Int}}}};

In [None]:
function action_space(npaths)
    @argcheck npaths >= 1
    range = Tuple(0:1 for _ in 1:npaths)
    CartesianIndices(IdentityUnitRange.(range))
end

function state_space(τmax, nstates)
    @argcheck length(τmax) == length(nstates)
    @argcheck all(τmax .>= 0) && all(nstates .>= 1)
    range = tuplejoin([(0:τ, 1:n) for (τ, n) in zip(τmax, nstates)]...)
    CartesianIndices(IdentityUnitRange.(range))
end

In [None]:
getaction(a::Action) = Tuple(a) .== 1
getstate(s::State) = collect(Iterators.partition(Tuple(s), 2))

## MDP

In [None]:
struct MonitoringMDP{PP,P} <: MDP{State{PP},Action{P}}
    τmax::Vector{Int}
    models::Vector{HMM}
    costs::Vector{Float64}
    discount::Float64
    # Internal fields
    actions::ActionSpace{P} # Action space
    states::StateSpace{PP} # State space
end

function MonitoringMDP(τmax, models, costs, discount)
    @argcheck length(τmax) == length(models) == length(costs)
    @argcheck 0 <= discount < 1
    P = length(models)
    actions = action_space(P)
    nstates = map(m -> size(m, 1), models)
    states = state_space(τmax, nstates)
    MonitoringMDP{2P,P}(τmax, models, costs, discount, actions, states)
end

In [None]:
actionindex(mdp::MonitoringMDP, a::Action) = LinearIndices(mdp.actions)[a]
stateindex(mdp::MonitoringMDP, s::State) = LinearIndices(mdp.states)[s]

In [None]:
actions(mdp::MonitoringMDP) = mdp.actions
states(mdp::MonitoringMDP) = mdp.states
discount(mdp::MonitoringMDP) = mdp.discount
# actionindex(mdp::MonitoringMDP, a::Action) = index(mdp.actions, a)
# stateindex(mdp::MonitoringMDP, s::State) = index(mdp.states, s)

### Transitions

In [None]:
function transition(τmax::Int, model::HMM, b, a)
    timesteps, laststate = b
    @argcheck timesteps <= τmax
    if a
        # Measure
        probas = (model.A^(timesteps + 1))[laststate,:]
        states = map(i -> (0, i), 1:length(probas))
        return probas, states
    else
        # Don't measure
        if timesteps == τmax
            return [1.0], [(timesteps, laststate)]
        else
            return [1.0], [(timesteps + 1, laststate)]
        end
    end
end

In [None]:
# Possible transitions from state s and action a
function transition(mdp::MonitoringMDP, s::State, a::Action)
    probas = []
    states = []
    
    s = getstate(s)
    a = getaction(a)

    for (τmax, model, belief, action) in zip(mdp.τmax, mdp.models, s, a)
        probas_, states_ = transition(τmax, model, belief, action)
        push!(probas, probas_)
        push!(states, states_)
    end

    probas = splatmap(*, vproduct(probas...))
    states = map(CartesianIndex, splatmap(tuplejoin, vproduct(states...)))

    SparseCat(states, probas)
end

### Rewards

In [None]:
function reward(mdp::MonitoringMDP, ::State, a::Action, sp::State)
    cost = dot(mdp.costs, Tuple(a))

    sp = getstate(sp)
    
    delay = minimum(zip(mdp.models, sp)) do (model, belief)
        timesteps, laststate = belief
        probas = (model.A^timesteps)[laststate,:]
        sum(i -> mean(model.B[i]) * probas[i], 1:length(probas))
    end

    return -cost-delay
end

In [None]:
# TODO: Tests

## Simulation

In [None]:
using Distributions
using DiscreteValueIteration
using POMDPModelTools
using PyPlot

### JONS paper
#### 8.1 A first simple example

In [None]:
# TODO: Use DiscreteNonParametric instead of 0-variance Normal distn.
p1 = HMM(ones(1,1), [Normal(8,0)])
p2 = HMM([0.99 0.01; 0.02 0.98], [Normal(5,0), Normal(10,0)])
mdp = MonitoringMDP([300, 300], [p1, p2], [0.65, 0.65], 0.01);

In [None]:
smdp = SparseTabularMDP(mdp);

In [None]:
solver = SparseValueIterationSolver(max_iterations=100, belres=1e-6, verbose=true)
res = solve(solver, smdp);

In [None]:
x, y = [], []
for (i, action) in enumerate(res.policy)
    action = getaction(actions(mdp)[action])
    timesteps, laststate = getstate(states(mdp)[i])[2]
    push!(x, (p2.A^timesteps)[laststate,1])
    push!(y, action[2])
end
scatter(x, y)

In [None]:
x, y = [], []
for (i, action) in enumerate(res.policy)
    action = getaction(actions(mdp)[action])
    timesteps, laststate = getstate(states(mdp)[i])[2]
    push!(x, (p2.A^timesteps)[laststate,1])
    push!(y, action[2])
end
scatter(x, y)

#### 8.2 Two Markov chains of two states each

In [None]:
# TODO: Use DiscreteNonParametric instead of 0-variance Normal distn.
p1 = HMM([0.7 0.3; 0.3 0.7], [Normal(0.5, 0), Normal(2.0, 0)])
p2 = HMM([0.9 0.1; 0.1 0.9], [Normal(1.0,0), Normal(3.0,0)])
mdp = MonitoringMDP([100, 100], [p1, p2], [0.05, 0.15], 0.01);

In [None]:
smdp = SparseTabularMDP(mdp);

In [None]:
solver = SparseValueIterationSolver(max_iterations=100, belres=1e-6, verbose=true)
res = solve(solver, smdp);

In [None]:
x, y, z = [], [], []
for (i, action) in enumerate(res.policy)
    state = getstate(states(mdp)[i])
    timesteps, laststate = state[1]
    push!(x, (p1.A^(timesteps+1))[laststate,1])
    timesteps, laststate = state[2]
    push!(y, (p2.A^(timesteps+1))[laststate,1])
    push!(z, action)
end
# scatter(x, y)

In [None]:
scatter(x, y, c=z)
xlim(0,1)
ylim(0,1)

TODO: Implement https://juliapomdp.github.io/POMDPModelTools.jl/latest/visualization.html

## Simulation

https://juliapomdp.github.io/POMDPSimulators.jl/stable/parallel/#Parallel-1