In [1]:
import numpy as np
from GridWorld import GridWorld

In [2]:
class Error(Exception):
    """Base class for exceptions in this module."""
    pass

class InputError(Error):
    """Exception raised for errors in the input."""

    def __init__(self, msg):
        self.msg = msg


In [3]:
class ValueIteration():
    """Class of a general discrete agent"""
    
    def __init__(self, env, discount_factor=0.99, theta=0.1):
        self.env = env
        self.ValueFunction = np.zeros(self.env.nS)
        self.discount_factor = discount_factor
        self.theta = theta
        self.policy = np.ones([self.env.nS, self.env.nA])/self.env.nA

    def update(self):
        V = np.zeros(self.env.nS)
        policy = np.zeros([self.env.nS, self.env.nA])

        # Implement!
        while True:
            max_diff = 0
            for s in range(self.env.nS):
                max_v = -999999
                max_a = 0
                for a in range(self.env.nA):
                    v = 0;
                    next_state, reward, done, prob = self.env.P[s][a]
                    v += prob * (reward + self.discount_factor * V[(int)(next_state)])
                    if v > max_v:
                        max_a = a
                    max_v = max(v, max_v)

                delta = max(max_diff, np.abs(V[s]-max_v))
                V[s] = max_v
                policy[s] = np.eye(self.env.nA)[max_a]
            if delta < self.theta:
                self.policy = policy
                self.ValueFunction = V
                return policy, V
    
    def get_action(self, state):
        if type(state) is tuple and len(state) == 2:
            return np.random.choice(np.arange(len(self.policy[state])), p=self.policy[state])
        raise InputError("Wrong State Value")

In [4]:
agent = ValueIteration(GridWorld())

In [5]:
pol, V = agent.update()

In [6]:
V.reshape(8, 8)

array([[ 16.46159041,  19.74645163,  23.36182899,  27.34789333,
         31.74893427,  36.61385225,  41.99670193,  47.95729273],
       [ 19.74645163,  23.36182899,  27.34789333,  31.74893427,
         36.61385225,  41.99670193,  47.95729273,  54.56185346],
       [ 23.36182899,  27.34789333,  31.74893427,  36.61385225,
         41.99670193,  47.95729273,  54.56185346,  61.88376811],
       [ 27.34789333,  31.74893427,  36.61385225,  41.99670193,
         47.95729273,  54.56185346,  61.88376811,  70.0043913 ],
       [ 23.613104  ,  27.57404084,  31.95246703,  47.95729273,
         54.56185346,  61.88376811,  70.0043913 ,  79.01395217],
       [ 20.2517936 ,  23.81663676,  20.43497308,  54.56185346,
         61.88376811,  70.0043913 ,  79.01395217,  89.01255696],
       [ 17.22661424,  20.43497308,  17.39147577,  61.88376811,
         70.0043913 ,  79.01395217,  89.01255696, 100.11130126],
       [ 14.50395282,  17.39147577,  14.6523282 ,  70.0043913 ,
         79.01395217,  89.0125569

In [17]:
GridWorld().rewards

{'done': 0, 'invalid_move': -2, 'terminated': 20, 'valid_move': -1}