In [2]:
import numpy as np
import matplotlib.pyplot as plt

<img src="images/Sutton-p76-GridWorld.png" width=800 height=400>

In [3]:
class GridWorld:
    def __init__(self):
        # self.worldValues = np.random.rand(4,4)
        self.worldValues = np.zeros((4,4))
        self.worldValues[0,0] = 0
        self.worldValues[3,3] = 0
        self.currentState = int(np.random.rand()*8 + 2) ## The current state is randomly generated, but the world knows where you are
        self.actions = [0, 1, 2, 3]
        self.actionNames = {0:"up", 1:"down", 2:"left", 3:"right"}
        
    def step(self,action):
        self.currentState = GridWorld.move(self.currentState, action)
        if self.currentState == 0 or self.currentState == 15:
            reward = 1
        else:
            reward = -1
        return self.currentState, reward
    def get_value(self, state_no):
        x, y = GridWorld.state_coordinates(state_no)
        return self.worldValues[x,y]
    def set_value(self, state_no, value):
        x, y = GridWorld.state_coordinates(state_no)
        self.worldValues[x,y] = value
    
    @staticmethod
    def general_step(state, action):
        new_state = GridWorld.move(state, action)
        if new_state == 0 or new_state == 15:
            reward = 1
        else:
            reward = -1
        return new_state, reward
    @staticmethod
    def state_nunmber(x,y):
        position = x*4 + y
        return position
    @staticmethod
    def state_coordinates(number):
        y = number % 4
        x = number // 4
        return (x,y)
    @staticmethod
    def move(s,a):
        x, y = GridWorld.state_coordinates(s)
        if x == 0 and a == 0:
            x = 0
        elif x == 3 and a == 1:
            x = 3
        else:
            if a == 0:
                x = x - 1
            elif a == 1:
                x = x + 1
        
        if y == 0 and a == 2:
            y = 0
        elif y == 3 and a == 3:
            y = 3
        else:
            if a == 2:
                y = y - 1
            elif a == 3:
                y = y + 1
        return GridWorld.state_nunmber(x,y)

In [4]:
world = GridWorld()
world.worldValues

array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

In [5]:
world.currentState

7

In [6]:
#print(world.get_value(2))
#world.set_value(2,89)
world.get_value(2)

0.0

In [7]:
GridWorld.general_step(1,2)

(0, 1)

In [8]:
class Agent:
    def __init__(self):
        pass
    def action(self): ## Implementing the equiprobable policy
        return np.random.randint(4)
    def policy(self):
        return 0.25

<img src="images/Sutton-PolicyEvaluation.png" width=800 height=400>>

In [11]:
def areDeltasLargerThanTheta(d, t):
    for i in range(d.shape[0]):
        if abs(d[i]) > abs(t):
            return True
    return False
#def iterative_policy_evaluation():
if __name__ == "__main__":
    world = GridWorld()
    agent = Agent()
    theta = 0.01 ## Every element in delta will be compared to this value
    deltas = np.ones(15) # [1, 1 ...]
    print(areDeltasLargerThanTheta(deltas, theta))
    while_count = 0
    while areDeltasLargerThanTheta(deltas, theta) and while_count < 1000:
        current_delta = 0
        while_count+=1
        for state in range(1,15):
            v = world.get_value(state) ## v <- V(s)
            new_v = 0
            #print("\tProcessing state: ", state)
            for action in range(len(world.actions)): ## Generating the \sum \pi(a|s)
            #This is an undicounted update (no \gamma)
                new_s, r = GridWorld.general_step(state, action) ## p(s',r | s,a) 
                new_v += agent.policy() * (r + world.get_value(new_s)) ## Adding the
            world.set_value(state, new_v) ## V(s) <- Expression
            deltas[state] = max(deltas[state], abs(new_v - v))
        if(while_count < 5):
            print("")
            print("Main Loop No.:", while_count, "\n",world.worldValues)

True

Main Loop No.: 1 
 [[ 0.         -0.5        -1.125      -1.28125   ]
 [-0.5        -1.25       -1.59375    -1.71875   ]
 [-1.125      -1.59375    -1.796875   -1.37890625]
 [-1.28125    -1.71875    -1.37890625  0.        ]]

Main Loop No.: 2 
 [[ 0.         -1.21875    -2.3046875  -2.64648438]
 [-1.21875    -2.40625    -3.05664062 -3.20019531]
 [-2.3046875  -3.05664062 -3.21777344 -2.44921875]
 [-2.64648438 -3.20019531 -2.44921875  0.        ]]

Main Loop No.: 3 
 [[ 0.         -1.98242188 -3.49755859 -3.99768066]
 [-1.98242188 -3.51953125 -4.35876465 -4.50146484]
 [-3.49755859 -4.35876465 -4.4039917  -3.33866882]
 [-3.99768066 -4.50146484 -3.33866882  0.        ]]

Main Loop No.: 4 
 [[ 0.         -2.74987793 -4.65097046 -5.28694916]
 [-2.74987793 -4.55432129 -5.52768707 -5.66369247]
 [-4.65097046 -5.52768707 -5.43317795 -4.10888481]
 [-5.28694916 -5.66369247 -4.10888481  0.        ]]


In [12]:
world.worldValues

array([[  0., -12., -18., -20.],
       [-12., -16., -18., -18.],
       [-18., -18., -16., -12.],
       [-20., -18., -12.,   0.]])