In [14]:
import numpy as np

In [15]:
class GridWorld:
    def __init__(self) :
        self.row = 5
        self.column = 5
        self.i = 4
        self.j = 0
        self.config()

    def config(self):
        reward = {}

        for p in range(self.row):
            for q in range(self.column):
                if p==0 and q==4:
                    reward[(p,q)] = 1
                elif p==1 and q==4:
                    reward[(p,q)] = -1
                elif p==2 and q==1:
                    reward[(p,q)] = 0
                elif p==3 and q==3:
                    reward[(p,q)] = 0
                else :
                    reward[(p,q)] = 0

        action = {(0, 0): {'D','R'},
        (0, 1):{'L','D','R'},
        (0, 2): {'L','D','R'},
        (0, 3): {'L','D','R'},
        (1, 0): {'U','R','D'},
        (1, 1): {'U','D','R','L'},
        (1, 2): {'U','D','R','L'},
        (1, 3): {'U','D','R','L'},
        (2, 0): {'U','R','D'},
        (2, 1): {'U','D','R','L'},
        (2, 2): {'U','D','R','L'},
        (2, 3): {'U','D','R','L'},
        (2, 4): {'D','U','L'},
        (3, 0): {'U','R','D'},
        (3, 1): {'U','D','R','L'},
        (3, 2): {'U','D','R','L'},
        (3, 3): {'U','D','R','L'},
        (3, 4): {'D','L','U'},
        (4, 0): {'U','R'},
        (4, 1): {'U','R','L'},
        (4, 2): {'U','R','L'},
        (4, 3): {'U','R','L'},
        (4, 4): {'U','L'},}

        self.set(reward,action)

    def set(self,reward, action):
        self.reward = reward
        self.action = action

    def get_nextState(self,si,sj,action):
        if action == "U":
            si -= 1
        elif action == "D":
            si += 1
        elif action == "R":
            sj += 1
        elif action == "D":
            sj -= 1
        return si,sj

    def get_reward(self,si,sj):
        return self.reward[(si,sj)]

    def move(self,action):
        self.i,self.j = self.get_nextState(self.i,self.j,action)
        return self.get_reward(self.i,self.j)
    
    def terminate(self):
        if (self.i,self.j) in self.action:
            return False
        return True
    
    def current_state(self):
        return self.i,self.j
    
    def is_terminal(self,i,j):
        if (i,j) in self.action:
            return False
        return True

In [16]:
g = GridWorld()

In [17]:
def set_policy():
        policy = {
        (0, 0): 'D',
        (0, 1):'D',
        (0, 2): 'D',
        (0, 3): 'D',
        (1, 0): 'D',
        (1, 1): 'L',
        (1, 2): 'L',
        (1, 3): 'L',
        (2, 0): 'U',
        (2, 2): 'U',
        (2, 3): 'U',
        (2, 4): 'U',
        (3, 0): 'D',
        (3, 1): 'R',
        (3, 2): 'R',
        (3, 4): 'R',
        (4, 0): 'R',
        (4, 1): 'R',
        (4, 2): 'L',
        (4, 3): 'R',
        (4, 4): 'U',
        (0, 4): 'D',
        (1,4) : 'U'}
        return policy

policy = set_policy()

In [18]:
#transition and reward function of S' / s2
transition = {}
reward_s2 = {}

for i in range(g.row):
    for j in range(g.column):
        if(i,j) in g.action:
            for action in g.action[(i,j)]:
                s2i,s2j = g.get_nextState(i,j,action)
                transition[(i,j,action,s2i,s2j)]=1
                reward_s2[(i,j,action,s2i,s2j)] = g.get_reward(s2i,s2j)



In [19]:
#print policy

def print_policy(policy):
    print("POLICY")
    for i in range(g.row):
        print ("---"*7)
        print("|",end=" ")
        for j in range(g.column):
            if (i,j) in policy:
                print(policy[(i,j)],"|",end=" ")
            else:
                print(" ","|",end=" ")
        print("")
    print ("---"*7)

print_policy(policy)

POLICY
---------------------
| D | D | D | D | D | 
---------------------
| D | L | L | L | U | 
---------------------
| U |   | U | U | U | 
---------------------
| D | R | R |   | R | 
---------------------
| R | R | L | R | U | 
---------------------


In [20]:
#print value
def plot_value(V):
    print("Value")
    for i in range(g.row):
        print ("---"*7)
        print("|",end=" ")
        for j in range(g.column):
            print("%.2f|" %V[(i,j)],end=" ")
        print("")
    print ("---"*7)

In [40]:
#initiate Value
V = {}

for p in range(g.row):
    for q in range(g.column):
            V[(p,q)] = 0

delta = 0.001
gamma = 0.9
plot_value(V)

Value
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------


In [47]:
# Run the model

def policy_evaluation(policy,V):
    iter = 0
    while True:
        diff = 0
        for si in range(g.row):
            for sj in range(g.column):
                if not g.is_terminal(si,sj):
                    temp = 0
                    for action in g.action[(si,sj)]:
                        s2i,s2j = g.get_nextState(si,sj,action)
                        pi = 0
                        if (si,sj) in policy:
                            if policy[(si,sj)]==action:
                                pi = 1
                        temp += pi * transition[(si,sj,action,s2i,s2j)]*(reward_s2[(si,sj,action,s2i,s2j)]+ gamma * V[(s2i,s2j)])
                    
                    diff = max(diff,np.abs(V[(si,sj)]-temp))
                    V[(si,sj)]=temp
        iter = iter+1
        # print("Itereation :", iter, "      Diff :",diff)
        # plot_value(V)
        if diff < delta:
            break

    return V
    

In [48]:
# Run the model

def policy_improvement(policy,V):
    temp_policy = {}
    for si in range(g.row):
        for sj in range(g.column):
            if not g.is_terminal(si,sj):
                temp = {}
                for action in g.action[(si,sj)]:
                    temp_v = 0
                    for s2i in range(g.row):
                        for s2j in range(g.column):
                            trs = transition[(si,sj,action,s2i,s2j)] if (si,sj,action,s2i,s2j) in transition else 0
                            rew = reward_s2[(si,sj,action,s2i,s2j)] if (si,sj,action,s2i,s2j) in reward_s2 else 0
                            temp_v += trs *(rew+ gamma * V[(s2i,s2j)])
                    temp[action] = temp_v
            val = max(temp.values())
            res = [key for key in temp if temp[key] == val] 
            temp_policy[(si,sj)] = res[0]
    policy = temp_policy
        

    return policy
    

In [49]:
policy= set_policy()
iter = 0
while True:
    iter += 1
    print("Iteration : ",iter)
    V = policy_evaluation(policy,V)
    plot_value(V)
    new_policy = policy_improvement(policy,V)
    if new_policy==policy:
        break
    policy = new_policy
    print_policy(policy)

print_policy(policy)


Iteration :  1
Value
---------------------
| 0.00| 0.01| 0.01| 0.01| 0.00| 
---------------------
| 0.00| 0.01| 0.01| 0.01| 0.00| 
---------------------
| 0.00| 0.00| 0.01| 0.01| -1.00| 
---------------------
| 0.01| 0.00| 0.00| 0.00| 0.00| 
---------------------
| 0.01| 0.01| 0.01| 0.00| 0.00| 
---------------------
POLICY
---------------------
| R | R | R | R | R | 
---------------------
| R | R | R | L | L | 
---------------------
| D | U | U | U | D | 
---------------------
| D | D | U | U | L | 
---------------------
| U | L | L | L | L | 
---------------------
Iteration :  2
Value
---------------------
| 0.73| 0.81| 0.90| 1.00| 0.00| 
---------------------
| 0.01| 0.01| 0.01| 0.01| 0.00| 
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------
| 0.00| 0.00| 0.00| 0.00| 0.00| 
---------------------
POLICY
---------------------
| R | R | R | R | R | 
---------------------
| U | U | U | U | U | 
------------