In [3]:
# import common packages
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


In [4]:
# Environment description
rewards={(0,3):1,(1,3):-1}
actions={
    (2,0):['U','R'],
    (1,0):['U','D'],
    (0,0):['R','D'],
    (2,1):['R','L'],
    (0,1):['R','L'],
    (2,2):['U','R','L'],
    (1,2):['U','D','R'],
    (0,2):['R','L','D'],
    (2,3):['L'],
}
'''
policy={
    (2,0):'U',
    (1,0):'U',
    (0,0):'R',
    (0,1):'R',
    (0,2):'R',
    (1,2):'U',
    (2,1):'R',
    (2,2):'U',
    (2,3):'L'
}
'''

"\npolicy={\n    (2,0):'U',\n    (1,0):'U',\n    (0,0):'R',\n    (0,1):'R',\n    (0,2):'R',\n    (1,2):'U',\n    (2,1):'R',\n    (2,2):'U',\n    (2,3):'L'\n}\n"

In [5]:
# define the gridworld  class
class GridWorld():
    
    def __init__(self, rows, columns, start_position):
        self.rows = rows
        self.columns = columns
        #self.all_states = [(i,j) for i in range(rows) for j in range(columns)]
        self.i = start_position[0]
        self.j = start_position[1]
        
    def set_rewards_actions(self, rewards, actions):
        self.rewards = rewards
        self.actions = actions
        self.all_states = set(self.actions.keys()) | set(self.rewards.keys())
        print (self.all_states)
    
    def set_state(self, s):
        self.i = s[0]
        self.j = s[1]
    
    def current_state(self):
        return self.i,self.j
    
    def get_next_state(self, s, a):
        i, j = s[0], s[1]
        #print(f"s : ({i},{j})")
        #print(a)
        #print(self.actions[(i,j)])
        if a in self.actions[(i,j)]:
            if a == 'U':
                i -= 1
            elif a == 'R':
                j += 1
            elif a == 'L':
                j -= 1
            else:
                i += 1
        #print(f"s2 : ({i},{j})")
        return i,j
    
    def undo_move(self, action):
        if action in self.actions[(self.i,self.j)]:
            if action == 'U':
                self.i += 1
            elif action == 'R':
                self.j -= 1
            elif action == 'L':
                self.j += 1
            else:
                self.i -= 1
        # should never happen
        assert (self.current_state() in self.all_states)
 
    def move(self, action):
        if action in self.actions[(self.i,self.j)]:
            if action == 'U':
                self.i -= 1
            elif action == 'R':
                self.j += 1
            elif action == 'L':
                self.j -= 1
            else:
                self.i += 1
        return self.rewards.get((self.i,self.j),0)

    def is_terminal (self, s):
        return s not in self.actions
    
    def game_over():
        return (self.i,self.j) in self.actions

In [23]:
SMALL_ENOUGH = 1e-3

def print_values(V,g):
    for i in range(g.rows):
        print("---------------------------")
        for j in range(g.columns):
            v = V.get((i,j),0)
            if v >= 0:
                print(" %.2f|" % v, end="")
            else:
                print("%.2f|" % v, end="")
        print ("")

def print_policy(P,g):
    for i in range(g.rows):
        print("---------------------------")
        for j in range(g.columns):
            a = P.get((i,j),' ')
            print(" %s |" % a, end="")
        print ("")

ACTION_SPACE = ('U', 'D', 'L', 'R')
        
def init_transition_probs(grid, tr_probs, exp_rewards):
    for i in range(grid.rows):
        for j in range(grid.columns):
            s = (i,j)
            if not grid.is_terminal(s):
                for a in ACTION_SPACE:
                    s2 = grid.get_next_state(s,a)
                    tr_probs[(s,a,s2)] = 1
                    if s2 in grid.rewards:
                        exp_rewards[(s,a,s2)] = grid.rewards[s2]
    return tr_probs, exp_rewards

gamma = 0.9

def evaluate_deterministic_policy(grid, cur_policy, tr_probs, exp_rewards):
    V = {}
    for s in grid.all_states:
        V[s]=0
          
    it = 0
    while True:
        biggest_change = 0
        for s in grid.all_states:
            if not grid.is_terminal(s):
                old_v = V[s]
                new_v = 0
                for a in ACTION_SPACE:
                    for s2 in grid.all_states:
                        
                        action_prob = 1 if cur_policy.get(s) == a else 0
                        
                        r = exp_rewards.get((s,a,s2),0)
                        new_v += action_prob *tr_probs.get((s,a,s2),0)*(r + gamma *V[s2])
                
                V[s] = new_v
                biggest_change = max(biggest_change, np.abs(old_v - V[s]))
                
        print ("iter: ", it, "biggest_change: ", biggest_change)
        print_values(V,grid)
        it += 1
        
        if biggest_change < SMALL_ENOUGH:
            break   
    return V
    
def play_game():
    
    transition_probs = {}
    expected_rewards = {}
        
    g = GridWorld(3,4,(2,0))
    g.set_rewards_actions(rewards, actions)
   
    transition_probs, expected_rewards= init_transition_probs(g, transition_probs, expected_rewards) 
    print (transition_probs)
    print (expected_rewards)
    
    # intialize a policy
    policy = {}
    for s in g.actions.keys():
        policy[s]=np.random.choice(ACTION_SPACE)
        
    print_policy(policy, g)
    
    it_main = 0
    
    while True:
        
        cur_v = evaluate_deterministic_policy(g, policy, transition_probs, expected_rewards)
        
        # policy improvement step
        is_policy_stable = True
        
        for s in g.actions.keys():
    
            if not g.is_terminal(s):
                old_a = policy[s]
                new_a = None
                best_value = float('-inf')                
                
                for a in ACTION_SPACE:
                    v=0               
                    for s2 in g.all_states:          
                        
                        r = expected_rewards.get((s,a,s2),0)
                        v += transition_probs.get((s,a,s2),0)*(r + gamma *cur_v[s2])
                
                    if v > best_value:
                        best_value = v
                        new_a = a
                
                policy[s] = new_a                   
                if old_a != policy[s]:
                    is_policy_stable = False
 
        if is_policy_stable:
            break
                                   
        it_main +=1
        print_policy(policy,g)
                     
    print ("iter: ", it_main)
    print_values(cur_v,g)  


In [24]:
play_game()

{(0, 1), (1, 2), (0, 0), (1, 3), (2, 1), (2, 0), (2, 3), (2, 2), (1, 0), (0, 2), (0, 3)}
{((0, 0), 'U', (0, 0)): 1, ((0, 0), 'D', (1, 0)): 1, ((0, 0), 'L', (0, 0)): 1, ((0, 0), 'R', (0, 1)): 1, ((0, 1), 'U', (0, 1)): 1, ((0, 1), 'D', (0, 1)): 1, ((0, 1), 'L', (0, 0)): 1, ((0, 1), 'R', (0, 2)): 1, ((0, 2), 'U', (0, 2)): 1, ((0, 2), 'D', (1, 2)): 1, ((0, 2), 'L', (0, 1)): 1, ((0, 2), 'R', (0, 3)): 1, ((1, 0), 'U', (0, 0)): 1, ((1, 0), 'D', (2, 0)): 1, ((1, 0), 'L', (1, 0)): 1, ((1, 0), 'R', (1, 0)): 1, ((1, 2), 'U', (0, 2)): 1, ((1, 2), 'D', (2, 2)): 1, ((1, 2), 'L', (1, 2)): 1, ((1, 2), 'R', (1, 3)): 1, ((2, 0), 'U', (1, 0)): 1, ((2, 0), 'D', (2, 0)): 1, ((2, 0), 'L', (2, 0)): 1, ((2, 0), 'R', (2, 1)): 1, ((2, 1), 'U', (2, 1)): 1, ((2, 1), 'D', (2, 1)): 1, ((2, 1), 'L', (2, 0)): 1, ((2, 1), 'R', (2, 2)): 1, ((2, 2), 'U', (1, 2)): 1, ((2, 2), 'D', (2, 2)): 1, ((2, 2), 'L', (2, 1)): 1, ((2, 2), 'R', (2, 3)): 1, ((2, 3), 'U', (2, 3)): 1, ((2, 3), 'D', (2, 3)): 1, ((2, 3), 'L', (2, 2)): 1, 