In [2]:
import gym
import numpy as np
import matplotlib.pyplot as plt
import time
from gym.envs.registration import register
import os

In [3]:
env = gym.make('CartPole-v0')
env.reset()

array([ 0.01942701,  0.02362708, -0.04432664, -0.0091893 ])

In [4]:
for step in range(100):
    env.render();
    action = env.action_space.sample()
    observation, reward, done, infor = env.step(action)
    print(observation)
    time.sleep(0.02)

env.close();

[ 0.01989955 -0.17083206 -0.04451043  0.26918513]
[ 0.01648291 -0.36529148 -0.03912672  0.54750362]
[ 0.00917708 -0.55984253 -0.02817665  0.82760652]
[-0.00201977 -0.36434688 -0.01162452  0.5261966 ]
[-0.0093067  -0.1690633  -0.00110059  0.22987349]
[-0.01268797 -0.36416951  0.00349688  0.52220905]
[-0.01997136 -0.16909695  0.01394106  0.23063009]
[-0.0233533   0.02582304  0.01855366 -0.05762295]
[-0.02283684 -0.16955995  0.01740121  0.24085557]
[-0.02622804  0.02530916  0.02221832 -0.04628815]
[-0.02572185  0.22010558  0.02129255 -0.33187902]
[-0.02131974  0.02468713  0.01465497 -0.03255825]
[-0.020826    0.21959589  0.01400381 -0.32058158]
[-0.01643408  0.41451564  0.00759218 -0.60881555]
[-0.00814377  0.21928838 -0.00458413 -0.31375102]
[-0.003758    0.02423203 -0.01085916 -0.0225173 ]
[-0.00327336  0.21950801 -0.0113095  -0.31860654]
[ 0.0011168   0.02454895 -0.01768163 -0.02951157]
[ 0.00160778 -0.17031504 -0.01827186  0.25754059]
[-0.00179852 -0.36517142 -0.01312105  0.54440475]




[ 0.18825015  0.83566505 -0.3766717  -2.00256128]
[ 0.20496346  1.03230655 -0.41672293 -2.38498622]
[ 0.22560959  0.84252167 -0.46442265 -2.2436727 ]
[ 0.24246002  0.65371477 -0.50929611 -2.12214447]
[ 0.25553432  0.46582375 -0.551739   -2.01941987]
[ 0.26485079  0.66133166 -0.59212739 -2.42327154]
[ 0.27807742  0.47392543 -0.64059282 -2.35410908]
[ 0.28755593  0.28732151 -0.68767501 -2.30541213]
[ 0.29330236  0.1014228  -0.73378325 -2.27655381]
[ 0.29533082  0.2939726  -0.77931432 -2.68793547]
[ 0.30121027  0.10779551 -0.83307303 -2.69588791]
[ 0.30336618  0.2972187  -0.88699079 -3.10455889]
[ 0.30931055  0.11005983 -0.94908197 -3.15510437]
[ 0.31151175  0.29512909 -1.01218406 -3.55577615]
[ 0.31741433  0.10594678 -1.08329958 -3.65468288]
[ 0.31953327  0.2852499  -1.15639324 -4.04041688]
[ 0.32523827  0.46034417 -1.23720158 -4.41528267]
[ 0.33444515  0.2644833  -1.32550723 -4.59687533]
[ 0.33973482  0.06638144 -1.41744474 -4.80991545]
[ 0.34106245 -0.13452426 -1.51364304 -5.05443237]


In [5]:
def create_bins(num_bins_per_action = 10):
    bins_cart_position = np.linspace(-4.8, 4.8, num_bins_per_action)
    bins_cart_velocity = np.linspace(-5, 5, num_bins_per_action)
    bins_pole_angle = np.linspace(-0.418, 0.418, num_bins_per_action)
    bins_pole_velocity = np.linspace(-5, 5, num_bins_per_action)
    bins = np.array([bins_cart_position, bins_cart_velocity, bins_pole_angle, bins_pole_velocity])
    return bins

NUM_BINS = 10
BINS = create_bins(NUM_BINS)

def discretize_observation(observations, bins):
    binned_observations = []
    for i , observation in enumerate(observations):
        binned_observations.append(np.digitize(observation, bins[i]))
    
    return tuple(binned_observations)

#observations = env.reset()
#binned_observations = discretize_observation(observations, BINS)
#print(binned_observations)

q_table_shape  = (NUM_BINS, NUM_BINS, NUM_BINS, NUM_BINS, env.action_space.n)
q_table = np.zeros(q_table_shape)



def epsilon_greedy_action_selection(epsilon, q_table, discrete_state):
    # EXPLORATION
    if np.random.rand() <= epsilon:
        return env.action_space.sample()
    # EXPLOITATION
    else:
        return np.argmax(q_table[discrete_state])

EPOCHS = 20000 # number of episodes
ALPHA = 0.8 # Learning rate
GAMMA = 0.9 # Discount factor

def compute_next_q_value(old_q_value, reward, new_q_value):
    return old_q_value + ALPHA * (reward + GAMMA * new_q_value - old_q_value)

epsilon = 1
BURN_IN = 1
EPSILON_END =10000
EPSILON_REDUCE = 0.0001

def reduce_epsilon(epsilon, epoch):
    if BURN_IN <= epoch < EPSILON_END:
        return epsilon - EPSILON_REDUCE
    return epsilon

def fail(done, points, reward):
    if done and points < 150:
        reward = -200
    return reward

In [6]:
%matplotlib
epsilon =1.0
rewards = []
log_interval = 500
render_interval = 10000

fig = plt.figure()
ax = fig.add_subplot(111)
plt.ion()
fig.canvas.draw()
plt.show()



points_log = []
mean_points_log = []
epochs = []

for epoch in range(EPOCHS):
    initial_state = env.reset()
    discrete_state = discretize_observation(initial_state, BINS)
    done = False
    points = 0

    epochs.append(epoch)
    while not done:
        action = epsilon_greedy_action_selection(epsilon, q_table, discrete_state)
        next_state, reward, done, info = env.step(action)

        reward = fail(done, points, reward)

        next_state_discrete = discretize_observation(next_state, BINS)

        old_q_value = q_table[discrete_state + (action,)]
        next_optimal_q_value = np.max(q_table[next_state_discrete])

        next_q_value = compute_next_q_value(old_q_value, reward, next_optimal_q_value)
        q_table[discrete_state + (action,)] = next_q_value

        discrete_state = next_state_discrete
        points += 1

    epsilon = reduce_epsilon(epsilon, epoch)
    points_log.append(points)
    running_mean = round(np.mean(points_log[-30:]), 2)
    mean_points_log.append(running_mean)

    if epoch % log_interval == 0:
        
        ax.clear()
        ax.scatter(epochs, points_log)
        ax.plot(epochs, points_log)
        ax.plot(epochs, mean_points_log, label='Running Mean')
        plt.legend()
        fig.canvas.draw()
        plt.pause(0.01) 
        

env.close()



Using matplotlib backend: <object object at 0x00000217F2409A10>


In [7]:
plt.close()

In [32]:
env.close()

In [31]:
observation = env.reset()
rewards = 0
for step in range(1000):
    env.render()
    discrete_state = discretize_observation(observation, BINS)
    action = np.argmax(q_table[discrete_state])
    print(discrete_state)
    observation, reward, done, info = env.step(action)
    rewards += 1


env.close()


(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 5, 4)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 5, 5)
(5, 5, 6, 6)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 4)
(5, 5, 6, 5)
(5, 5, 6, 4)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 6, 6, 4)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 6, 6, 4)
(5, 5, 6, 5)
(5, 6, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 6, 6, 5)
(5, 6, 6, 4)
(5, 6, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 6, 6, 5)
(5, 6, 6, 4)
(5, 6, 6, 5)
(5, 6, 5, 5)
(5, 6, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 5)
(5, 5, 5, 6)
(5, 5, 6, 5)
(5, 5, 6, 5)
(5, 6, 6, 5)
(5, 6, 6, 5)
(5, 6, 5, 4)

IndexError: index 10 is out of bounds for axis 2 with size 10