# Setup

In [50]:
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default='notebook'
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)
import plotly.figure_factory as ff


In [79]:
##### PLOTTING FUNCTIONS #####
def policy_to_text(policy):
    # policy is a 2D numpy array of the same shape as values. It contains the policy for each state s, pi(s)
    policy_text = []
    policy_dict = {0:'stay', 1:'up', 2:'right', 3:'down', 4:'left'}
    for row in range(policy.shape[0]):
        this_row = []
        for col in range(policy.shape[1]):
            this_row.append(policy_dict[policy[row,col]])
        policy_text.append(this_row)
    return policy_text

def plot_grid_world(values, policy):
    # values is a 2D numpy array with the value for each cell in the grid
    # policy is a 2D numpy array of the same shape as values. It contains the policy for each state s, pi(s)

    # labels
    x = ['0', '1', '2', '3', '4']
    y = ['0', '1', '2', '3', '4']
    
    policy_text = policy_to_text(policy)
#     print(policy_text)

    fig = ff.create_annotated_heatmap(values, x=x, y=y, annotation_text=policy_text, colorscale='Viridis')
    fig.show()

In [80]:
values = np.random.rand(5,5)

policy = np.random.randint(5, size=(5,5))


plot_grid_world(values, policy)

# The System

In [134]:
n = 5 ## grid size

# STATES
S = []
for row in range(0,n):
    for col in range(0,n):
        S.append((col,row))
blocked_states = [(1,3),(2,3),(1,1),(2,1)]
        
print("States")
print(S)
print("Number of states: ", len(S))
print("\n")



# ACTIONS
# A = [0, 1, 2, 3, 4] # stay, up, right, down, left
n_a = 5
A = [(0,0), (0,1), (1,0), (0,-1), (-1,0)] # stay, up, right, down, left
print("Actions")
print(A)
print("Number of actions: ", len(A))
print("\n")

# TRANSITION PROBABILITIES
P = np.zeros((n**2, len(A), n**2))
for state in range(n**2):
    for action in range(len(A)):
#         print("state")
#         print(S[state])
#         print("action")
#         print(A[action])
        next_state_candidate = (S[state][0] + A[action][0], S[state][1] + A[action][1])
        if next_state_candidate in S and not next_state_candidate in blocked_states:
            next_state_index = S.index(next_state_candidate)
            P[state,action,next_state_index] = 1
#             print("next_state_index")
#             print(next_state_index)
#             print("next_state")
#             print(S[next_state_index])        
#         print("\n")

print("Transition Probabilities")
print("Shape of the transition probability matrix: ", P.shape)
print("Number of transition probabilities: ", P.shape[0]* P.shape[1]* P.shape[2])




States
[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (0, 3), (1, 3), (2, 3), (3, 3), (4, 3), (0, 4), (1, 4), (2, 4), (3, 4), (4, 4)]
Number of states:  25


Actions
[(0, 0), (0, 1), (1, 0), (0, -1), (-1, 0)]
Number of actions:  5


Transition Probabilities
Shape of the transition probability matrix:  (25, 5, 25)
Number of transition probabilities:  3125


# The Task

In [135]:
# Discount Factor
gamma = 0.8

##### REWARD MATRIX #####
R = np.zeros((n**2, len(A), n**2)) # initialize as a matrix with the same size as P
Rw = -10
Rs = 10
Rd = 1
# Associate negative reward with being on Westwood Blvd (Rw) in any next state
R[:,:,4] = Rw
R[:,:,9] = Rw
R[:,:,14] = Rw
R[:,:,19] = Rw
R[:,:,24] = Rw
# Associate positive reward with ENTERING the ice cream shops
Rs_state = 12
Rd_state = 2
for state in range(n**2):
    # assign to all state jumps that land in the ice cream shop as long as the current state is not already in the ice cream shop
    if state != Rs_state:
        R[state,:,Rs_state] = Rs
    if state != Rd_state:
        R[state,:,Rd_state] = Rd
        

# Policy Iteration

In [173]:
def policy_iteration(S, A, P, R, gamma):
    # S is the list of n_s states
    # A is the list of n_a actions
    # P is the 3D numpy array of transition probabilities of shape (n_s, n_a, n_s), in the order (state, action, next_state)
    # R is the reward associated with each (state, action, next_state) triplet. It is the same shape as P
    # gamma is the discount factor (0,1]
    
    policy = np.random.randint(5, size=(5,5)) # initalize the policy
    values = np.zeros((5,5))
    iters = 0
    
    not_converged = True
    while(not_converged):
        # policy evaluation
        Pdd = np.zeros((n**2, n**2))
        Rdd = np.zeros((n**2, n**2))
        for s in range(Pdd.shape[0]):
            for next_s in range(Pdd.shape[1]):
                Pdd[s, next_s] = P[s, policy[s%n,int(s/n)], next_s]
                Rdd[s, next_s] = R[s, policy[s%n,int(s/n)], next_s]
        D = np.diagonal(np.dot(Pdd, np.transpose(Rdd)))
#         print("D.shape")
#         print(D.shape)
        V = np.dot(np.linalg.inv(np.eye(Pdd.shape[0]) - np.dot(gamma, Pdd)), D)
#         print(V.shape)
        
        # policy refinement
        old_policy = np.copy(policy)
        Q = np.zeros((n**2, n_a))
        for s in range(n**2):
            for a in range(n_a):
                for next_s in range(n**2):
#                     print(Q[s,a])
                    Q[s,a] += P[s,a,next_s]*(R[s,a,next_s] + gamma*V[next_s])
        for s in range(Pdd.shape[0]):
            policy[s%n,int(s/n)] = np.argmax(Q[s,:])
            values[s%n,int(s/n)] = Q[s,np.argmax(Q[s,:])]
        iters +=1
        
        # check convergence, break if converged
        if (np.allclose(policy, old_policy)):
            not_converged = False
    
    return policy, values, iters
    

In [174]:
p_calc, vals, iterations = policy_iteration(S,A,P,R,gamma)
print(iterations)
print("Policy")
for i in range(len(p_calc)):
    print(p_calc[4-i])
print("Values")
# for i in range(len(vals)):
#     print(vals[4-i])
#     print(vals)
print(vals)

6
Policy
[4 4 4 4 4]
[1 1 4 3 3]
[2 1 2 3 2]
[2 1 2 3 4]
[1 1 2 3 3]
Values
[[14.22222222 17.77777778 22.22222222 17.77777778 14.22222222]
 [12.37777778 22.22222222 27.77777778 22.22222222 11.37777778]
 [14.22222222 27.77777778 22.22222222 27.77777778 14.22222222]
 [17.77777778 22.22222222 27.77777778 22.22222222 17.77777778]
 [14.22222222 17.77777778 22.22222222 17.77777778 14.22222222]]


In [175]:
plot_grid_world(vals, p_calc)