In [1]:
import gym
import numpy as np
from tqdm import trange

In [4]:
class Q:
    
    def __init__(self, env, lr, discount):
        self.env = env
        self.lr = lr
        self.discount = discount
        self.q_table_shape = [50] * len(env.observation_space.high) # dimensions of the q table
        self.q_table_category_size = (env.observation_space.high - env.observation_space.low)/self.q_table_shape
        
    def get_discrete_state(self, state):
        discrete_state = (state - self.env.observation_space.low)/self.q_table_category_size
        return tuple(discrete_state.astype(np.int))  # we use this tuple to look up the 3 Q values for the available actions in the q-table
    
    def train(self, episodes, epsilon, render_every):
        EPSILON_START_DECAY = 1
        EPSILON_END_DECAY = episodes // 2
        EPSILON_DECAY_VALUE = epsilon / (EPSILON_END_DECAY - EPSILON_START_DECAY)
        
        q_table = np.random.uniform(low=-2, high=0, size=(self.q_table_shape + [self.env.action_space.n]))
        for episode in (t := trange(episodes)):
            discrete_state = self.get_discrete_state(self.env.reset())
            done = False

            while not done:

                if np.random.random() > epsilon:
                    # Get action from Q table
                    action = np.argmax(q_table[discrete_state])
                else:
                    # Get random action
                    action = np.random.randint(self.env.action_space.n)


                new_state, reward, done, _ = self.env.step(action)
                
                #if episode % render_every == 0:
                    #self.env.render()

                new_discrete_state = self.get_discrete_state(new_state)

                # If simulation did not end yet after last step - update Q table
                if not done:

                    # Maximum possible Q value in next step (for new state)
                    max_future_q = np.max(q_table[new_discrete_state]) # returns the highest quality next move.

                    # Current Q value (for current state and performed action)
                    current_q = q_table[discrete_state + (action,)] # q value at this point in time.

                    # And here's our equation for a new Q value for current state and action
                    new_q = (1 - self.lr) * current_q + self.lr * (reward + self.discount * max_future_q)

                    # Update Q table with new Q value
                    q_table[discrete_state + (action,)] = new_q


                # Simulation ended (for any reson) - if goal position is achived - update Q value with reward directly
                elif new_state[0] >= self.env.goal_position:
                    #q_table[discrete_state + (action,)] = reward
                    q_table[discrete_state + (action,)] = 0
                    print('Goal reached')

                discrete_state = new_discrete_state

            # Decaying is being done every episode if episode number is within decaying range
            if EPSILON_END_DECAY >= episode >= EPSILON_START_DECAY:
                epsilon -= EPSILON_DECAY_VALUE
            t.set_description(str(epsilon))
        env.close()

In [15]:
env = gym.make("MountainCar-v0")
model = Q(env, 0.1, 0.95)

In [16]:
model.train(1000, 0, 10001)

0.0:   7%|█████▏                                                                   | 708/10000 [00:20<05:00, 30.94it/s]

Goal reached


0.0:   7%|█████▎                                                                   | 720/10000 [00:20<05:03, 30.57it/s]

Goal reached


0.0:   8%|█████▌                                                                   | 760/10000 [00:22<06:26, 23.90it/s]

Goal reached


0.0:   8%|█████▌                                                                   | 770/10000 [00:22<05:33, 27.64it/s]

Goal reached


0.0:   8%|██████▏                                                                  | 846/10000 [00:25<04:35, 33.27it/s]

Goal reached


0.0:   9%|██████▏                                                                  | 854/10000 [00:25<05:04, 30.02it/s]

Goal reached


0.0:   9%|██████▍                                                                  | 881/10000 [00:26<04:31, 33.54it/s]


KeyboardInterrupt: 