

> Import libraries to use



In [39]:
import numpy as np

>  # Introduction to numpy (Skip if you already are familiar)

>> Creating a 1D array

In [40]:
a = np.array([1,2,3,4])
print(a)

[1 2 3 4]


>> Creating a 2D array


In [41]:
a = np.array([[1,2],[3,4]])
print(a)

[[1 2]
 [3 4]]


>> Creating an array full of zeros


In [42]:
a = np.zeros(shape=(10))
print(a)
a = np.zeros(shape=(5,2))
print(a)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]


>> Infinity in numpy

In [43]:
print(np.inf)

inf


>> Max and Argmax

In [44]:
a = np.array([2,1,4,3])
print(np.max(a))
print(np.argmax(a))

4
2


>> From list to Numpy

In [45]:
l = [1,2,3,4]
print(l)
print(np.asarray(l))

[1, 2, 3, 4]
[1 2 3 4]


>> Random in numpy

In [46]:
# Array of Random integers ranging from 1 to 10 (with any size you want)
a = np.random.randint(low=1, high=10, size=(5,2))
print(a)

# Array of random elements of a list with any size you want
a = np.random.choice([0,1,2], size=(2,))

[[8 6]
 [7 1]
 [2 9]
 [6 8]
 [1 3]]


>> Shapes in numpy

In [47]:
a = np.random.randint(low=1, high=5, size=(4,2))
print(a.shape)
print(a)

# Reshape a to a vector of shape = (8,1)
a = a.reshape((8,1))
print(a.shape)
print(a)

(4, 2)
[[3 1]
 [2 4]
 [1 1]
 [1 1]]
(8, 1)
[[3]
 [1]
 [2]
 [4]
 [1]
 [1]
 [1]
 [1]]


# Pre-defined utilities

In [48]:

int_to_char = {
    0 : 'u',
    1 : 'r',
    2 : 'd',
    3 : 'l'
}

policy_one_step_look_ahead = {
    0 : [-1,0],
    1 : [0,1],
    2 : [1,0],
    3 : [0,-1]
}

def policy_int_to_char(pi,n):

    pi_char = ['']

    for i in range(n):
        for j in range(n):

            if i == 0 and j == 0 or i == n-1 and j == n-1:

                continue

            pi_char.append(int_to_char[pi[i,j]])

    pi_char.append('')

    return np.asarray(pi_char).reshape(n,n)


n = 4 
terminal_states = [0, n * n - 1]

def get_transitions(n):
    transitions = {}
    for i in range(n):
        for j in range(n):
            state = i * n + j
            transitions[state] = {}
            
            if state in terminal_states:
                transitions[state] = {a: [(1.0, state)] for a in range(4)}
            else:
                transitions[state][0] = [(1.0, state - n if i > 0 else state)]     # UP
                transitions[state][1] = [(1.0, state + 1 if j < n - 1 else state)] # RIGHT
                transitions[state][2] = [(1.0, state + n if i < n - 1 else state)] # DOWN
                transitions[state][3] = [(1.0, state - 1 if j > 0 else state)]     # LEFT
    return transitions

transitions = get_transitions(n)

# 1- Policy evaluation

In [49]:

def policy_evaluation(n, pi, v, Gamma, threshhold):
    """
    This function should return the value function that follows the policy pi.
    Use the stopping criteria given in the problem statement.
    """
    max_iterations = 1000
    iteration = 0
    
    while iteration < max_iterations:
        delta = 0
        for i in range(n):
            for j in range(n):
                if (i, j) in terminal_states:
                    continue

                old_value = v[i, j]
                action = pi[i, j]
                di, dj = policy_one_step_look_ahead[action]
                ni, nj = i + di, j + dj


                if 0 <= ni < n and 0 <= nj < n:
                    new_value = -1 + Gamma * v[ni, nj]
                else:
                    new_value = -1 + Gamma * v[i, j]

                v[i, j] = new_value
                delta = max(delta, abs(new_value - old_value))

        if delta < threshhold:
            break

        iteration += 1

    if iteration >= max_iterations:
        print("Warning : Maximum iterations reached whithout convergence.")
        
    return v



# 2- Policy improvement

In [50]:
def policy_improvement(n,pi,v,Gamma):   
    """
    This function should return the new policy by acting in a greedy manner.
    The function should return as well a flag indicating if the output policy
    is the same as the input policy.

    Example:
      return new_pi, True if new_pi = pi for all states
      else return new_pi, False
    """
    policy_stable = True
    new_pi = pi.copy()
    for i in range(n):
        for j in range(n):
            if (i, j) in terminal_states:
                continue

            old_action = pi[i, j]
            action_values = []

            for action, (di, dj) in policy_one_step_look_ahead.items():
                ni, nj = i + di, j + dj
                if 0 <= ni < n and 0 <= nj < n:
                    next_value = v[ni, nj]
                else:
                    next_value = v[i, j]
                action_values.append(-1 + Gamma * next_value)

            best_action = np.argmax(action_values)
            new_pi[i, j] = best_action

            if best_action != old_action:
                policy_stable = False

    return new_pi, policy_stable

# 3- Policy Initialization

In [51]:
def policy_initialization(n):
  """
    This function should return the initial random policy for all states.
  """
  return np.random.choice(4, size=(n, n))

# 4- Policy Iteration algorithm

In [52]:
def policy_iteration(n,Gamma,threshhold):

    pi = policy_initialization(n=n)

    v = np.zeros(shape=(n,n))

    while True:

        v = policy_evaluation(n=n,v=v,pi=pi,threshhold=threshhold,Gamma=Gamma)

        pi , pi_stable = policy_improvement(n=n,pi=pi,v=v,Gamma=Gamma)

        if pi_stable:

            break

    return pi , v

# Main Code to Test

In [53]:
n = 4

Gamma = [0.8,0.9,1]

threshhold = 1e-4

for _gamma in Gamma:

    pi , v = policy_iteration(n=n,Gamma=_gamma,threshhold=threshhold)

    pi_char = policy_int_to_char(n=n,pi=pi)

    print()
    print("Gamma = ",_gamma)

    print()

    print(pi_char)

    print()
    print()

    print(v)



Gamma =  0.8

[['' 'u' 'u' 'u']
 ['u' 'u' 'u' 'u']
 ['r' 'r' 'r' 'r']
 ['u' 'u' 'u' '']]


[[-4.99992864 -4.99992864 -4.99992864 -4.99992864]
 [-4.99994291 -4.99994291 -4.99994291 -4.99994291]
 [-4.99992864 -4.99992864 -4.99992864 -4.99992864]
 [-4.99994291 -4.99994291 -4.99994291 -4.99994291]]

Gamma =  0.9

[['' 'l' 'l' 'l']
 ['u' 'u' 'u' 'u']
 ['u' 'u' 'u' 'u']
 ['u' 'u' 'u' '']]


[[-9.99963565 -9.99967208 -9.99970487 -9.99973439]
 [-9.99967208 -9.99970487 -9.99973439 -9.99976095]
 [-9.99970487 -9.99973439 -9.99976095 -9.99978485]
 [-9.99973439 -9.99976095 -9.99978485 -9.99980637]]
Avertissement : Maximum d'itérations atteint sans convergence.
Avertissement : Maximum d'itérations atteint sans convergence.
Avertissement : Maximum d'itérations atteint sans convergence.
Avertissement : Maximum d'itérations atteint sans convergence.

Gamma =  1

[['' 'u' 'u' 'l']
 ['u' 'u' 'u' 'u']
 ['u' 'u' 'u' 'u']
 ['u' 'u' 'u' '']]


[[-4000. -4000. -4000. -4001.]
 [-4001. -4001. -4001. -4002.]
 [