In [1]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import norm,inv
%matplotlib inline

In [2]:
A,M,lamda = 4,81,0.9925
target = [3,11,12,15,16,17,20,22,23,24,26,29,30,31,34,35
          ,39,43,48,52,53,56,57,58,59,60,61,62,66,70,71]
target = np.array(target) - 1

In [3]:
def reconstruct(sparse):
    result = np.zeros(M**2)
    index = ((sparse[:,0] - 1) * M + sparse[:,1]).astype(int)
    result[index-1] = sparse[:,2]
    return result

In [4]:
# read data from txt files
prob_a1_sparse = np.loadtxt('prob_a1.txt')
prob_a2_sparse = np.loadtxt('prob_a2.txt')
prob_a3_sparse = np.loadtxt('prob_a3.txt')
prob_a4_sparse = np.loadtxt('prob_a4.txt')
R = np.loadtxt('rewards.txt')

In [5]:
# reconstruct the transition matrix
prob_a1 = reconstruct(prob_a1_sparse).reshape(M,M)
prob_a2 = reconstruct(prob_a2_sparse).reshape(M,M)
prob_a3 = reconstruct(prob_a3_sparse).reshape(M,M)
prob_a4 = reconstruct(prob_a4_sparse).reshape(M,M)
prob = np.array([prob_a1,prob_a2,prob_a3,prob_a4])
print(prob.shape)

(4, 81, 81)


### a) policy iteration:

In [6]:
# initialize policy pi randomly
pi = np.random.randint(4,size = 81)

In [7]:
def value(P_pi,R):
    return (inv(np.eye(M)-lamda*P_pi)).dot(R)

In [8]:
def choose(pi,prob):
    result = np.zeros((M,M))
    for i in range(M):
        result[i,:] += prob[pi[i]][i,:]
    return result

In [9]:
def greedy(V_pi):
    temp = (prob.dot(V_pi[:,np.newaxis])).reshape(4,81).T
    return np.argmax(temp,axis = 1)

In [10]:
pi_old = np.zeros(81)
iteration_p = 0
while (np.array_equal(pi_old,pi) == False):
    iteration_p += 1
    pi_old = pi
    P_pi = choose(pi,prob)
    V_pi = value(P_pi,R)
    pi = greedy(V_pi)
    if iteration_p >= 100:
        break

In [11]:
policy_iteration = pi[target]
print(policy_iteration)

[2 2 1 3 3 2 2 3 3 0 2 3 3 0 2 1 0 2 0 2 2 3 3 3 3 3 2 2 0 2 1]


In [12]:
P_best = choose(pi,prob)
V_best = value(P_best,R)

In [13]:
print(V_best[target])

[100.70098073 102.3752644  101.52364515 109.48993454 110.40903296
 111.33584663 103.23462342 106.77826755 107.67462643 108.57848712
 112.27044032 104.10121204 104.97507555 105.88853591 114.1632295
 113.21287932 103.78140737 115.12155727  90.9853796  116.08792959
 122.02491241  81.39949278  93.67165583  95.17285726 108.34261934
 109.58365072 123.64307021 123.1822391   81.39949278 125.24978944
 124.20738563]


In [14]:
V_best

array([   0.        ,    0.        ,  100.70098073,    0.        ,
          0.        ,    0.        ,    0.        ,    0.        ,
          0.        ,    0.        ,  102.3752644 ,  101.52364515,
          0.        ,    0.        ,  109.48993454,  110.40903296,
        111.33584663,    0.        ,    0.        ,  103.23462342,
          0.        ,  106.77826755,  107.67462643,  108.57848712,
          0.        ,  112.27044032,    0.        ,    0.        ,
        104.10121204,  104.97507555,  105.88853591,    0.        ,
          0.        ,  114.1632295 ,  113.21287932,    0.        ,
          0.        ,    0.        ,  103.78140737,    0.        ,
          0.        ,    0.        ,  115.12155727,    0.        ,
          0.        ,    0.        , -133.33333333,   90.9853796 ,
       -133.33333333,    0.        , -133.33333333,  116.08792959,
        122.02491241,    0.        ,    0.        ,   81.39949278,
         93.67165583,   95.17285726,  108.34261934,  109.58365

### b) value iteration:

In [15]:
# initialize values to be zero, policy to be -1(invalid value)
V_old = np.zeros(M) - 1
V = np.zeros(M)

In [16]:
iteration_v = 0
while (norm(V_old-V) >= 0.00001):
    iteration_v += 1
    V_old = V
    temp = (prob.dot(V_old[:,np.newaxis])).reshape(4,81).T
    V = np.max(temp,axis = 1) * lamda + R
    if iteration_v >= 1000:
        break

In [21]:
pi = np.argmax(temp,axis = 1)

In [22]:
pi[target]

array([2, 2, 1, 3, 3, 2, 2, 3, 3, 0, 2, 3, 3, 0, 2, 1, 0, 2, 0, 2, 2, 3,
       3, 3, 3, 3, 2, 2, 0, 2, 1])

In [23]:
V[target]

array([100.70076778, 102.37505145, 101.5234322 , 109.48972157,
       110.40882   , 111.33563367, 103.23441047, 106.77805459,
       107.67441346, 108.57827415, 112.27022735, 104.1009991 ,
       104.97486261, 105.88832295, 114.16301654, 113.21266636,
       103.78119506, 115.1213443 ,  90.9851907 , 116.08771662,
       122.02468702,  81.39933699,  93.67147874,  95.17267934,
       108.34241889, 109.58344966, 123.64284538, 123.18201337,
        81.39933699, 125.24956367, 124.20715987])