In [1]:
from typing import Set

import numpy as np

from src.formalisms.primitives import State, Action

In [2]:
from src.formalisms.policy import FinitePolicyForFixedCMDP
from src.solution_methods.linear_programming.cplex_dual_cmdp_solver import solve_CMDP_for_policy
from src.concrete_decision_processes.maze_cmdp import RoseMazeCMDP

cmdp = RoseMazeCMDP()
sigma: FinitePolicyForFixedCMDP = solve_CMDP_for_policy(cmdp, True, False)[0]

Entering cplex: view dual_mdp_result_20230307_182830.log for info
Exiting cplex


In [3]:
sigma.policy_matrix

array([[0.   , 0.   , 0.   , 1.   ],
       [1.   , 0.   , 0.   , 0.   ],
       [0.   , 0.686, 0.   , 0.314],
       [0.   , 0.   , 0.   , 1.   ],
       [0.   , 0.   , 0.   , 1.   ],
       [0.   , 0.   , 0.   , 1.   ],
       [1.   , 0.   , 0.   , 0.   ],
       [0.   , 1.   , 0.   , 0.   ],
       [1.   , 0.   , 0.   , 0.   ]])

In [4]:
def A(policy: FinitePolicyForFixedCMDP, s: State) -> Set[Action]:
    return set(a for a in policy.A if policy(s).get_probability(a) > 0.0)

In [5]:
A(sigma, cmdp.state_list[2])
sigma.policy_matrix[2, :], sigma.occupancy_measure_matrix[2, :], np.isclose(sigma.policy_matrix[2, :], 0)

(array([0.   , 0.686, 0.   , 0.314]),
 array([0.   , 0.686, 0.   , 0.314]),
 array([ True, False,  True, False]))

In [6]:
m = sum(len(A(sigma, s)) - 1 for s in cmdp.S)
m

1

In [11]:
def get_phi_1():
    phi_1_policy_matrix = sigma.policy_matrix.copy()
    is_stochastic_state_mask = ((phi_1_policy_matrix > 0).sum(axis=1) != 1)
    states_inds_to_be_split = np.where(is_stochastic_state_mask)
    for s_ind in states_inds_to_be_split:
        action_probs = phi_1_policy_matrix[s_ind, :]
        first_action_ind = np.where(action_probs > 0.0)[0]
        new_action_probs = np.zeros(cmdp.n_actions)
        new_action_probs[first_action_ind] = 1.0
        phi_1_policy_matrix[s_ind, :] = new_action_probs

    # Check row_stochastic
    assert np.allclose(phi_1_policy_matrix.sum(axis=1), 1.0)
    # Check deterministic
    assert ((phi_1_policy_matrix > 0).sum(axis=1) == 1).all()
    return FinitePolicyForFixedCMDP.fromPolicyMatrix(cmdp=cmdp, policy_matrix=phi_1_policy_matrix)

phi_1 = get_phi_1()

# Algorithm 1

## Inputs

In [13]:
sigma, phi_1,

(<src.formalisms.policy.FinitePolicyForFixedCMDP at 0x7fd650439b10>,
 <src.formalisms.policy.FinitePolicyForFixedCMDP at 0x7fd621490690>)

## Initiation

### Line 1

$q^{\sigma}_{\mu}(x)$

In [16]:
sigma.state_occupancy_measure_vector

array([0.2826    , 0.40507614, 1.        , 0.500094  , 0.55566   ,
       0.        , 0.4500846 , 0.6174    , 6.18908526])

$Q^{\sigma}_{\mu}(x, a)$

In [17]:
sigma.occupancy_measure_matrix

array([[0.        , 0.        , 0.        , 0.2826    ],
       [0.40507614, 0.        , 0.        , 0.        ],
       [0.        , 0.686     , 0.        , 0.314     ],
       [0.        , 0.        , 0.        , 0.500094  ],
       [0.        , 0.        , 0.        , 0.55566   ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.4500846 , 0.        , 0.        , 0.        ],
       [0.        , 0.6174    , 0.        , 0.        ],
       [6.18908526, 0.        , 0.        , 0.        ]])

### Line 2

$q(x) \leftarrow q^{\sigma}(x)$

In [25]:
q = {
    cmdp.state_list[s_ind]: sigma.state_occupancy_measure_vector[s_ind]
    for s_ind in range(cmdp.n_states)
}

$A^*(x) \leftarrow A^{\sigma}(x)$

In [26]:
 A_star = {
    s: A(sigma, s)
    for s in cmdp.S
}

$\phi = \phi^1$

$j \leftarrow 1$

In [27]:
j = 1

$U \leftarrow \{ x \in X \mid \   |A^*(x)| > 1, q(x)=0\}$

$V \leftarrow \{ x \in X \mid \   |A^*(x)| > 1, q(x)>0\}$

In [30]:
U = {
    s
    for s in cmdp.S
    if len(A_star[s]) > 1 and q[s] == 0
}
V = {
    s
    for s in cmdp.S
    if len(A_star[s]) > 1 and q[s] > 0
}
U, V

(set(), {XYState(x=0, y=0)})

$Q(x, a) \leftarrow Q^{\sigma}_{\mu}(x,a)$ for $x \in V$ and $a
\in A^*(x)$

In [33]:
Q = {
    (s, a): sigma.occupancy_measure_matrix[cmdp.state_to_ind_map[s], cmdp.action_to_ind_map[a]]
    for s in V
    for a in A_star[s]
}
Q

{(XYState(x=0, y=0), <IntAction(0)>): 0.6859999999999999,
 (XYState(x=0, y=0), <IntAction(1)>): 0.314}