In [4]:
import numpy as np
import random
import copy
from scipy.stats import bernoulli

In [5]:
class Environment(object):
    '''General RL environment'''

    def __init__(self):
        pass

    def reset(self):
        pass

    def advance(self, action):
        '''
        Moves one step in the environment.
        Args:
            action
        Returns:
            reward - double - reward
            newState - int - new state
            pContinue - 0/1 - flag for end of the episode
        '''
        return 0, 0, 0

In [24]:
def make_riverSwim(epLen=20, nState=5):
    '''
    Makes the benchmark RiverSwim MDP.
    Args:
        NULL - works for default implementation
    Returns:
        riverSwim - Tabular MDP environment '''
    nAction = 2
    R_true = {}
    P_true = {}

    for s in range(nState):
        for a in range(nAction):
            R_true[s, a] = (0, 0)
            P_true[s, a] = np.zeros(nState)

    # Rewards
    R_true[0, 0] = (5/1000, 0)
    R_true[nState - 1, 1] = (1, 0)

    # Transitions
    for s in range(nState):
        P_true[s, 0][max(0, s-1)] = 1.

    for s in range(1, nState - 1):
        P_true[s, 1][min(nState - 1, s + 1)] = 0.3
        P_true[s, 1][s] = 0.6
        P_true[s, 1][max(0, s-1)] = 0.1

    P_true[0, 1][0] = 0.3
    P_true[0, 1][1] = 0.7
    P_true[nState - 1, 1][nState - 1] = 0.9
    P_true[nState - 1, 1][nState - 2] = 0.1

    riverSwim = TabularMDP(nState, nAction, epLen)
    riverSwim.R = R_true
    riverSwim.P = P_true
    riverSwim.reset()

    return riverSwim

class TabularMDP(Environment):
    '''
    Tabular MDP
    R - dict by (s,a) - each R[s,a] = (meanReward, sdReward)
    P - dict by (s,a) - each P[s,a] = transition vector size S
    '''

    def __init__(self, nState, nAction, epLen):
        '''
        Initialize a tabular episodic MDP
        Args:
            nState  - int - number of states
            nAction - int - number of actions
            epLen   - int - episode length
        Returns:
            Environment object
        '''

        self.nState = nState
        self.nAction = nAction
        self.epLen = epLen

        self.timestep = 0
        self.state = 0

        # Now initialize R and P
        self.R = {}
        self.P = {}
        for state in range(nState):
            for action in range(nAction):
                self.R[state, action] = (1, 1)
                self.P[state, action] = np.ones(nState) / nState
                
    def reset(self):
        "Resets the Environment"
        self.timestep = 0
        self.state = 0
        
    def advance(self,action):
        '''
        Move one step in the environment
        Args:
        action - int - chosen action
        Returns:
        reward - double - reward
        newState - int - new state
        episodeEnd - 0/1 - flag for end of the episode
        '''
        if self.R[self.state, action][1] < 1e-9:
            # Hack for no noise
            reward = self.R[self.state, action][0]
        else:
            reward = np.random.normal(loc=self.R[self.state, action][0],
                                      scale=self.R[self.state, action][1])
        #print(self.state, action, self.P[self.state, action])
        newState = np.random.choice(self.nState, p=self.P[self.state, action])
        
        # Update the environment
        self.state = newState
        self.timestep += 1

        episodeEnd = 0
        if self.timestep == self.epLen:
            episodeEnd = 1
            #newState = None
            self.reset()

        return reward, newState, episodeEnd
    
    def argmax(self,b):
        print(b)
        return np.random.choice(np.where(b == b.max())[0])

In [34]:
def proj(x, lo, hi):
    '''Projects the value of x into the [lo,hi] interval'''
    return max(min(x,hi),lo)

In [35]:
class UCRL_VTR(object):
    '''
    Algorithm 1 as described in the paper Model-Based RL with
    Value-Target Regression
    The algorithm assumes that the rewards are in the [0,1] interval.
    '''
    def __init__(self,env,K):
        self.env = env
        self.K = K
        # Here the dimension (self.d) for the Tabular setting is |S x A x S| as stated in Appendix B
        self.d = env.nState * env.nAction * env.nState 
        # In the tabular setting the basis models is just the dxd identity matrix, see Appendix B
        self.P_basis = np.identity(self.d)
        #Our Q-values are initialized as a 2d numpy array, will eventually convert to a dictionary
        self.Q = [np.zeros((env.nState,env.nAction)) for i in range(env.epLen)]
        #Our State Value function is initialized as a 1d numpy error, will eventually convert to a dictionary
        self.V = [np.zeros(env.nState) for i in range(env.epLen+1)] # self.V[env.epLen] stays zero
        #The index of each (s,a,s') tuple, see Appendix B
        self.sigma = {}
        self.createSigma()
        #See Step 2, of algorithm 1
#         self.M = env.epLen**2*self.d*np.identity(self.d)
        self.M = np.identity(self.d)
        #See Step 2
        self.w = np.zeros(self.d)
        #See Step 2
        self.theta = np.matmul(np.linalg.inv(self.M),self.w)
        #See Step 3
        self.delta = 1/self.K
        #C_theta >= the 2-norm of theta_star, see Assumption 1
        self.C_theta = 3.0
#         #Initialize the predicted value of the basis models, see equation 3
#         self.X = np.zeros((env.epLen,self.d))

    def feature_vector(self,s,a,h):
        '''
        Returning sum_{s'} V[h+1][s'] P_dot(s'|s,a),
        with V stored in self.
        Inputs:
            s - the state
            a - the action
            h - the current timestep within the episode
        '''
        sums = np.zeros(self.d)
        for ss in range(env.nState):
            sums += self.V[h+1][ss] * self.P_basis[self.sigma[(s,a,ss)]]
        return sums
            
    def update_Q(self,s,a,k,h):
        '''
        A function that updates both Q and V, Q is updated according to equation 4 and 
        V is updated according to equation 2
        Inputs:
            s - the state
            a - the action
            k - the current episode
            h - the current timestep within the episode
        Currently, does not properly compute the Q-values but it does seem to learn theta_star
        '''
        #Here env.R[(s,a)][0] is the true reward from the environment
        # Alex's code: X = self.X[h,:] 
        # Suggested code:
        X = self.feature_vector(s,a,h)
        self.Q[h][s,a] = proj(env.R[(s,a)][0] + np.dot(X,self.theta) + np.sqrt(self.Beta(k)) \
            * np.sqrt(np.dot(np.dot(np.transpose(X),np.linalg.inv(self.M)),X)), 0, env.epLen )
        self.V[h][s] = max(self.Q[h][s,:])
    
    def update_Qend(self,k):
        '''
        A function that updates both Q and V at the end of each episode, see step 16 of algorithm 1
        Inputs:
            k - the current episode
        '''
        #step 16
        for h in range(env.epLen-1,-1,-1):
            for s in range(env.nState):
                for a in range(env.nAction):
                    #Here env.R[(s,a)][0] is the true reward from the environment
                    # Alex's code: X = self.X[h,:] 
                    # Suggested code:
                    self.update_Q(s,a,k,h)
                self.V[h][s] = max(self.Q[h][s,:])
    
    def update_stat(self,s,a,s_,h):
        '''
        A function that performs steps 9-13 of algorithm 1
        Inputs:
            s - the current state
            a - the action
            s_ - the next state
            k - the current episode
            h - the timestep within episode when s was visited (starting at zero)
        '''
        #Step 10
#         self.X[h,:] = self.feature_vector(s,a,h) # do not need to store this
        X = self.feature_vector(s,a,h)
        #Step 11
        y = self.V[h+1][s_]
#         if s_ != None:
#             y = self.V[h+1][s_]
#         else:
#             y = 0.0
        #Step 12
        self.M = self.M + np.outer(X,X)
        #Step 13
        self.w = self.w + y*X
    
    def update_param(self):
        '''
        Updates our approximation of theta_star at the end of each episode, see 
        Step 15 of algorithm1
        '''
        #Step 15
        print(self.M)
        self.theta = np.matmul(np.linalg.inv(self.M),self.w)
        
    def act(self,s,h):
        '''
        Returns the greedy action with respect to Q_{h,k}(s,a) for a \in A
        see step 8 of algorithm 1
        Inputs:
            s - the current state
            h - the current timestep within the episode
        '''
        #step 8
        return env.argmax(self.Q[h][s,:])
        # return bernoulli.rvs(0.9) #A random policy for testing
        
    def createSigma(self):
        '''
        A simple function that creates sigma according to Appendix B.
        Here sigma is a dictionary who inputs is a tuple (s,a,s') and stores
        the interger index to be used in our basis model P.
        '''
        i = 0
        for s in range(env.nState):
            for a in range(env.nAction):
                for s_ in range(env.nState):
                    self.sigma[(s,a,s_)] = int(i)
                    i += 1
    
    def Beta(self,k):
        '''
        A function that return Beta_k according to Algorithm 1, step 3
        '''
        #Step 3
        return 16*pow(self.C_theta,2)*pow(env.epLen,2)*self.d*np.log(1+env.epLen*k) \
            *np.log(pow(k+1,2)*env.epLen/self.delta)*np.log(pow(k+1,2)*env.epLen/self.delta)
        

In [36]:
env = make_riverSwim(epLen = 2, nState = 2)
K = 2000
agent = UCRL_VTR(env,K)
count = np.zeros((env.nState,env.nState))
for k in range(1,K+1):
    env.reset()
    done = 0
    while done != 1:
        s = env.state
        h = env.timestep
        a = agent.act(s,h)
        r,s_,done = env.advance(a)
        count[s,s_] += 1
        agent.update_stat(s,a,s_,h)
    agent.update_param()
    agent.update_Qend(k)

[0. 0.]
[0. 0.]
[[1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1.]]
[2. 2.]
[0.005 0.   ]
[[1.000025 0.005    0.       0.       0.       0.       0.       0.      ]
 [0.005    2.       0.       0.       0.       0.       0.       0.      ]
 [0.       0.       1.       0.       0.       0.       0.       0.      ]
 [0.       0.       0.       1.       0.       0.       0.       0.      ]
 [0.       0.       0.       0.       1.       0.       0.       0.      ]
 [0.       0.       0.       0.       0.       1.       0.       0.      ]
 [0.       0.       0.       0.       0.       0.       1.       0.      ]
 [0.       0.       0.       0.       0.       0.       0.       1.      ]]
[2. 2.]
[0. 1.]
[[1.000025 0.005    0.       0.       0.       0.       0.       0.      ]
 [0.005    2.       0.       0.       0.       0.    

   1.      ]]
[2. 2.]
[0.005 0.   ]
[[ 1.00095   0.19      0.        0.        0.        0.        0.
   0.      ]
 [ 0.19     39.        0.        0.        0.        0.        0.
   0.      ]
 [ 0.        0.        1.001175  0.235     0.        0.        0.
   0.      ]
 [ 0.        0.        0.235    48.        0.        0.        0.
   0.      ]
 [ 0.        0.        0.        0.        1.        0.        0.
   0.      ]
 [ 0.        0.        0.        0.        0.        1.        0.
   0.      ]
 [ 0.        0.        0.        0.        0.        0.        1.
   0.      ]
 [ 0.        0.        0.        0.        0.        0.        0.
   1.      ]]
[2. 2.]
[0.005 0.   ]
[[ 1.00095  0.19     0.       0.       0.       0.       0.       0.     ]
 [ 0.19    39.       0.       0.       0.       0.       0.       0.     ]
 [ 0.       0.       1.0012   0.24     0.       0.       0.       0.     ]
 [ 0.       0.       0.24    49.       0.       0.       0.       0.     ]
 [ 0.    

   1.      ]]
[2. 2.]
[0. 1.]
[[ 1.00175  0.35     0.       0.       0.       0.       0.       0.     ]
 [ 0.35    71.       0.       0.       0.       0.       0.       0.     ]
 [ 0.       0.       1.00225  0.45     0.       0.       0.       0.     ]
 [ 0.       0.       0.45    91.       0.       0.       0.       0.     ]
 [ 0.       0.       0.       0.       1.       0.       0.       0.     ]
 [ 0.       0.       0.       0.       0.       1.       0.       0.     ]
 [ 0.       0.       0.       0.       0.       0.       1.       0.     ]
 [ 0.       0.       0.       0.       0.       0.       0.       1.     ]]
[2. 2.]
[0.005 0.   ]
[[ 1.001775  0.355     0.        0.        0.        0.        0.
   0.      ]
 [ 0.355    72.        0.        0.        0.        0.        0.
   0.      ]
 [ 0.        0.        1.00225   0.45      0.        0.        0.
   0.      ]
 [ 0.        0.        0.45     91.        0.        0.        0.
   0.      ]
 [ 0.        0.        0.      

    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.00265    0.53       0.         0.         0.         0.
    0.         0.      ]
 [  0.53     107.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.003025   0.605      0.         0.
    0.         0.      ]
 [  0.         0.         0.605    122.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.00265   0.53      0.        0.        0.        0.        0.
    0.     ]
 [  0.53    107.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.00305   0.61      0.        0.        0.
    0.     ]

[2. 2.]
[0. 1.]
[[  1.0034     0.68       0.         0.         0.         0.
    0.         0.      ]
 [  0.68     137.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.003775   0.755      0.         0.
    0.         0.      ]
 [  0.         0.         0.755    152.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.003425   0.685      0.         0.         0.         0.
    0.         0.      ]
 [  0.685    138.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.003775   0.755      0.         0.
    0.         0.      ]
 [  

[2. 2.]
[0.005 0.   ]
[[  1.00415   0.83      0.        0.        0.        0.        0.
    0.     ]
 [  0.83    167.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.00455   0.91      0.        0.        0.
    0.     ]
 [  0.        0.        0.91    183.        0.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        1.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        1.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        1.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        0.
    1.     ]]
[2. 2.]
[0.005 0.   ]
[[  1.004175   0.835      0.         0.         0.         0.
    0.         0.      ]
 [  0.835    168.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.00455    0.91       0.         0.
    0.         0.      ]
 [  0.         0.         0.91     183.         0.    

[0.005 0.   ]
[[  1.004925   0.985      0.         0.         0.         0.
    0.         0.      ]
 [  0.985    198.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.005275   1.055      0.         0.
    0.         0.      ]
 [  0.         0.         1.055    212.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.004925   0.985      0.         0.         0.         0.
    0.         0.      ]
 [  0.985    198.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.0053     1.06       0.         0.
    0.         0.      ]
 [  0.      

    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.0059     1.18       0.         0.         0.         0.
    0.         0.      ]
 [  1.18     237.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.005825   1.165      0.         0.
    0.         0.      ]
 [  0.         0.         1.165    234.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.0059    1.18      0.        0.        0.        0.        0.
    0.     ]
 [  1.18    237.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.00585   1.17      0.        0.        0.
    0.     ]

    1.     ]]
[2. 2.]
[0.005 0.   ]
[[  1.006725   1.345      0.         0.         0.         0.
    0.         0.      ]
 [  1.345    270.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.00665    1.33       0.         0.
    0.         0.      ]
 [  0.         0.         1.33     267.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.00675   1.35      0.        0.        0.        0.        0.
    0.     ]
 [  1.35    271.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.00665   1.33      0.        0.        0.
    0.     ]
 [  0

    1.     ]]
[2. 2.]
[0. 1.]
[[  1.00775    1.55       0.         0.         0.         0.
    0.         0.      ]
 [  1.55     311.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.007325   1.465      0.         0.
    0.         0.      ]
 [  0.         0.         1.465    294.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.007775   1.555      0.         0.         0.         0.
    0.         0.      ]
 [  1.555    312.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.007325   1.465      0.         0.
    0.         

[2. 2.]
[0. 1.]
[[  1.00865    1.73       0.         0.         0.         0.
    0.         0.      ]
 [  1.73     347.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.007975   1.595      0.         0.
    0.         0.      ]
 [  0.         0.         1.595    320.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.00865   1.73      0.        0.        0.        0.        0.
    0.     ]
 [  1.73    347.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.008     1.6       0.        0.        0.
    0.     ]
 [  0.        0.        1.6    

[2. 2.]
[0.005 0.   ]
[[  1.009425   1.885      0.         0.         0.         0.
    0.         0.      ]
 [  1.885    378.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.00875    1.75       0.         0.
    0.         0.      ]
 [  0.         0.         1.75     351.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.009425   1.885      0.         0.         0.         0.
    0.         0.      ]
 [  1.885    378.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.008775   1.755      0.         0.
    0.         0.      ]
 [  

    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.00995    1.99       0.         0.         0.         0.
    0.         0.      ]
 [  1.99     399.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.009975   1.995      0.         0.
    0.         0.      ]
 [  0.         0.         1.995    400.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.00995   1.99      0.        0.        0.        0.        0.
    0.     ]
 [  1.99    399.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.01      2.        0.        0.        0.
    0.     ]

    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.010925   2.185      0.         0.         0.         0.
    0.         0.      ]
 [  2.185    438.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.010925   2.185      0.         0.
    0.         0.      ]
 [  0.         0.         2.185    438.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.010925   2.185      0.         0.         0.         0.
    0.         0.      ]
 [  2.185    438.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.01095    2.19       0.         0.
   

    1.     ]]
[2. 2.]
[0.005 0.   ]
[[  1.01195    2.39       0.         0.         0.         0.
    0.         0.      ]
 [  2.39     479.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.011675   2.335      0.         0.
    0.         0.      ]
 [  0.         0.         2.335    468.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.011975   2.395      0.         0.         0.         0.
    0.         0.      ]
 [  2.395    480.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.011675   2.335      0.         0.
    0.   

    1.     ]]
[2. 2.]
[0.005 0.   ]
[[  1.012775   2.555      0.         0.         0.         0.
    0.         0.      ]
 [  2.555    512.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.01255    2.51       0.         0.
    0.         0.      ]
 [  0.         0.         2.51     503.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.012775   2.555      0.         0.         0.         0.
    0.         0.      ]
 [  2.555    512.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.012575   2.515      0.         0.
    0.         

    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.01355    2.71       0.         0.         0.         0.
    0.         0.      ]
 [  2.71     543.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.013575   2.715      0.         0.
    0.         0.      ]
 [  0.         0.         2.715    544.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.01355   2.71      0.        0.        0.        0.        0.
    0.     ]
 [  2.71    543.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.0136    2.72      0.        0.        0.
    0.     ]

    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.014675   2.935      0.         0.         0.         0.
    0.         0.      ]
 [  2.935    588.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.014375   2.875      0.         0.
    0.         0.      ]
 [  0.         0.         2.875    576.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.014675   2.935      0.         0.         0.         0.
    0.         0.      ]
 [  2.935    588.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.0144     2.88       0.         0.
    0.   

[2. 2.]
[0.005 0.   ]
[[  1.01535    3.07       0.         0.         0.         0.
    0.         0.      ]
 [  3.07     615.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.015475   3.095      0.         0.
    0.         0.      ]
 [  0.         0.         3.095    620.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.015375   3.075      0.         0.         0.         0.
    0.         0.      ]
 [  3.075    616.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.015475   3.095      0.         0.
    0.         0.      

[2. 2.]
[0.005 0.   ]
[[  1.016225   3.245      0.         0.         0.         0.
    0.         0.      ]
 [  3.245    650.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.016275   3.255      0.         0.
    0.         0.      ]
 [  0.         0.         3.255    652.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.01625    3.25       0.         0.         0.         0.
    0.         0.      ]
 [  3.25     651.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.016275   3.255      0.         0.
    0.         0.      

[2. 2.]
[0.005 0.   ]
[[  1.01715    3.43       0.         0.         0.         0.
    0.         0.      ]
 [  3.43     687.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.017225   3.445      0.         0.
    0.         0.      ]
 [  0.         0.         3.445    690.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.017175   3.435      0.         0.         0.         0.
    0.         0.      ]
 [  3.435    688.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.017225   3.445      0.         0.
    0.         0.      

[2. 2.]
[0. 1.]
[[  1.01825   3.65      0.        0.        0.        0.        0.
    0.     ]
 [  3.65    731.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.01785   3.57      0.        0.        0.
    0.     ]
 [  0.        0.        3.57    715.        0.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        1.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        1.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        1.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        0.
    1.     ]]
[2. 2.]
[0.005 0.   ]
[[  1.018275   3.655      0.         0.         0.         0.
    0.         0.      ]
 [  3.655    732.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.01785    3.57       0.         0.
    0.         0.      ]
 [  0.         0.         3.57     715.         0.         0

[2. 2.]
[0.005 0.   ]
[[  1.019325   3.865      0.         0.         0.         0.
    0.         0.      ]
 [  3.865    774.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.018525   3.705      0.         0.
    0.         0.      ]
 [  0.         0.         3.705    742.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.019325   3.865      0.         0.         0.         0.
    0.         0.      ]
 [  3.865    774.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.01855    3.71       0.         0.
    0.         0.      

    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.0202   4.04     0.       0.       0.       0.       0.       0.    ]
 [  4.04   809.       0.       0.       0.       0.       0.       0.    ]
 [  0.       0.       1.0194   3.88     0.       0.       0.       0.    ]
 [  0.       0.       3.88   777.       0.       0.       0.       0.    ]
 [  0.       0.       0.       0.       1.       0.       0.       0.    ]
 [  0.       0.       0.       0.       0.       1.       0.       0.    ]
 [  0.       0.       0.       0.       0.       0.       1.       0.    ]
 [  0.       0.       0.       0.       0.       0.       0.       1.    ]]
[2. 2.]
[0. 1.]
[[  1.0202     4.04       0.         0.         0.         0.
    0.         0.      ]
 [  4.04     809.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.019425   3.885      0.         0.
    0.         0.      ]
 [  0.         0.         3.885    778.         0.         0.
    0.      

[2. 2.]
[0. 1.]
[[  1.02105    4.21       0.         0.         0.         0.
    0.         0.      ]
 [  4.21     843.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.020275   4.055      0.         0.
    0.         0.      ]
 [  0.         0.         4.055    812.         0.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         1.         0.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         1.
    0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    1.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         1.      ]]
[2. 2.]
[0.005 0.   ]
[[  1.021075   4.215      0.         0.         0.         0.
    0.         0.      ]
 [  4.215    844.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.020275   4.055      0.         0.
    0.         0.      ]
 [  

    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.02205   4.41      0.        0.        0.        0.        0.
    0.     ]
 [  4.41    883.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.021     4.2       0.        0.        0.
    0.     ]
 [  0.        0.        4.2     841.        0.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        1.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        1.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        1.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        0.
    1.     ]]
[2. 2.]
[0. 1.]
[[  1.02205    4.41       0.         0.         0.         0.
    0.         0.      ]
 [  4.41     883.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.021025   4.205      0.         0.
    0.         0.      ]
 [  0.         0.         4.205    842. 

    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.0227    4.54      0.        0.        0.        0.        0.
    0.     ]
 [  4.54    909.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.02195   4.39      0.        0.        0.
    0.     ]
 [  0.        0.        4.39    879.        0.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        1.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        1.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        1.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        0.
    1.     ]]
[2. 2.]
[0.005 0.   ]
[[  1.022725   4.545      0.         0.         0.         0.
    0.         0.      ]
 [  4.545    910.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.02195    4.39       0.         0.
    0.         0.      ]
 [  0.         0.         4.39    

    0.         1.      ]]
[2. 2.]
[0. 1.]
[[  1.02355   4.71      0.        0.        0.        0.        0.
    0.     ]
 [  4.71    943.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.02275   4.55      0.        0.        0.
    0.     ]
 [  0.        0.        4.55    911.        0.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        1.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        1.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        1.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        0.
    1.     ]]
[2. 2.]
[0. 1.]
[[  1.02355    4.71       0.         0.         0.         0.
    0.         0.      ]
 [  4.71     943.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.022775   4.555      0.         0.
    0.         0.      ]
 [  0.         0.         4.555    912. 

[2. 2.]
[0.005 0.   ]
[[  1.02445   4.89      0.        0.        0.        0.        0.
    0.     ]
 [  4.89    979.        0.        0.        0.        0.        0.
    0.     ]
 [  0.        0.        1.0234    4.68      0.        0.        0.
    0.     ]
 [  0.        0.        4.68    937.        0.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        1.        0.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        1.        0.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        1.
    0.     ]
 [  0.        0.        0.        0.        0.        0.        0.
    1.     ]]
[2. 2.]
[0. 1.]
[[  1.02445    4.89       0.         0.         0.         0.
    0.         0.      ]
 [  4.89     979.         0.         0.         0.         0.
    0.         0.      ]
 [  0.         0.         1.023425   4.685      0.         0.
    0.         0.      ]
 [  0.         0.         4.685    938.         0.         0

[2. 2.]
[0. 1.]
[[1.025100e+00 5.020000e+00 0.000000e+00 0.000000e+00 0.000000e+00
  0.000000e+00 0.000000e+00 0.000000e+00]
 [5.020000e+00 1.005000e+03 0.000000e+00 0.000000e+00 0.000000e+00
  0.000000e+00 0.000000e+00 0.000000e+00]
 [0.000000e+00 0.000000e+00 1.024225e+00 4.845000e+00 0.000000e+00
  0.000000e+00 0.000000e+00 0.000000e+00]
 [0.000000e+00 0.000000e+00 4.845000e+00 9.700000e+02 0.000000e+00
  0.000000e+00 0.000000e+00 0.000000e+00]
 [0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00
  0.000000e+00 0.000000e+00 0.000000e+00]
 [0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
  1.000000e+00 0.000000e+00 0.000000e+00]
 [0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
  0.000000e+00 1.000000e+00 0.000000e+00]
 [0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
  0.000000e+00 0.000000e+00 1.000000e+00]]
[2. 2.]
[0. 1.]
[[1.02510e+00 5.02000e+00 0.00000e+00 0.00000e+00 0.00000e+00 0.00000e+00
  0.00000e+00 0.00000

In [37]:
true_p = []
for values in env.P.values():
    for value in values:
        true_p.append(value)
print('The 2-norm of (P_true - theta_star) is:',np.linalg.norm(true_p-agent.theta))

The 2-norm of (P_true - theta_star) is: 1.705353151850923


In [38]:
agent.theta

array([2.49748664e-05, 4.99497329e-03, 3.60019041e-03, 7.20038083e-01,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00])

In [39]:
for z in agent.Q:
    print(z)

[[2. 2.]
 [2. 2.]]
[[0.005 0.   ]
 [0.    1.   ]]
