In [64]:
import numpy as np
import gym
import sys
import itertools
from sklearn.kernel_approximation import RBFSampler
import sklearn.pipeline
import sklearn.preprocessing

if "../" not in sys.path:
  sys.path.append("../") 

from lib import plotting

In [65]:
env = gym.make('MountainCar-v0')

In [99]:
class RadialBasisFeaturizer():
    def __init__(self,env):
        observation_exs = np.array([env.observation_space.sample() for _ in range(10000)])
        self.scaler = sklearn.preprocessing.StandardScaler()
        #Standardize features by removing the mean and scaling to unit variance
        self.scaler.fit(observation_exs)
        
        self.featurizer = sklearn.pipeline.FeatureUnion([
            ("rbf1", RBFSampler(gamma=5.0, n_components=100)),
            ("rbf2", RBFSampler(gamma=2.0, n_components=100)),
            ("rbf3", RBFSampler(gamma=1.0, n_components=100)),
            ("rbf4", RBFSampler(gamma=0.5, n_components=100))
        ])
        
        self.featurizer.fit(self.scaler.transform(observation_exs))
        self.num_weights = 400
    def featurize(self,observation):
        scaled = self.scaler.transform([observation])
        return self.featurizer.transform(scaled)

In [111]:
class LinearActionValueEstimator():
    def __init__(self,featurizer,nA):
        self.featurizer = featurizer
        self.nA = nA
        self.weights = np.zeros([nA,self.featurizer.num_weights])
    def featurize_state(self,state):
        return self.featurizer.featurize(state)[0]
    def predict(self,state,a=None):
        state_features = self.featurize_state(state)
        if a is not None:
            return state_features.dot(self.weights[a])
        else:
            return np.array([state_features.dot(a) for a in range(self.nA)])
    def update(self, s, a, y):
            """
            Updates the estimator parameters for a given state and action towards
            the target y.
            """
            alpha = .01
            state_featurized = self.featurize_state(s)
            self.weights[a] += alpha * (y-self.predict(s,a)) * state_featurized
            
            
            
            

In [112]:
def make_epsilon_greedy_policy(estimator, epsilon, nA):
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) * epsilon / nA
        q_values = estimator.predict(observation)
        best_action = np.argmax(q_values)
        A[best_action] += (1.0 - epsilon)
        return A
    return policy_fn

In [113]:
rbf_featurizer = RadialBasisFeaturizer(env)
fa = LinearActionValueEstimator(rbf_featurizer,env.action_space.n)

In [96]:
def sarsa(env,estimator,num_episodes,discount_factor=1.0,epsilon=.1,epsilon_decay=1.0):
    
    # Keeps track of useful statistics
    stats = plotting.EpisodeStats(
        episode_lengths=np.zeros(num_episodes),
        episode_rewards=np.zeros(num_episodes))    
    
    for episode in range(num_episodes):
        policy = make_epsilon_greedy_policy(estimator,epsilon,env.action_space.n)
        
        state = env.reset()
        
        for t in itertools.count():
            action_probs = policy(state)
            action = np.random.choice(np.arrange(len(action_probs)),p=action_probs)
            
            next_state,reward,done,_ = env.step(action)
            
            episode_rewards[episode] += reward
            episode_lengths[episode] = t
            
            q_values_next = estimator.predict(next_state)            
                        
            action_probs = policy(next_state)
            
            next_action = np.random.choice(np.arrange(len(action_probs)),p=action_probs)
            
            target = reward + discount_factor * q_values_next[next_action]
            
            estimator.update(state,action,target)
            
            print("\rStep {} @ Episode {}/{} ({})".format(t, episode + 1, num_episodes, last_reward), end="")
                
            if done:
                break
            
            state = next_state
        
            

In [98]:
stats = sarsa(env,fa,100)

IndexError: index 1059 is out of bounds for axis 0 with size 3