In [1]:
import numpy as np # used for arrays
import gym # pull the environment
import time # to get the time
import math # needed for calculations

In [2]:
env = gym.make("CartPole-v1")
print(env.action_space.n)

2


In [3]:
LEARNING_RATE = 0.1
DISCOUNT = 0.95
RUNS = 1000  
SHOW_EVERY = 200  
UPDATE_EVERY = 100  

# Exploration settings
epsilon = 1  # not a constant, going to be decayed
START_EPSILON_DECAYING  = 1
END_EPSILON_DECAYING = RUNS // 2
epsilon_decay_value = epsilon / (END_EPSILON_DECAYING - START_EPSILON_DECAYING)

In [4]:
?env.env

In [5]:
# Create bins and Q table
def create_bins_and_q_table():
    # env.observation_space.high
    # [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]
    # env.observation_space.low
    # [-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]
    
    # remove hard coded Values when I know how to
    
    numBins = 20
    obsSpaceSize = len(env.observation_space.high)
    
    # Get the size of each bucket
    bins = [np.linspace(-4.8, 4.8, numBins),
        np.linspace(-4, 4, numBins),
        np.linspace(-.418, .418, numBins),
        np.linspace(-4, 4, numBins)]
    qTable = np.random.uniform(low=-2, high=0, size=([numBins] * obsSpaceSize + [env.action_space.n]))
    
    return bins, obsSpaceSize, qTable

In [6]:
def get_discrete_state(state, bins, obsSpaceSize):
    stateIndex = []
    for i in range(obsSpaceSize):
        stateIndex.append(np.digitize(state[i], bins[i]) - 1) # -1 will turn bin into index
    return tuple(stateIndex)

In [7]:
bins, obsSpaceSize, qTable = create_bins_and_q_table()

previousCnt = []  # array of all scores over runs
metrics = {'ep': [], 'avg': [], 'min': [], 'max': []}  # metrics recorded for graph

In [8]:
for run in range(RUNS):
    discreteState = get_discrete_state(env.reset(), bins, obsSpaceSize)
    done = False  # has the enviroment finished?
    cnt = 0  # how may movements cart has made
    
    while not done:
        if run % SHOW_EVERY == 0:
            env.render()  #if running RL comment this out
        cnt += 1
        # Get action from Q table
        if np.random.random() > epsilon:
            action = np.argmax(qTable[discreteState])
        # Get random action
        else:
            action = np.random.randint(0, env.action_space.n)
        newState, reward, done, _ = env.step(action)  # perform action on enviroment
        
        newDiscreteState = get_discrete_state(newState, bins, obsSpaceSize)
        
        maxFutureQ = np.max(qTable[newDiscreteState])  # estimate of optiomal future value
        currentQ = qTable[discreteState + (action, )]  # old value
        
        # pole fell over / went out of bounds, negative reward
        if done and cnt < 200:
            reward = -375
        
        # formula to caculate all Q values
        newQ = (1 - LEARNING_RATE) * currentQ + LEARNING_RATE * (reward + DISCOUNT * maxFutureQ)
        qTable[discreteState + (action, )] = newQ  # Update qTable with new Q value
        
        discreteState = newDiscreteState
    previousCnt.append(cnt)
    
    # Decaying is being done every run if run number is within decaying range
    if END_EPSILON_DECAYING >= run >= START_EPSILON_DECAYING:
        epsilon -= epsilon_decay_value
    # Add new metrics for graph
    if run % UPDATE_EVERY == 0:
        latestRuns = previousCnt[-UPDATE_EVERY:]
        averageCnt = sum(latestRuns) / len(latestRuns)
        metrics['ep'].append(run)
        metrics['avg'].append(averageCnt)
        metrics['min'].append(min(latestRuns))
        metrics['max'].append(max(latestRuns))
        print("Run:", run, "Average:", averageCnt, "Min:", min(latestRuns), "Max:", max(latestRuns))



Run: 0 Average: 18.0 Min: 18 Max: 18
Run: 100 Average: 22.67 Min: 9 Max: 61
Run: 200 Average: 34.39 Min: 9 Max: 104
Run: 300 Average: 41.01 Min: 12 Max: 109
Run: 400 Average: 58.88 Min: 16 Max: 126
Run: 500 Average: 83.65 Min: 27 Max: 130
Run: 600 Average: 101.09 Min: 56 Max: 135
Run: 700 Average: 110.89 Min: 68 Max: 179
Run: 800 Average: 110.76 Min: 59 Max: 192
Run: 900 Average: 113.7 Min: 61 Max: 149
