In [1]:
import numpy as np
from GridWorld import GridWorld

In [3]:
class Error(Exception):
    """Base class for exceptions in this module."""
    pass

class InputError(Error):
    """Exception raised for errors in the input."""

    def __init__(self, msg):
        self.msg = msg


In [15]:
class PolicyIteration():
    """Class of a general discrete agent"""
    
    def __init__(self, env, discount_factor=0.99, theta=0.000001):
        self.env = env
        self.ValueFunction = np.zeros(self.env.nS)
        self.discount_factor = discount_factor
        self.theta = theta
        self.policy = np.ones([self.env.nS, self.env.nA])/self.env.nA

    def evaluate_policy(self, policy):
        # Start with a random (all 0) value function
        V = np.zeros(self.env.nS)
        while True:
            delta = 0
            # For each state, perform a "full backup"
            for s in range(self.env.nS):
                v = 0
                # Look at the possible next actions
                for a, action_prob in enumerate(policy[s]):
                    # For each action, look at the possible next state.                    
                    next_state, reward, done, prob = self.env.P[s][a]

                    v += action_prob * prob * (reward + self.discount_factor * V[(int)(next_state)])
                
                # How much our value function changed (across any states)
                delta = max(delta, np.abs(v - V[s]))
                V[s] = v
            
            # Stop evaluating once our value function change is below a threshold
            if delta < self.theta:
                break
        return np.array(V)
    
    def next_step_rewards(self, state, V):
        A = np.zeros(self.env.nA)
        for i in range(self.env.nA):
            next_state, reward, done, prob = self.env.P[state][i]
            A[i] += prob * (reward + self.discount_factor*V[(int)(next_state)])
        return A
    
    def update_policy(self, V, policy):
        
        is_stable = True 
        
        for i in range(self.env.nS):
            curr_best = np.argmax(policy[i])

            rewards = self.next_step_rewards(i, V)

            actual_best = np.argmax(rewards)

            if curr_best != actual_best:
                is_stable = False
                
            policy[i] =  np.eye(self.env.nA)[actual_best]

        return policy, is_stable
    
    def update(self):
        # Start with a random policy
        policy = np.ones([self.env.nS, self.env.nA]) / self.env.nA
        while True:
            V = self.evaluate_policy(policy)
            policy, is_stable = self.update_policy(V, policy)
            if is_stable:
                self.policy = policy
                self.ValueFunction = V
                return policy, V
            
    def get_action(self, state):
        if type(state) is tuple and len(state) == 2:
            return np.random.choice(np.arange(len(self.policy[state])), p=self.policy[state])
        raise InputError("Wrong State Value")

In [16]:
agent = PolicyIteration(GridWorld())

In [17]:
pol, V = agent.update()

In [18]:
V.reshape(8,8)

array([[826.02716166, 835.38097235, 844.82926597, 854.37299688,
        864.01312911, 873.7506364 , 883.58650235, 893.52172046],
       [835.38097235, 844.82926597, 854.37299688, 864.01312911,
        873.7506364 , 883.58650235, 893.52172046, 903.5572943 ],
       [844.82926597, 854.37299688, 864.01312911, 873.7506364 ,
        883.58650235, 893.52172046, 903.5572943 , 913.69423757],
       [854.37299688, 864.01312911, 873.7506364 , 883.58650235,
        893.52172046, 903.5572943 , 913.69423757, 923.93357419],
       [864.01312911, 873.7506364 , 883.58650235, 893.52172046,
        903.5572943 , 913.69423757, 923.93357419, 934.27633845],
       [873.7506364 , 883.58650235, 893.52172046, 903.5572943 ,
        913.69423757, 923.93357419, 934.27633845, 944.72357507],
       [883.58650235, 893.52172046, 903.5572943 , 913.69423757,
        923.93357419, 934.27633845, 944.72357507, 955.27633932],
       [893.52172046, 903.5572943 , 913.69423757, 923.93357419,
        934.27633845, 944.7235750

In [2]:
GridWorld().rewards

{'done': 0, 'invalid_move': -2, 'terminated': 20, 'valid_move': -1}