In [1]:
import numpy as np
import matplotlib.pyplot as plt

<img src="images/Sutton-p76-GridWorld.png" width=800 height=400>

In [2]:
class GridWorld:
    def __init__(self):
        self.worldValues = np.random.rand(4,4) * 5
        # self.worldValues = np.zeros((4,4))
        self.worldValues[0,0] = 0
        self.worldValues[3,3] = 0
        self.currentState = int(np.random.rand()*8 + 2) ## The current state is randomly generated, but the world knows where you are
        self.actions = [0, 1, 2, 3]
        self.actionNames = {0:"up", 1:"down", 2:"left", 3:"right"}
        
    def step(self,action):
        self.currentState = GridWorld.move(self.currentState, action)
        if self.currentState == 0 or self.currentState == 15:
            reward = 1
        else:
            reward = -1
        return self.currentState, reward
    def get_value(self, state_no):
        x, y = GridWorld.state_coordinates(state_no)
        return self.worldValues[x,y]
    def set_value(self, state_no, value):
        x, y = GridWorld.state_coordinates(state_no)
        self.worldValues[x,y] = value
    
    @staticmethod
    def general_step(state, action):
        new_state = GridWorld.move(state, action)
        if new_state == 0 or new_state == 15:
            reward = 1
        else:
            reward = -1
        return new_state, reward
    @staticmethod
    def state_nunmber(x,y):
        position = x*4 + y
        return position
    @staticmethod
    def state_coordinates(number):
        y = number % 4
        x = number // 4
        return (x,y)
    @staticmethod
    def move(s,a):
        x, y = GridWorld.state_coordinates(s)
        if x == 0 and a == 0: # Changing the x-coordinates
            x = 0
        elif x == 3 and a == 1:
            x = 3
        else:
            if a == 0:
                x = x - 1
            elif a == 1:
                x = x + 1
        
        if y == 0 and a == 2: # Changing the y-coordinates
            y = 0
        elif y == 3 and a == 3:
            y = 3
        else:
            if a == 2:
                y = y - 1
            elif a == 3:
                y = y + 1
        return GridWorld.state_nunmber(x,y)

In [3]:
world = GridWorld()
world.worldValues

array([[0.        , 3.70626643, 3.41878999, 1.53163632],
       [3.28896629, 0.74333504, 2.3919096 , 3.31015304],
       [2.82659672, 2.46336645, 2.52188586, 2.52230582],
       [2.81841048, 4.18407611, 2.01931692, 0.        ]])

In [4]:
GridWorld.general_step(1,2)

(0, 1)

In [5]:
class Agent:
    def __init__(self):
        pass
    def action(self): ## Implementing the equiprobable policy
        return np.random.randint(4)
    def policy(self):
        return 0.25

<img src="images/Sutton-PolicyEvaluation.png" width=800 height=400>>

In [6]:
def areDeltasLargerThanTheta(d, t, while_counter):
    if while_counter < 2: ## Turning the while to a do..while loop
        return True
    for i in range(d.shape[0]):
        if abs(d[i]) > abs(t):
            return True
    return False
#def iterative_policy_evaluation():
if __name__ == "__main__":
    world = GridWorld()
    agent = Agent()
    theta = 1e-10 ## Every element in delta will be compared to this value
    deltas = np.zeros(15) # [0, 0 ...]
    while_count = 0
    while areDeltasLargerThanTheta(deltas, theta, while_count):
        deltas = np.zeros(15) # [0, 0 ...] Zeroing deltas every time
        while_count+=1
        for state in range(1,15):
            v = world.get_value(state) ## v <- V(s)
            new_v = 0
            for action in range(len(world.actions)):                   ## Generating the \sum \pi(a|s)
            #This is an undicounted update (no \gamma)
                new_s, r = GridWorld.general_step(state, action)       ## p(s',r | s,a) 
                new_v += agent.policy() * (r + world.get_value(new_s)) ## Adding the value of every action to the value of the state
            world.set_value(state, new_v)                              ## V(s) <- Expression
            deltas[state] = max(deltas[state], abs(new_v - v))         ## \delta <- max(\delta, |v - V(s)|)

In [7]:
print("Number of main loops:", while_count)
world.worldValues

Number of main loops: 273


array([[  0., -12., -18., -20.],
       [-12., -16., -18., -18.],
       [-18., -18., -16., -12.],
       [-20., -18., -12.,   0.]])

In [8]:
deltas

array([0.00000000e+00, 6.01101391e-11, 8.64197602e-11, 9.51203560e-11,
       6.01119154e-11, 7.37632178e-11, 8.00923772e-11, 7.91757770e-11,
       8.64162075e-11, 8.00923772e-11, 6.75868250e-11, 5.04591924e-11,
       9.51168033e-11, 7.91757770e-11, 5.04591924e-11])