In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import gym
import numpy as np
import random

In [3]:
env = gym.make('CartPole-v1')

In [4]:
class VanillaFeatureEncoder:
    def __init__(self, env):
        self.env = env
        
    def encode(self, state):
        return state
    
    @property
    def size(self):
        return self.env.observation_space.shape[0]

In [5]:
class QLearning_LVFA:
    def __init__(self, env, feature_encoder_cls, alpha=0.005, alpha_decay=0.9999, 
                 gamma=0.9999, epsilon=1., epsilon_decay=0.99):
        self.env = env
        self.feature_encoder = feature_encoder_cls(env)
        self.shape = (self.env.action_space.n, self.feature_encoder.size)
        self.weights = np.random.random(self.shape)
#         self.weights = np.zeros(self.shape)
        self.alpha = alpha
        self.alpha_decay = alpha_decay
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        
    def Q(self, feats):
        feats = feats.reshape(-1,1)
        return self.weights@feats
    
    def update_transition(self, s, action, s_prime, reward, done):
        s_feats = self.feature_encoder.encode(s)
        s_prime_feats = self.feature_encoder.encode(s_prime)
        action_prime = self.epsilon_greedy(s_prime)
        td_error = reward
        if not done:
            td_error += self.gamma*self.Q(s_prime_feats).max()
#             td_error += self.gamma*self.Q(s_prime_feats)[action_prime]
            
        td_error -=  self.Q(s_feats)[action]
        
        delta_w = td_error*s_feats
        self.weights[action] += self.alpha*delta_w
        
    def update_alpha_epsilon(self):
        self.epsilon = max(0.2, self.epsilon*self.epsilon_decay)
        self.alpha = self.alpha*self.alpha_decay
        
    def policy(self, state):
        state_feats = self.feature_encoder.encode(state)
        return self.Q(state_feats).argmax()
    
    def epsilon_greedy(self, state, epsilon=None):
        if epsilon is None: epsilon = self.epsilon
        if random.random()<epsilon:
            return self.env.action_space.sample()
        return self.policy(state)
       
        
    def train(self, n_episodes=200, max_steps_per_episode=200):
        for episode in range(n_episodes):
            done = False
            s, _ = env.reset()
            for i in range(max_steps_per_episode):
                
                action = self.epsilon_greedy(s)
                s_prime, reward, done, _, _ = self.env.step(action)
                self.update_transition(s, action, s_prime, reward, done)
                
                s = s_prime
                
                if done: break
                
            self.update_alpha_epsilon()

            if episode % 20 == 0:
                print(episode, self.evaluate(), self.epsilon, self.alpha)
                
    def evaluate(self, env=None, n_episodes=10, max_steps_per_episode=200):
        if env is None:
            env = self.env
            
        rewards = []
        for episode in range(n_episodes):
            total_reward = 0
            done = False
            s, _ = env.reset()
            for i in range(max_steps_per_episode):
                action = self.policy(s)
                
                s_prime, reward, done, _, _ = env.step(action)
                
                total_reward += reward
                s = s_prime
                if done: break
            
            rewards.append(total_reward)
            
        return np.mean(rewards)


In [6]:
agent = QLearning_LVFA(env, VanillaFeatureEncoder)
agent.train()

0 112.5 0.99 0.0049995000000000005
20 111.1 0.8097278682212583 0.004989510493352992
40 128.1 0.6622820409839835 0.004979540946750601
60 94.3 0.5416850759668536 0.004969591320310636
80 111.3 0.44304798162617254 0.004959661574230599
100 102.2 0.36237201786049694 0.004949751668787519
120 103.8 0.2963865873992079 0.0049398615643377955
140 95.8 0.24241664604458016 0.004929991221317044
160 94.0 0.2 0.004920140600239929
180 94.2 0.2 0.004910309661700014


In [7]:
agent.evaluate()

96.1

In [8]:
agent.weights

array([[ 0.77759332,  1.57777479, -0.2845712 , -1.70878178],
       [ 1.35762954, -0.93265221,  0.66481105,  1.55850504]])