In this notebook, we implement the section III.C of [BRAJ23] (whose pdf can be found [here](https://adrienbanse.github.io/assets/pdf/cdc23_extended.pdf)).

[BRAJ23] Adrien Banse, Licio Romao, Alessandro Abate, Raphaël M. Jungers, Data-driven abstractions via adaptive refinements and a Kantorovich metric [extended version]

In [65]:
# We first include the implemented algorithms
include(joinpath(@__DIR__, "../src/KantorovichAbstraction.jl"))
K = KantorovichAbstraction



Main.KantorovichAbstraction

# Example 1: Adaptive refinement of a dynamical system

Step 1: We first define the dynamical system

In [66]:
# We define the transfer function
function F(x::Vector{Float64})
    if x[1] >= 1.
        return x
    elseif x[2] <= .5
        return [.5 * x[1]  + .5, x[2] + .5]
    elseif x[1] >= .5
        return [x[1] - .5, x[2]]
    elseif x[2] >= .75
        return [2. * x[1] + 1., 4. * (x[2] - .75)]
    else
        return x
    end
end

# We define the output function
H(x::Vector{Float64}) = x[1] >= 1. ? 0 : 1

# We define other characteristics of the dynamical DynamicalSystem (see src/system.jl)
dim = 2
outputs = [0, 1]

# And any initial (state, output) couple
state = [0., 0.]
output = H(state)

1

In [67]:
# We also have to define the oracle for the dynamical system
# The oracle is a function that returns an abstraction as defined in [BRAJ23, Definition 6] given a partitioning
# In [BRAJ23], we assume that one has access to such an oracle following [BRAJ23, Assumption 2]

# We first define two function query_memory_system
# The first one returns the measure of a given partition
# The second returns the proportion of a partition that jumps to another partition
# They respectively give μ_w and P_{w1, w2} in [BRAJ23, Definition 6]
not(c::Int) = c == 0 ? 1 : 0
function query_memory_system(init::Vector{Int})::Float64
    res_dict = Dict(
        0 => 1 / 2,
        1 => Dict(
            0 => 1 / 8,
            1 => Dict(
                0 => 1 / 7,
                1 => Dict(
                    0 => 1 / 3,
                    1 => 2 / 3
                )
            )
        )
    )
    function return_from_res(dict::Dict, init::Vector{Int}, agg::Float64)
        l = length(init)
        if l == 0 
            return agg
        end
        c = init[1]
        if typeof(dict[c]) <: Dict
            return return_from_res(dict[c], init[2:l], agg * (1 - dict[not(c)]))
        else
            return dict[c] * agg
        end
    end
    if length(init) > 1
        for (c1, c2) = zip(init[1:length(init) - 1], init[2:length(init)])
            if c1 == 0 && c2 == 1
                return 0.
            end
        end
    end
    if length(init) > 4 && init[1:4] == ones(4)
        if init == ones(length(init))
            return 1 / 4
        else
            return 0.
        end
    end
    return return_from_res(res_dict, init, 1.)
end
function query_memory_system(from::Vector{Int}, to::Vector{Int})::Float64
    p_from = query_memory_system(from)
    p_to = query_memory_system(to)
    if p_from == 0. || p_to == 0.
        return 0.
    end
    from_future = from[2:length(from)]
    shortest = min(length(from_future), length(to))
    if from_future[1:shortest] == to[1:shortest]
        if length(from_future) >= length(to)
            return 1.
        else
            inter = vcat([from[1]], to)
            return query_memory_system(inter) / p_from
        end
    else
        return 0.
    end
end

# And finally we define the corresponding oracle
function oracle(W::Vector{Vector{Int}})
    S = [K.PartitionState(w) for w = W]
    μ = Dict{K.PartitionState, Float64}()
    P = Dict{Tuple{K.PartitionState, K.PartitionState}, Float64}()
    L = Dict{K.PartitionState, Int}()
    for s_from = S
        μ[s_from] = query_memory_system(K.id(s_from))
        L[s_from] = K.id(s_from)[1]
        for s_to = S
            P[s_from, s_to] = query_memory_system(K.id(s_from), K.id(s_to))
        end
    end
    labels = [0, 1]
    return K.MarkovChain{K.PartitionState}(S, P, μ, labels, L)
end

oracle (generic function with 1 method)

In [68]:
# We are now able to define the dynamical system
system = K.DynamicalSystem(F, H, dim, state, outputs, output, oracle)

Main.KantorovichAbstraction.DynamicalSystem{typeof(F), typeof(H), typeof(oracle)}(F, H, 2, [0.0, 0.0], [0, 1], 1, oracle)

Step 2: Use the REFINE algorithm as defined in [BRAJ23, Algorithm 2] to create a data-driven adaptive refinement abstraction

In [72]:
# We will now use our function refine defined in src/abstraction.jl to 
N = typemax(Int) 
ε = 1e-5

_ = K.refine(system, N, ε; verbose = true)

(k = 0) Current abstraction:
Markov Chain with states Main.KantorovichAbstraction.PartitionState[[0], [1]]


(k = 1) Current abstraction (chosen with d = 0.0015200741576384842):
Markov Chain with states Main.KantorovichAbstraction.PartitionState[[0], [1, 0], [1, 1]]


(k = 2) Current abstraction (chosen with d = 0.005912370617181118):
Markov Chain with states Main.KantorovichAbstraction.PartitionState[[0], [1, 0], [1, 1, 0], [1, 1, 1]]


(k = 3) Current abstraction (chosen with d = 0.003906247549900491):
Markov Chain with states Main.KantorovichAbstraction.PartitionState[[0], [1, 0], [1, 1, 0], [1, 1, 1, 0], [1, 1, 1, 1]]


We see that with $N = \infty$, the algorithm stops at $k = 3$, and therefore the last abstraction has the same behaviour as the dynamical system following [BRAJ23, Proposition 1]. \  
We recover the values in [BRAJ23, Table I]

# Example 2: Application to controller design

In this example we solve MDPs as explained in [BRAJ3, Section III.C] for Example 2.

In [85]:
using Random, Distributions
using POMDPs, QuickPOMDPs, POMDPModelTools, POMDPSimulators, QMDP, DiscreteValueIteration

In [124]:
# We first define actions, reward and a discount factor

actions = [0, 1/4, 1/2]
reward = function (s, a) return s == 0 ? 1. : 0. end
discount = .95

0.95

In [125]:
# We now create a function that will solve the MDP defined by a set of states, some initial probabilities on the states and a transition function transition(s, a) corresponding to the probability to go to another state from state s, with action a
# See POMDPs.jl doc for more information

function solve_MDP_abstraction(
    states, 
    initial_probs, 
    transition;
    verbose = false
)
    m = QuickMDP(
        states = states,
        actions = actions,
        initialstate = SparseCat(states, initial_probs),
        discount = discount,
        transition = transition,
        reward = reward
    )
    solver = ValueIterationSolver(max_iterations=10000)
    policy = solve(solver, m)
    w = 0

    if verbose
        println("-- Result --")
        for (i, s) = enumerate(K.states(m))
            println("V($s) = $(value(policy, s))")
            w += initial_probs[i] * value(policy, s)
        end
    end

    return policy
end

solve_MDP_abstraction (generic function with 1 method)

We will now define the transitions of the MDP corresponding to the result of the REFINE algorithm above. \
For each partitioning, we found the transition function on paper, and explicitely write it here. \
For each partitionnig, we find the optimal policty thanks to our function to solve MDPs above.

In [93]:
states_0 = [0, 1]
initial_probs = [1/2, 1/2]
function _transition(s, a)
    res = Dict(
        0 => Deterministic(0), 
        1 => SparseCat([1, 0], [7/8, 1/8])
    )
    return res[s]
end
p_0 = solve_MDP_abstraction(states_0, initial_probs, _transition)

ValueIterationPolicy:
 0 -> 0.0
 1 -> 0.0

In [94]:
states_1 = [0, 10, 11]
initial_probs = [1/2, 1/16, 7/16]
function _transition(s, a)
    if a == 0
        res = Dict(
            0 => Deterministic(0), 
            10 => Deterministic(0),
            11 => SparseCat([11, 10], [6/7, 1/7])
        )
        return res[s]
    end
    res = Dict(
        0 => Deterministic(0),
        10 => Deterministic(11),
        11 => SparseCat([11, 10, 0], [5/7, 1/7, 1/7])
    )
    return res[s]
end
p_1 = solve_MDP_abstraction(states_1, initial_probs, _transition)

ValueIterationPolicy:
 0 -> 0.0
 10 -> 0.0
 11 -> 0.25

In [95]:
states_2 = [0, 10, 110, 111]
initial_probs = [1/2, 1/16, 1/16, 3/8]
function _transition(s, a)
    if a == 0
        res = Dict(
            0 => Deterministic(0),
            10 => Deterministic(0),
            110 => Deterministic(10),
            111 => SparseCat([111, 110], [2/3, 1/3])
        )
        return res[s]
    end
    if a == 1/4
        res = Dict(
            0 => Deterministic(0),
            10 => Deterministic(111),
            110 => Deterministic(111),
            111 => SparseCat([111, 110, 10, 0], [1/3, 1/3, 1/6, 1/6])
        )
        return res[s]
    end
    if a == 1/2
        res = Dict(
            0 => Deterministic(0),
            10 => Deterministic(110),
            110 => Deterministic(110),
            111 => SparseCat([111, 10, 0], [2/3, 1/6, 1/6])
        )
        return res[s]
    end
end
p_2 = solve_MDP_abstraction(states_2, initial_probs, _transition)

ValueIterationPolicy:
 0 -> 0.0
 10 -> 0.0
 110 -> 0.0
 111 -> 0.25

In [96]:
states_3 = [0, 10, 110, 1110, 1111]
initial_probs = [1/2, 1/16, 1/16, 1/8, 2/8]
function _transition(s, a)
    if a == 0
        res = Dict(
            0 => Deterministic(0),
            10 => Deterministic(0),
            110 => Deterministic(10),
            1110 => Deterministic(110),
            1111 => Deterministic(1111),
        )
        return res[s]
    end
    if a == 1/4
        res = Dict(
            0 => Deterministic(0),
            10 => Deterministic(1111),
            110 => Deterministic(1111),
            1110 => Deterministic(1111),
            1111 => SparseCat([110, 10, 0], [1/2, 1/4, 1/4])
        )
        return res[s]
    end
    if a == 1/2
        res = Dict(
            0 => Deterministic(0),
            10 => Deterministic(110),
            110 => Deterministic(110),
            1110 => SparseCat([10, 0], [1/2, 1/2]),
            1111 => Deterministic(1111)
        )
        return res[s]
    end
end
p_3 = solve_MDP_abstraction(states_3, initial_probs, _transition)

ValueIterationPolicy:
 0 -> 0.0
 10 -> 0.0
 110 -> 0.0
 1110 -> 0.5
 1111 -> 0.25

We will now define a controlled dynamical system corresponding to the optimal policies found above. \
For this, we first need to util functions

In [126]:
# This function returns true if a given state x is in the partition "state" 

function is_in_partition(state::Vector{Int}, x::Vector{Float64}, sys::K.DynamicalSystem)
    xp = copy(x)
    for s = state
        if (K.H(sys))(xp) != s 
            return false
        end
        xp = (K.F(sys))(xp)
    end
    return true
end

is_in_partition (generic function with 2 methods)

In [127]:
# Given a controlled system, a discount factor, a reward function, a sample length L and a number of samples N, this function gives the corresponding approximated expected reward as in [BRAJ23, Equation (11)]

function approximate_reward(
    cont_system::K.ControlledDynamicalSystem, 
    discount::Float64, 
    reward::F, 
    L::Int, 
    N::Int
) where F <: Function
    global r_tot = 0
    for n = 1:N
        cont_system.system.state = [rand(Distributions.Uniform(0, 2)), rand(Distributions.Uniform(0, 1))]
        cont_system.system.output = (K.H(cont_system))(K.state(cont_system))
        local r = reward(K.output(cont_system), nothing)
        for l = 1:(L - 1)
            K.next!(cont_system)
            r += discount^l * reward(K.output(cont_system), nothing)
        end
        r_tot += r
    end
    return r_tot / N
end

approximate_reward (generic function with 1 method)

Now, for each partitioning found by REFINE, we are able to define a corresponding controlled dynamical system (see [BRAJ23, Equation (12)]), \
and approximate the corresponding expected reward.

In [123]:
for (i, (p, states)) = enumerate(zip([p_0, p_1, p_2, p_3], [states_0, states_1, states_2, states_3]))
    function control(x::Vector{Float64})
        for s = states
            parsed_state = [parse(Int64, a) for a = string(s, base=10)]
            if is_in_partition(parsed_state, x, system)
                x2 = (x[2] + action(p, s)) % 1
                return [x[1], x2]
            end
        end
    end
    cont_system = K.ControlledDynamicalSystem(system, control)
    r = approximate_reward(cont_system, discount, reward, 1000, 5000) 
    println("(k = $(i-1)) Expected reward is $r")
end

(k = 0) Expected reward is 14.341482500000053


(k = 1) Expected reward is 18.904444650000364


(k = 2) Expected reward is 19.03577750000061


(k = 3) Expected reward is 19.072816000000536


We recover the values in [BRAJ23, Table II] (up to randomness). \
One can see that the expected reward increases.