In [1]:
%load_ext autoreload
%autoreload 2

import gym
import time
import random
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline

from tf_dqn_agent import Agent

In [2]:
env = gym.make('LunarLander-v2')
# env.seed(42)
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)

State shape:  (8,)
Number of actions:  4


In [3]:
state = env.reset()
print(state)

[-0.00499964  1.4194578  -0.506422    0.37943238  0.00580009  0.11471219
  0.          0.        ]


In [20]:
agent = Agent(state_size=8, action_size=4)

# watch an untrained agent
state = env.reset()
for j in range(200):
    state = np.reshape(state, [1, 8])
    action = agent.act(state)
    env.render()
    time.sleep(0.01)
    state, reward, done, _ = env.step(action)
    if done:
        break 
        
env.close()

In [26]:
episodes = 5000
scores_window = deque(maxlen=100)
scores = []
for e in range(episodes + 1):
    state = env.reset()
    state = np.reshape(state, [1, 8])
    score = 0
    for t in range(1000):
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        next_state = np.reshape(next_state, [1, 8])
        agent.step(state, action, reward, next_state, done)
        state = next_state
        score += reward
        if done:
            print("episode: {}/{}, score: {}, after time: {}".format(e, episodes, score, t))
            break
        
        if ((t+1)% 4) == 0:
            if agent.memory.__len__()>=64:
                agent.replay()
    if agent.epsilon > agent.epsilon_min:
        agent.epsilon *= agent.epsilon_decay
    scores.append(score)
    scores_window.append(score)
    if e+1 % 50 == 0:
        mean_score = np.mean(scores_window)
        print('\rEpisode {}\tAverage Score: {:.2f}'.format(e, mean_score))
        if mean_score>=-50.0:
            agent.qnetwork.save_weights("weights_{}.h5".format(e))
            break
    
    if e %200 == 0:
        agent.qnetwork.save_weights("weights_{}.h5".format(e))
        

fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(np.arange(len(scores)), scores)
plt.ylabel('Score')
plt.xlabel('Episode #')
plt.show()

episode: 0/5000, score: -78.67214089406838, after time: 149
episode: 1/5000, score: -66.46713050919372, after time: 999
episode: 2/5000, score: -48.59214115331608, after time: 82
episode: 3/5000, score: -41.67755062600905, after time: 84
episode: 4/5000, score: -32.91917036823874, after time: 92
episode: 5/5000, score: -299.1264334419036, after time: 206
episode: 6/5000, score: -232.5040359155132, after time: 658
episode: 7/5000, score: -139.85427234693515, after time: 136
episode: 8/5000, score: -61.94395226820352, after time: 155
episode: 9/5000, score: -22.008313815620895, after time: 286
episode: 10/5000, score: -51.79911894871987, after time: 172
episode: 11/5000, score: -6.243241683934897, after time: 167
episode: 12/5000, score: 8.603503515483068, after time: 999
episode: 13/5000, score: 119.51037093428216, after time: 999
episode: 14/5000, score: 79.21356983386528, after time: 999
episode: 15/5000, score: 60.748459427626464, after time: 999
episode: 16/5000, score: -110.5872136

episode: 135/5000, score: -257.00312017139413, after time: 117
episode: 136/5000, score: -242.87883583590127, after time: 119
episode: 137/5000, score: -279.31467401063276, after time: 652
episode: 138/5000, score: 70.44119291066778, after time: 999
episode: 139/5000, score: -72.42140052707201, after time: 438
episode: 140/5000, score: -6.698041469066425, after time: 168
episode: 141/5000, score: -268.25866973696714, after time: 140
episode: 142/5000, score: -69.12416184040836, after time: 81
episode: 143/5000, score: -192.06504296692162, after time: 129
episode: 144/5000, score: -159.75950240724336, after time: 106
episode: 145/5000, score: -23.50677811770433, after time: 761
episode: 146/5000, score: -68.41211210100222, after time: 408
episode: 147/5000, score: -71.81509739979896, after time: 175
episode: 148/5000, score: -266.4253426657855, after time: 224
episode: 149/5000, score: -87.51676697696075, after time: 246
episode: 150/5000, score: -366.3585067569401, after time: 244
epis

episode: 268/5000, score: 125.86665894225716, after time: 999
episode: 269/5000, score: -120.85772730998865, after time: 999
episode: 270/5000, score: -65.91157622816812, after time: 307
episode: 271/5000, score: -12.85577588874986, after time: 515
episode: 272/5000, score: 147.5629922708764, after time: 999
episode: 273/5000, score: -58.885608652663734, after time: 277
episode: 274/5000, score: -36.46066298213826, after time: 89
episode: 275/5000, score: -72.71220398656776, after time: 74
episode: 276/5000, score: -40.259575826099606, after time: 74
episode: 277/5000, score: -200.86777239444262, after time: 110
episode: 278/5000, score: -43.38839708336019, after time: 89
episode: 279/5000, score: -207.30458366900905, after time: 349
episode: 280/5000, score: -594.714286364837, after time: 577
episode: 281/5000, score: -267.3133586155896, after time: 391
episode: 282/5000, score: 29.350246693466573, after time: 502
episode: 283/5000, score: 209.8852030661699, after time: 596
episode: 2

KeyboardInterrupt: 

In [27]:
state = env.reset()
for j in range(300):
    state = np.reshape(state, [1, 8])
    action = agent.act(state)
    env.render()
    time.sleep(0.01)
    state, reward, done, _ = env.step(action)
    if done:
        break 
        
env.close()

In [23]:
agent.model.save_weights('weights.h5')