In [1]:
import gym
import numpy as np
from math import sqrt
import time
import random

# Globals

In [2]:

# (state,action) : [value,count]
state_action = {}
env = gym.make('CartPole-v1')
scale = 12
explore = 13

In [3]:
def get_discrete_obs(observation):
    t = (observation* scale).astype(int)
    return tuple(t)

def get_action(observation):
    t = random.random()

    action = env.action_space.sample()
    best_score = 0
    for k,v in state_action.items():
        if k[0] == observation and v[0] > best_score:
            action = k[1]
            best_score = v[0]

    if best_score > 0 and best_score/explore > t:
        return action
    else:
        return env.action_space.sample()
    

def get_episode():
    ## (state,action,reward)
    episode = []
    observation = env.reset()
    observation = get_discrete_obs(observation)
    
    while True:
        action = get_action(observation)
        new_observation, reward, done, _ = env.step(action)
        new_observation = get_discrete_obs(new_observation)
        episode.append([observation,action,reward])
        observation = new_observation
        
        if done:
            break
        
    return episode

def eval_episode(episode):
    episode.reverse()
    reward_so_far = 0
    
    for state,action,reward in episode:
#         print(state,action,reward)
        reward_so_far = reward_so_far + reward
        
        if (state,action) in state_action:
            value,count = state_action[(state,action)]
            new_val = (value * count + reward_so_far) / (count + 1)
            state_action[(state,action)] = [new_val,count + 1]
        
        else:
            state_action[(state,action)] = [reward_so_far,1]
            

            
    

In [4]:

def train():
    global exploration
    avg_len = 0
    max_len = 100
    for i in range(1,5000):
        ep = get_episode()
        eval_episode(ep)
        avg_len = avg_len + len(ep)
        
        if i%max_len == 0:
            print("Avg is",avg_len/max_len)
            avg_len = 0
        


In [None]:
train()

Avg is 48.1
Avg is 62.5
Avg is 66.83
Avg is 72.26
Avg is 80.79
Avg is 85.38
Avg is 86.3
Avg is 90.56
Avg is 89.3
Avg is 93.33
Avg is 101.76
Avg is 97.67
Avg is 104.91
Avg is 100.07
Avg is 100.09
Avg is 100.51
Avg is 100.46
Avg is 99.73
Avg is 101.2
Avg is 109.97
Avg is 112.98
Avg is 102.29
Avg is 121.43
Avg is 115.72
Avg is 115.53
Avg is 105.62
Avg is 107.78
Avg is 119.94
Avg is 108.93
Avg is 113.16
Avg is 111.09
Avg is 116.47
Avg is 118.25
Avg is 125.26
Avg is 117.33
Avg is 112.43
Avg is 119.91
Avg is 122.12
Avg is 116.72
Avg is 120.9
Avg is 123.87
Avg is 118.95
Avg is 119.86
Avg is 122.78
Avg is 126.74
Avg is 125.28
Avg is 120.04
Avg is 121.8
