In [1]:
import pandas as pd
import numpy as np
from numpy import genfromtxt
from tqdm import tqdm

In [2]:
def load_data():
    transitions = genfromtxt("../data/icu-sepsis-mixed/transitionFunction.csv", delimiter=",")
    survival_reward = genfromtxt("../data/icu-sepsis-mixed/rewardFunction-survival.csv", delimiter=",")
    death_reward = genfromtxt("../data/icu-sepsis-mixed/rewardFunction-death.csv", delimiter=",")
    mixed_reward = genfromtxt("../data/icu-sepsis-mixed/rewardFunction-mixed.csv", delimiter=",")
    return transitions, survival_reward, death_reward, mixed_reward

In [3]:
P, R_surv, R_death, R_mixed = load_data()

### Value Iter

In [23]:
def value_iter(P: np.ndarray, R: np.ndarray, gamma: float = 1.0, epsilon: float = 1e-8 ) -> None:
    '''
    value iteration that supports three reward variants. 
    P (transition probability): np.ndarray, (17900, 716): (i, j) corresponds to (s, a, s'). each row corresponds to a (s, a) pair, column the next state
    R_ (reward function): np.ndarray, (716, ): each column corresponds to the reward received upon transitioning into that state
    gamma: discount factor
    epsilon: threshold for convergence
    '''
    num_states = 716
    num_actions = 25
    absorbing_state = num_states - 1

    P_sa_s = P.reshape((num_states, num_actions, num_states))
    V = np.zeros(num_states)

    iteration = 0
    while True:
        V_prev = V.copy()
        for s in range(num_states):
            Q_sa = np.zeros(num_actions)
            for a in range(num_actions):
                Q_sa[a] = np.sum(P_sa_s[s, a, :] * (R + gamma * V_prev))
            V[s] = np.max(Q_sa)

        V[absorbing_state] = R[absorbing_state]

        delta = np.max(np.abs(V - V_prev))
        iteration += 1
        if delta < epsilon:
            break

    print(f"Converged after {iteration} iterations.")
    return V

In [49]:
print("V_plus: ")
V_plus = value_iter(P, R_surv)
print("V_minus: ")
V_minus = value_iter(P, R_death)
print("V_mixed: ")
V_mixed = value_iter(P, R_mixed)

V_plus: 
Converged after 207 iterations.
V_minus: 
Converged after 189 iterations.
V_mixed: 
Converged after 209 iterations.


In [None]:
# for proposition 2: Value sum in MDPs
print("prop 2: ", np.isclose((V_plus + V_minus), V_mixed, atol=1e-8).all())

# for proposition 4: State-value difference in MDPs with shared policy
print("prop 4: ", np.isclose((V_plus - V_minus)[:-3], 1.00, atol=1e-10).all())  # excluding last 3 states



prop 2:  True
prop 4:  True


In [22]:
print(V_plus - V_minus)

[0.99994973 0.99997615 0.99996647 0.99995705 0.99994313 0.99998346
 0.99996055 0.99995062 0.99996452 0.999953   0.99995664 0.99996155
 0.99995297 0.99995509 0.99995523 0.99996857 0.99997851 0.99994906
 0.99995093 0.99995264 0.99994901 0.99996813 0.99994921 0.99995133
 0.99997882 0.99999268 0.99999094 0.99997259 0.99995422 0.9999536
 0.9999661  0.99994413 0.99994637 0.99997828 0.99997619 0.99995949
 0.999972   0.9999549  0.99994656 0.99995415 0.9999545  0.99995704
 0.99995289 0.9999528  0.99996717 0.99995568 0.99995386 0.99995763
 0.99996896 0.99994284 0.99997239 0.99994147 0.99996445 0.99995111
 0.99996058 0.99994809 0.99997365 0.99995132 0.99995773 0.99995005
 0.99997658 0.99996265 0.99998325 0.99998044 0.99999212 0.99996319
 0.99995501 0.99995481 0.99995465 0.99994847 0.99996201 0.99995122
 0.99995369 0.99994859 0.99995485 0.99997515 0.99996932 0.99995248
 0.99997191 0.99995314 0.99996735 0.99976769 0.99997702 0.99996452
 0.99995038 0.99995161 0.99995518 0.99998585 0.99996183 0.99995

In [26]:
print(V_mixed)

[ 0.84140838  0.74081142  0.88756427  0.84436511  0.7267429   0.67954962
  0.87637531  0.82475764  0.72931897  0.78869794  0.78961477  0.86630073
  0.81861909  0.83746816  0.86988583  0.73701492  0.84454379  0.84800426
  0.84747086  0.85806061  0.81614914  0.72598715  0.79788807  0.84471455
  0.80816777  0.59206531  0.71463869 -0.43513215  0.84446651  0.86962332
  0.80299758  0.8173377   0.86108958  0.7962864   0.74853268  0.83751825
  0.71683508  0.85588754  0.8033727   0.85542034  0.80876633  0.78977834
  0.83005413  0.83965704  0.60732517  0.86367052  0.84303867  0.82740828
  0.80869089  0.781342    0.75126125  0.67884891  0.7817188   0.84345973
  0.88015144  0.88485694  0.65538633  0.8459466   0.85973049  0.82888455
  0.80113122  0.8731575   0.77912223 -0.60474163  0.68388609  0.85453255
  0.82929251  0.85623739  0.86335103  0.85469559  0.85997858  0.83487424
  0.79303238  0.80995288  0.81466461  0.76699937  0.75972446  0.84847636
  0.68132351  0.87032858  0.74868892 -0.08450118  0

In [30]:
print(np.isclose((V_plus + V_minus), V_mixed, atol=1e-8))

[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  T

### Policy Eval and Policy Iter

In [None]:
def policy_eval(P: np.ndarray, R: np.ndarray, policy: np.ndarray, gamma: float = 1.0, epsilon: float = 1e-4):
    S, A, _ = P.shape
    V = np.zeros(S)
    inner_iter = 0
    while True:
        V_prev = V.copy()
        for s in range(S):
            a = policy[s]
            V[s] = np.sum(P[s,a,:] * (R + gamma * V_prev))
        inner_iter += 1
        if inner_iter % 20 == 0:
            delta = np.max(np.abs(V - V_prev))
            print(f"   [Policy Eval] Iter {inner_iter} | delta={delta:.6f}")
        if np.max(np.abs(V - V_prev)) < epsilon:
            break
    return V


In [59]:
def policy_iter(P: np.ndarray, R: np.ndarray, gamma: float = 1.0, epsilon: float = 1e-4):
    '''
    P: (S, A, S)
    R: (S,)
    Returns:
        policy: (S,)
        V: (S,)
    '''
    S, A, _ = P.shape
    policy = np.zeros(S, dtype=int)
    V = np.zeros(S)
    
    iteration = 0
    while True:
        # 1. Policy Evaluation
        V = policy_eval(P, R, policy, gamma, epsilon)
        
        # 2. Policy Improvement
        policy_stable = True
        for s in range(S):
            old_action = policy[s]
            Q_sa = np.zeros(A)
            for a in range(A):
                Q_sa[a] = np.sum(P[s,a,:] * (R + gamma * V))
            policy[s] = np.argmax(Q_sa)
            if old_action != policy[s]:
                policy_stable = False

        iteration += 1
        
        if iteration % 5 == 0:
            print(f"Policy Iteration: iteration {iteration}")


        if policy_stable:
            break

    print(f"Policy iteration converged after {iteration} iterations.")
    return policy, V


In [None]:
P_sa_s = P.reshape((716, 25, 716))
optimal_policy, optimal_V = policy_iter(P_sa_s, R_mixed, gamma=0.8, epsilon=1e-4)


Policy Iteration: iteration 5
Policy Iteration: iteration 10
Policy Iteration: iteration 15
Policy Iteration: iteration 20
Policy Iteration: iteration 25
Policy Iteration: iteration 30
Policy Iteration: iteration 35
Policy Iteration: iteration 40
Policy Iteration: iteration 45
Policy Iteration: iteration 50
Policy Iteration: iteration 55
Policy Iteration: iteration 60
Policy Iteration: iteration 65
Policy Iteration: iteration 70
Policy Iteration: iteration 75
Policy Iteration: iteration 80
Policy Iteration: iteration 85
Policy Iteration: iteration 90
Policy Iteration: iteration 95
Policy Iteration: iteration 100
Policy Iteration: iteration 105
Policy Iteration: iteration 110
Policy Iteration: iteration 115
Policy Iteration: iteration 120
Policy Iteration: iteration 125
Policy Iteration: iteration 130
Policy Iteration: iteration 135
Policy Iteration: iteration 140
Policy Iteration: iteration 145
Policy Iteration: iteration 150
Policy Iteration: iteration 155
Policy Iteration: iteration 