In [2]:
import gym
import random
import numpy as np
import time
from gym.envs.registration import register
from IPython.display import clear_output

print("Gym:", gym.__version__)

Gym: 0.18.0


In [3]:
try:
    register(
        id='FrozenLakeNoSlip-v0',
        entry_point='gym.envs.toy_text:FrozenLakeEnv',
        kwargs={'map_name' : '4x4', 'is_slippery':False},
        max_episode_steps=100,
        reward_threshold=0.78, # optimum = .8196
    )
except:
    pass

env_name = "FrozenLake-v0"
env_name = "FrozenLakeNoSlip-v0"
# env_name = "CartPole-v0"
env = gym.make(env_name)
print("Observation space:", env.observation_space)
print("Action space:", env.action_space)

Observation space: Discrete(16)
Action space: Discrete(4)


In [4]:
class Agent():
    def __init__(self, env):
        self.is_discrete = type(env.action_space) == gym.spaces.discrete.Discrete
        
        if self.is_discrete:
            self.action_size = env.action_space.n
            print("Action size:", env.action_space)
        else:
            self.action_low = env.action_space.low
            self.action_high = env.action_space.high
            self.action_shape = env.action_space.shape
            print("Action range:", self.action_low, self.action_high)
            
    def get_action(self, state):
        if self.is_discrete:
            action = random.choice(range(self.action_size))
        else:
            action = np.random.uniform(self.action_low,
                                       self.action_high,
                                       self.action_shape)
                                       
        return action


In [5]:
class QAgent(Agent):
    def __init__(self, env, discount_rate = 0.97, learning_rate = 0.01):
        super().__init__(env)
        self.state_size = env.observation_space.n
        print("State size:", self.state_size)

        self.eps = 1.0
        self.discount_rate = discount_rate
        self.learning_rate = learning_rate
        self.build_model()
    
    def build_model(self):
        self.q_table = 1e-4*np.random.random([self.state_size, self.action_size])

    def get_action(self, state):
        q_state = self.q_table[state]
        action_greedy = np.argmax(q_state)
        action_random = super().get_action(state)
        return action_random if random.random() < self.eps else action_greedy

    def train(self, experience):
        state, action, next_state, reward, done = experience

        q_next = self.q_table[next_state]
        q_next = np.zeros([self.action_size]) if done else q_next
        q_target = reward + self.discount_rate * np.max(q_next)

        q_update = q_target - self.q_table[state,action]
        self.q_table[state,action] += self.learning_rate * q_update

        if done:
            self.eps = self.eps * 0.99

agent = QAgent(env)

Action size: Discrete(4)
State size: 16


In [6]:
total_reward = 0
streak = 0
for ep in range(400):
    state = env.reset()
    done = False
    #for _ in range(200):
    while not done:
        # action = env.action_space.sample()
        action = agent.get_action(state)
        next_state, reward, done, info = env.step(action)
        agent.train((state, action, next_state, reward, done))
        state = next_state
        total_reward += reward

        print("s:", state, "a:", action)
        print("Episode: {}, Total reward: {}, eps: {}, Continuous Wins: {}".format(ep,total_reward,agent.eps, streak))
        env.render()
        print(agent.q_table)
        time.sleep(0.02)
        clear_output(wait=True)
    if done and reward > 0:
        streak += 1
    else:
        streak = 0
    #env.close()

s: 15 a: 2
Episode: 399, Total reward: 279.0, eps: 0.017950553275045134, Continuous Wins: 76
  (Right)
SFFF
FHFH
FFFH
HFF[41mG[0m
[[1.00441529e-04 5.92524382e-02 4.39493555e-05 3.07129301e-04]
 [3.84850635e-04 7.83897619e-05 2.90909847e-05 3.54118741e-05]
 [1.13552485e-05 9.33453145e-05 8.81890804e-05 3.46433185e-05]
 [7.89990156e-05 8.53004778e-05 1.92451445e-05 8.63781127e-05]
 [1.25525743e-03 1.38185872e-01 6.39349354e-05 2.68601197e-04]
 [6.18711400e-05 5.90049394e-05 6.01829444e-05 2.82805969e-05]
 [4.68059003e-05 4.27369785e-06 4.60602784e-05 3.60540809e-05]
 [3.57686878e-05 6.99469029e-05 4.58455417e-05 5.61341390e-05]
 [3.73283095e-04 3.99906875e-05 2.78802715e-01 2.78270929e-03]
 [3.30042251e-03 4.89855718e-01 6.52345444e-03 4.62398086e-05]
 [1.54494451e-04 1.25253312e-01 1.28391196e-05 8.38391222e-05]
 [4.84839738e-05 4.22876436e-06 7.48124009e-05 4.75521184e-05]
 [8.47689933e-06 5.46813687e-05 1.25010425e-05 6.82895756e-06]
 [5.39818013e-05 7.89624235e-03 7.40652769e-01 7.