# Setup

Game: https://gym.openai.com/envs/Acrobot-v1/

In [1]:
import tensorflow 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam


import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
mpl.rc('animation', html='jshtml')

import random
import numpy as np

import gym

In [2]:
try:
    import pyvirtualdisplay
    display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
except ImportError:
    pass

# Network

In [3]:
gamma = 0.95
alpha = 0.50
learning_rate = 0.01
epsilon = 0.999
epsilon_decay = 0.90

class DQN:

    def __init__(self, observation_space, action_space):
        
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.action_space = action_space
        self.observation_space = observation_space
        
        self.memory = []
        self.batch_size = 8

        self.model = Sequential()
        self.model.add(Dense(32, input_shape=(observation_space,), activation="selu", kernel_initializer='lecun_normal'))
        self.model.add(Dense(64, activation="selu", kernel_initializer='lecun_normal'))
        self.model.add(Dense(16, activation="selu", kernel_initializer='lecun_normal'))
        self.model.add(Dense(8, activation="selu", kernel_initializer='lecun_normal'))
        self.model.add(Dropout(0.1))
        self.model.add(BatchNormalization())
        self.model.add(Dense(self.action_space, activation="linear"))
        self.model.compile(loss="mse", optimizer=Adam(lr=learning_rate))
    
    def get_scores(self):
        return self.scores
    
    def get_rewards(self):
        return self.rewards
        
    def act(self, state):
        if np.random.rand() < self.epsilon:
            return random.randrange(self.action_space)
        q = self.model.predict(state)
        return np.argmax(q[0])

    def experience_replay(self):
        print(self.memory,self.batch_size)
        mem_sample = random.sample(self.memory, min(len(self.memory), self.batch_size))
        for state, action, reward, next_state, done in mem_sample:
            update_value = reward
            print(update_value)
            if not done:
                update_value = self.alpha * (reward + self.gamma * np.max(self.model.predict(next_state)[0]))
            q = self.model.predict(state)
            q[0][action] = update_value
            self.model.fit(np.array(state), q, verbose=0)
        self.epsilon *= epsilon_decay
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

# Run & Train

In [4]:
stop = 0
frames = []

def run_game():
    env = gym.make("Acrobot-v1")
    env.seed(21)
    observation_space, action_space = env.observation_space.shape[0], env.action_space.n
    epoch = 0
    dqn = DQN(observation_space, action_space)

    for step in range(25):
        score = 0
        epoch += 1
        state = env.reset()
        state = np.reshape(state, [1, observation_space])
        while True:
            img = env.render(mode="rgb_array")
            score += 1
            action = dqn.act(state)
            next_state, reward, done, info = env.step(action)
            reward = reward if not done else -reward
            next_state = np.reshape(next_state, [1, observation_space])
            frames.append(img)
            dqn.experience_replay()
            state = next_state
            if score > 5:
                print ("Round: " + str(epoch) + " Score: " + str(score))
                break
            dqn.remember(state, action, reward, next_state, done)
    return dqn


In [5]:
dqn = run_game()

[] 8
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False)] 8
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False)] 8
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.0841

        -0.47259453]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

         0.42161227]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.80557986]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.56951552]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.78040032]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.40233166]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.45273089]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.582238  ]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

         0.18332289]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.74379977]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
[(array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), 1, -1.0, array([[0.99998879, 0.00473563, 0.99878805, 0.04921818, 0.11483555,
        0.02413791]]), False), (array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), 2, -1.0, array([[ 0.99990384,  0.01386727,  0.99645336,  0.08414691, -0.02402284,
         0.32086768]]), False), (array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), 2, -1.0, array([[ 0.99999336, -0.00364447,  0.98511163,  0.17191587, -0.14446939,
         0.54589864]]), False), (array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), 0, -1.0, array([[ 0.99990297, -0.01393023,  0.9747388 ,  0.22334786,  0.04454926,
        -0.03025746]]), False), (array([[ 0.99990997,  0.01341854,  0.98698371,  0.16082028,  0.22162779,
        -0.5922

        -0.55652379]]), False)] 8
-1.0
-1.0
-1.0
-1.0
-1.0


KeyboardInterrupt: 

# Results

In [None]:
scores = dqn.get_scores()

In [None]:
plt.plot(scores)
plt.xlabel('Round')
plt.ylabel('Score')
plt.title('Total game score by round')
plt.show()

In [None]:
def update_scene(num, frames, patch):
    patch.set_data(frames[num])
    return patch,

def plot_animation(frames, repeat=False, interval=20):
    fig = plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis('off')
    anim = animation.FuncAnimation(
        fig, update_scene, fargs=(frames, patch),
        frames=len(frames), repeat=repeat, interval=interval)
    plt.close()
    return anim

In [None]:
plot_animation(frames)