# This is a notebook for testing my implementation of LSPI

In [None]:
# Messing around with OpenAI Gym

In [88]:
import gym
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline



In [89]:
env = gym.make('CartPole-v0')
env.reset()

array([ 0.04122447, -0.04029895,  0.03812163, -0.00496799])

In [3]:
for _ in range(100):
    env.render()
    #env.step(env.action_space.sample()) # take a random action
    env.step(0)

No handlers could be found for logger "gym.envs.classic_control.cartpole"


In [2]:
env.reset()

array([ 0.04029155,  0.006834  , -0.01659412,  0.00514427])

In [29]:
env.action_space.sample()

0

In [7]:
env.observation_space

Box(4,)

In [24]:
env.step(env.action_space.sample())

(array([ 0.15534137,  0.40782107, -0.19749603, -0.83199806]), 1.0, False, {})

# Implementation

In [397]:
def LSPI(samples, basis_functions, gamma, epsilon, w):
    '''
    Compute the parameters of the policy, w, using the LSPI algorithm.
    
    Inputs:
    sample: list of tuples of the form (s,a,r,s')
    basis_functions: list of basis functions
    gamma: float, discount factor
    epsilon: float, convergence threshold
    w: intial policy parameter vector
    
    Outputs:
    w: the converged policy paramters
    '''
    
    w_prev = w
    while True:
   # for i in range(100):
        w = LSTDQ(samples, basis_functions, gamma, w_prev)
        
        if converged(w, w_prev, epsilon):
            break 
        else:
            w_prev = w
    
    return w

def converged(w, w_prev, epsilon, sigma = 0.1):
    '''
    Determines if the policy parameters have converged based
    on whether or not the norm of the difference of w
    is less than the threshold epsilon.
    
    Inputs:
    w: a policy parameter vector
    w_prev: the policy parameter vetor from a previous iteration.
    epsilon: float, convergence threshold
    '''
    return np.linalg.norm(w-w_prev) < epsilon

def LSTDQ(samples, basis_functions, gamma, w):
    '''
    Simple version of LSTDQ
    '''
    k = len(basis_functions)
#    A = np.zeros((k,k)), this might not have an inverse, use the next line instead
    A = np.identity(k) * 0.1
    b = np.zeros(k)
    
    sub_samples = samples[np.random.choice(len(samples), 100, replace=False)]
    
    for s, a, r, sp in sub_samples:
        phi = compute_phi(s,a, basis_functions)
        phi_p = compute_phi(sp, get_policy_action(sp, w,basis_functions), basis_functions)

        A = A + np.outer(phi, (phi - gamma*phi_p))
        b = b + phi*r
    
    
    w = np.dot(np.linalg.inv(A),b)
    return w
    
    

    
    
def LSTDQ_OPT(samples, basis_functions, gamma, w, sigma=0.1):
    '''
    Computes an approximation of the policy parameters based
    on the LSTDQ-OPT algorithm presented in the paper.
    
    Inputs:
    sample: list of tuples of the form (s,a,r,s')
    basis_functions: list of basis functions
    gamma: float, discount factor
    epsilon: float, convergence threshold
    w: intial policy parameter vector
    
     sigma: small positive float.
    '''
    pass
       

def compute_phi(s,a, basis_functions):
    '''
    Computes the vector ϕ(s,a) according to the basis function ϕ_1...ϕ_k
    
    Inputs:
    s: state
    a: action
    basis_functions: list of basis functions that operate on s and a
    
    Outputs:
    ϕ(s,a), a vector where each entry is the result of one of the basis functions.
    '''
    phi= np.array([bf(s,a) for bf in basis_functions])
    return phi
    
def get_policy_action(s, w, basis_functions):
    '''
    Given a parameterization for the policy,
    reconstruct the policy and querery it to get 
    the optimal action for state s. That is,
    the argmax over actions of ϕ(s,a).w
    
    Inputs:
    s: state
    w: policy parameters
    action_space: set of all possible actions
    
    Outputs:
    action a that the policy says is best
    '''
    a_max = None
    max_score = float("-inf")
    
    # TODO: don't hard code action space
    action_space = [0,1]
    
    # Search action space for most valuable action
    for a in action_space:
        #print "phi:", compute_phi(s,a, basis_functions)
        #print "w:",w
        score = np.dot(compute_phi(s,a, basis_functions), w)
       # print "Score:",score
        # update if we found something better
        if score > max_score:
            max_score = score
            a_max = a
    print "Best action:",a_max
    return a_max
    

def get_basis_functions():
    '''
    Define some basis functions and return them in a list
    '''
    bf1 = lambda s,a: 1
    bf2 = lambda s,a: np.exp( - np.linalg.norm(s)/2.0) # rbf, with reference to origin
    return [bf1, bf2]


def generate_samples(n_samples):
    samples = []

    for i in range(n_samples):
        s = env.reset()
        action_space = [0,1]
        for a in action_space:
            sp,r, _,_ = env.step(a)
            sample = (s,a,r, sp)
            samples.append(sample)

    return np.array(samples)
    

In [398]:
bfs = get_basis_functions()
samples = generate_samples(10000)
gamma = 0.1
k = len(bfs)
w = np.zeros(k)
print w

gamma = 0.1
epsilon = 0.0001
w_est = LSPI(samples, bfs, gamma, epsilon, w)
print w_est



[ 0.  0.]
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0


Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best actio

Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best actio

Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best actio

Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best action: 0
Best actio

In [361]:
print LSTDQ(samples, bfs, gamma, np.array([0,0]))

new w: [ 0.58558378  0.53628282]
[ 0.58558378  0.53628282]


In [396]:
print LSTDQ(samples, bfs, gamma, np.array([0.57800949 , 0.54309165]))

[ 0.57958179  0.54131968]


In [379]:
samples[0]

array([array([-0.00573749, -0.03771747, -0.03415394, -0.01942246]), 0, 1.0,
       array([-0.00649184, -0.2323334 , -0.03454239,  0.26229187])], dtype=object)