In [8]:
import gym
import numpy as np
import math
import random

In [9]:
# balance a pole on a cart, move the cart left or right to maintain balance
env = gym.make('CartPole-v0')

In [10]:
# 2 actions: move cart left or move cart rigth
print(env.action_space.n)

2


In [11]:
# 4 means: position of cart, velocity of cart, angle of pole, rotation rate of pole
print(env.observation_space)

Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)


In [12]:
# lower bounds of the 4 values that make up the observation space
print(env.observation_space.low)

[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]


In [13]:
# upper bounds of the 4 values that make up the observation space
print(env.observation_space.high)

[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]


In [15]:
# discretize the state space so we can apply Q-learning to a bounded space
# first variable represents position of the cart
# cart position: 2 states, left or right - reducing this to 1 means that we're ignoring this variable in our state space
# second varible refers to the cart velocity
# cart velocity: reduces to 1, we ignore this state as well
# reduced our state space along 2 dimensions - makes our learning much faster as our Q-table size is smaller
# other two variable represent bucketize the pole position with respect to the vertical and angular velocity
# 6 buckets to represent the vertical velocity
# 3 buckets to represent the angular velocity
NUM_BUCKETS = (1, 1, 6, 3)

In [16]:
# cart actions - move left or move right
NUM_ACTIONS = env.action_space.n

In [17]:
STATE_BOUNDS = list(zip(env.observation_space.low, env.observation_space.high))

In [18]:
# cart velocity bounds
STATE_BOUNDS[1] = [-0.5, 0.5]

In [19]:
# pole angular velocity bounds
STATE_BOUNDS[3] = [-math.radians(50), math.radians(50)]

In [20]:
print(STATE_BOUNDS)

[(-4.8, 4.8), [-0.5, 0.5], (-0.41887903, 0.41887903), [-0.8726646259971648, 0.8726646259971648]]


In [21]:
# num_states * num_actions = (1*1*6*3) * 2
q_table = np.zeros(NUM_BUCKETS + (NUM_ACTIONS,))

In [22]:
print(q_table.shape)

(1, 1, 6, 3, 2)


In [23]:
print(q_table)

[[[[[0. 0.]
    [0. 0.]
    [0. 0.]]

   [[0. 0.]
    [0. 0.]
    [0. 0.]]

   [[0. 0.]
    [0. 0.]
    [0. 0.]]

   [[0. 0.]
    [0. 0.]
    [0. 0.]]

   [[0. 0.]
    [0. 0.]
    [0. 0.]]

   [[0. 0.]
    [0. 0.]
    [0. 0.]]]]]


In [24]:
EXPLORE_RATE_MIN = 0.01

In [25]:
LEARNING_RATE_MIN = 0.1

In [26]:
def get_explore_rate(t):
    # decay the exploration rate but not too fast, we want to explore less as we're more sure of getting the balance right
    return max(EXPLORE_RATE_MIN, min(1, 1.0 - math.log10((t + 1) / 25)))

In [27]:
def get_learning_rate(t):
    # start with a larger learning rate and decay it slowly
    return max(LEARNING_RATE_MIN, min(0.5, 1.0 - math.log10((t + 1) / 25)))

In [28]:
def select_action(state, explore_rate):
    # explore the sample space at random based on the explore_rate
    if random.random() < explore_rate:
        action = env.action_space.sample()
    # perform the action that gets us to the state with the highest Q-value
    else:
        action = np.argmax(q_table[state])
    return action

In [34]:
def state_to_bucket(state):
    bucket_indices = []
    
    for i in range(len(state)):
        # if state is beyond the lower bounds, set it to the smallest bucket
        if state[i] <= STATE_BOUNDS[i][0]:
            bucket_index = 0
        # if state is beyond the upper bounds, cap it to the largest bucket
        elif state[i] >= STATE_BOUNDS[i][1]:
            bucket_index = NUM_BUCKETS[i] - 1
        else:
            bound_width = STATE_BOUNDS[i][1] - STATE_BOUNDS[i][0]
            
            # use the bound width and the number of buckets to calculate which discrete bucket our continuous value falls in
            offset = (NUM_BUCKETS[i] - 1) * STATE_BOUNDS[i][0] / bound_width
            scaling = (NUM_BUCKETS[i] -1) / bound_width
            
            bucket_index = int(round(scaling * state[i] - offset))
        
        bucket_indices.append(bucket_index)
    
    return tuple(bucket_indices)

In [47]:
def simulate():
    learning_rate = get_learning_rate(0)
    explore_rate = get_explore_rate(0)
    
    discount_factor = 0.99
    # how long has the pole balanced on the cart? 200 times instances makes one streak
    num_streaks = 0
    
    for episode in range(1000):
        observ = env.reset()
        
        state_0 = state_to_bucket(observ)
        
        for t in range(250):
            
            env.render()
            
            action = select_action(state_0, explore_rate)
            
            observ, reward, done, _ = env.step(action)
            
            state = state_to_bucket(observ)
            
            best_q = np.amax(q_table[state])
            
            # Q[s, a] = Q[s, a] + alpha * (R[s, a] + gamma * Max[Q(s', A)] - Q[s, a])
            q_table[state_0 + (action,)] += learning_rate * (reward + discount_factor * (best_q) - q_table[state_0 + (action,)])
            
            state_0 = state
            '''
            print("\Episode = %d" % episode)
            print("t = %d" % t)
            print("Action %d" % action)
            print("State %s" % str(state))
            print("Reward %f" % reward)
            print("Best Q %f" % best_q)
            print("Explore rate: %f" % explore_rate)
            print("Learning rate: %f" % learning_rate)
            print("Streaks: %d" % num_streaks)
            
            print("")
            '''
            
            if done:
                # print("Episode %d finished after %f time steps" % (episode, t))
                
                if(t >= 199):
                    num_streaks += 1
                else:
                    num_streaks = 0
                break
            
            if num_streaks > 120:
                break
            
            explore_rate = get_explore_rate(episode)
            learning_rate = get_learning_rate(episode)

In [48]:
simulate()

In [49]:
env.close()

In [50]:
print(q_table)

[[[[[ 0.          0.        ]
    [ 0.          0.        ]
    [ 0.          0.        ]]

   [[75.6347248  72.01311701]
    [61.91321938 77.25345304]
    [19.54764926  0.        ]]

   [[99.95313656 99.47635972]
    [99.96840756 99.89861471]
    [99.75726939 99.95456561]]

   [[99.95466042 99.79294757]
    [99.89639805 99.96873941]
    [99.34346476 99.9534751 ]]

   [[62.46445062 81.64068294]
    [91.22575092 89.73081802]
    [91.61159149 91.11241394]]

   [[ 0.          0.        ]
    [ 0.          0.        ]
    [ 0.          0.        ]]]]]
