In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
import itertools
from tqdm import tqdm
import pprint

plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'
%config InlineBackend.figure_formats = ['svg']

project_dir = os.path.expanduser('../')

In [2]:
γ = 1

nS = 750
nA = 25

nS_term = 2
S_survival = 750
S_death = 751

nS_total = nS + nS_term

In [3]:
with open(project_dir + 'data/env_model/gymP_shifted_plus.pkl', 'rb') as f:
    gymP_plus = pickle.load(f)
with open(project_dir + 'data/env_model/gymP_shifted_minus.pkl', 'rb') as f:
    gymP_minus = pickle.load(f)
with open(project_dir + 'data/env_model/gymP_shifted_mixed.pkl', 'rb') as f:
    gymP_mixed = pickle.load(f)

In [4]:
SA_mask = pd.read_pickle(project_dir + "data/behavior_policy/SA_mask_shifted.pkl")

if isinstance(SA_mask, pd.DataFrame):
    SA_mask = SA_mask.values

SA_mask.shape

(750, 25)

In [5]:
def value_iteration_masked(gymP, nS, nA, SA_mask, gamma, threshold=1e-15):
    Vs = []
    V = np.zeros(nS)
    Vs.append(V.copy())
    for _ in tqdm(itertools.count()):
        V_new = V.copy()
        for s in range(nS):
            ## V[s] = max {a} sum {s', r} P[s', r | s, a] * (r + gamma * V[s'])
            Q_s = np.zeros((nA))
            for a in range(nA):
                Q_s[a] = sum(p * (r + (0 if done else gamma * V[s_])) for p, s_, r, done in gymP[s][a])
            Q_s[~SA_mask[s]] = np.nan
            new_v = np.nanmax(Q_s)
            V_new[s] = new_v
        if np.isclose(np.linalg.norm(V_new - V), threshold):
            break
        V = V_new
        Vs.append(V_new)
    return V, {
        'V': V,
        'Vs': Vs,
    }

In [6]:
V_plus, info_plus = value_iteration_masked(gymP_plus, nS, nA, SA_mask, γ)
V_minus, info_minus = value_iteration_masked(gymP_minus, nS, nA, SA_mask, γ)
V_mixed, info_mixed = value_iteration_masked(gymP_mixed, nS, nA, SA_mask, γ)

133it [00:02, 47.45it/s]
112it [00:02, 44.63it/s]
134it [00:02, 46.05it/s]


In [7]:
# check prop 2: V_plus + V_minus = V_mixed
print(V_plus + V_minus)

[1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         0.99515736 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         0.99997779 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         0.92057761 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 0.99999999 1.         1.         1.         1.         

In [41]:
print(V_mixed)

[1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         0.99515736 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         0.99997779 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         0.92057761 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 0.99999999 1.         1.         1.         1.         

In [51]:
np.allclose(V_plus + V_minus, V_mixed, atol=1e-6, rtol=0)

True

In [43]:
# check for prop 4
print(V_plus - V_minus)

[1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         0.99999999 1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 0.99999999 1.         1.         1.         1.         

In [56]:
np.allclose(V_plus - V_minus, 1.0, atol=1e-7, rtol=0)

False

In [61]:
zero_state = np.where(V_plus - V_minus == 0)[0]
zero_state


array([330, 482])

In [63]:
print(V_plus[330])
print(V_minus[482])

0.0
0.0
