This notebook uses value iteration to solve the MDP depicted in Figure A1 and produces the results of Table A1.

In [14]:
import numpy as np

In [None]:
class demo_env_min():
    def __init__(self, p_neg, start_val=0, rng=np.random.default_rng(0)) -> None:


        self.p_neg = p_neg
        self.states = [0, 1, 2]
        self.n_states = len(self.states)
        self.start_val = start_val
        self.rng = rng
        self.reward_probs = [[{1:0.5, -1:0.5,}],
                             [{0:1}, {1:1-self.p_neg, -2:self.p_neg,}],
                            []]

        self.reset()


    def get_state_value(self, state):
        return self.state_values[state]

    def reset(self):
        self.state_values = self.start_val * np.zeros(self.n_states)

    def get_state_action_value(self, state, action):
        action_dict = self.reward_probs[state][action]

        state_action_val = 0
        for reward, prob in action_dict.items():
            next_state = state + 1
            state_action_val += min(reward, self.get_state_value(next_state)) * prob
        return state_action_val
    
    def train(self, n_steps):
        
        for _ in range(n_steps):
            for idx, state in enumerate(self.states):
                action_values = []
                for action_dict in self.reward_probs[idx]:
                    update_val = 0
                    for reward, prob in action_dict.items():
                        next_state = state  + 1
                        update_val += min(reward, self.get_state_value(next_state)) * prob
                    action_values.append(update_val)
                if len(action_values) > 0:
                    self.state_values[state] = np.max(action_values)
                
        return self.state_values

In [16]:
env_min = demo_env_min(p_neg=0.1)
env_min.train(100)

array([-0.5,  0. ,  0. ])

In [18]:
Q_values = {}
Q_values[(0, 0)] = env_min.get_state_action_value(0, 0)
Q_values[(1, 0)] = env_min.get_state_action_value(1, 0)
Q_values[(1, 1)] = env_min.get_state_action_value(1, 1)
print(Q_values)

{(0, 0): -0.5, (1, 0): 0, (1, 1): -0.2}
