In [3]:
import gym
import numpy as np
import random
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

Using TensorFlow backend.


In [4]:
ENV_NAME = 'Acrobot-v1'

GAMMA = .5
LEARNING_RATE = .001

MEMORY_SIZE = 1000000
BATCH_SIZE = 20

EXPLORATION_MAX = 1
EXPLORATION_MIN = .01
EXPLORATION_DECAY = .99995

class DQNSolver:
    
    def __init__(self, observation_space, action_space):
        self.explore_rate = EXPLORATION_MAX
        
        self.action_space = action_space
        self.memory = deque(maxlen=MEMORY_SIZE)
        
        self.model = Sequential()
        self.model.add(Dense(24, input_shape=(observation_space,), activation="relu"))
        self.model.add(Dense(24, activation="relu"))
        self.model.add(Dense(self.action_space, activation="linear"))
        self.model.compile(loss="mse", optimizer=Adam(lr=LEARNING_RATE))
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        
    def act(self, state):
        if np.random.rand() < self.explore_rate:
            return random.randrange(self.action_space)
        q_values = self.model.predict(state)
        return np.argmax(q_values)
        
    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = random.sample(self.memory, BATCH_SIZE)
        for state, action, reward, state_next, done in batch:
            q_update = reward
            if not done:
                q_update = (reward + GAMMA * np.argmax(self.model.predict(state_next)[0]))
            q_values = self.model.predict(state)
            q_values[0][action] = q_update
            self.model.fit(state, q_values, verbose=0)
        self.explore_rate *= EXPLORATION_DECAY
        self.explore_rate = max(self.explore_rate, EXPLORATION_MIN)
    

In [5]:
def acrobot():
    env = gym.make(ENV_NAME)
    observation_space = env.observation_space.shape[0]
    action_space = env.action_space.n
    dqn_solver = DQNSolver(observation_space, action_space)
    run = 0
    while run < 200:
        run += 1
        state = env.reset()
        state = np.reshape(state, [1, observation_space])
        step = 0
        progress = 0
        penalties = 0
        while True:
            step += 1
            env.render()
            action = dqn_solver.act(state)
            state_next, reward, done, info = env.step(action)
            if reward > 0: 
                progress += reward
            else:
                penalties += reward
            state_next = np.reshape(state_next, [1, observation_space])
            dqn_solver.remember(state, action, reward, state_next, done)
            state = state_next
            if done:
                print(f"Run: {run} exploration: {dqn_solver.explore_rate}")
                print(f"Progress: {progress}, Penalties: {penalties}")
                break
            dqn_solver.experience_replay()
    
acrobot()

W0717 13:51:09.970468 140046005827392 deprecation_wrapper.py:119] From /home/tomas/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0717 13:51:09.986757 140046005827392 deprecation_wrapper.py:119] From /home/tomas/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0717 13:51:09.988723 140046005827392 deprecation_wrapper.py:119] From /home/tomas/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0717 13:51:10.037181 140046005827392 deprecation_wrapper.py:119] From /home/tomas/anaconda3/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0717 13:51:

Run: 1 exploration: 0.9762851239671341
Progress: 0, Penalties: -500.0
Run: 2 exploration: 0.9522275746171881
Progress: 0, Penalties: -500.0
Run: 3 exploration: 0.9287628497060423
Progress: 0, Penalties: -500.0
Run: 4 exploration: 0.9058763408955779
Progress: 0, Penalties: -500.0
Run: 5 exploration: 0.883553799825314
Progress: 0, Penalties: -500.0
Run: 6 exploration: 0.8617813292418679
Progress: 0, Penalties: -500.0
Run: 7 exploration: 0.8405453743469987
Progress: 0, Penalties: -500.0
Run: 8 exploration: 0.8198327143588462
Progress: 0, Penalties: -500.0
Run: 9 exploration: 0.7996304542811299
Progress: 0, Penalties: -500.0
Run: 10 exploration: 0.7799260168751616
Progress: 0, Penalties: -500.0
Run: 11 exploration: 0.7607071348296816
Progress: 0, Penalties: -500.0
Run: 12 exploration: 0.7419618431236511
Progress: 0, Penalties: -500.0
Run: 13 exploration: 0.7236784715772421
Progress: 0, Penalties: -500.0
Run: 14 exploration: 0.7058456375863705
Progress: 0, Penalties: -500.0
Run: 15 explorat

Run: 116 exploration: 0.056069009753439716
Progress: 0, Penalties: -500.0
Run: 117 exploration: 0.054687361159159405
Progress: 0, Penalties: -500.0
Run: 118 exploration: 0.05333975905234998
Progress: 0, Penalties: -500.0
Run: 119 exploration: 0.05202536446186214
Progress: 0, Penalties: -500.0
Run: 120 exploration: 0.05074335909041462
Progress: 0, Penalties: -500.0
Run: 121 exploration: 0.04949294480515025
Progress: 0, Penalties: -500.0
Run: 122 exploration: 0.04827334314074544
Progress: 0, Penalties: -500.0
Run: 123 exploration: 0.04708379481476431
Progress: 0, Penalties: -500.0
Run: 124 exploration: 0.04592355925495562
Progress: 0, Penalties: -500.0
Run: 125 exploration: 0.044791914138197234
Progress: 0, Penalties: -500.0
Run: 126 exploration: 0.0436881549408027
Progress: 0, Penalties: -500.0
Run: 127 exploration: 0.042611594499908576
Progress: 0, Penalties: -500.0
Run: 128 exploration: 0.041561562585670446
Progress: 0, Penalties: -500.0
Run: 129 exploration: 0.040537405484000505
Prog