In [None]:
import os
import threading
import multiprocessing
import numpy as np
import tensorflow as tf

from worker import Worker
from ac_network import AC_Network

In [None]:
ENV_NAME = 'CartPole-v0'
STATE_DIM = 4
ACTION_DIM = 2
MONITOR_DIR = './results/' + ENV_NAME

In [None]:
RANDOM_SEED = 1234
LOAD_MODEL = False
TEST_MODEL = False
MODEL_DIR = './model/'
LEARNING_RATE = 0.0001
GAMMA = 0.99

In [None]:
global master_network
global global_episodes

tf.reset_default_graph()

if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

with tf.device("/cpu:0"):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)

    global_episodes = tf.Variable(0, dtype=tf.int32, name='global_episodes', trainable=False)
    trainer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
    master_network = AC_Network(STATE_DIM, ACTION_DIM, 'global', None)
    num_workers = multiprocessing.cpu_count()

    if TEST_MODEL:
        num_workers = 1

    workers = []
    for i in range(num_workers):
        workers.append(Worker(i, STATE_DIM, ACTION_DIM, trainer, MODEL_DIR, global_episodes,
                              ENV_NAME, RANDOM_SEED, TEST_MODEL))
    saver = tf.train.Saver(max_to_keep=5)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    if LOAD_MODEL or TEST_MODEL:
        print('Loading Model...')
        ckpt = tf.train.get_checkpoint_state(MODEL_DIR)
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())

    if TEST_MODEL:
        env = workers[0].get_env()
        env.monitor.start(MONITOR_DIR, force=True)
        workers[0].work(GAMMA, sess, coord, saver)
    else:
        worker_threads = []
        for worker in workers:
            worker_work = lambda: worker.work(GAMMA, sess, coord, saver)
            t = threading.Thread(target=(worker_work))
            t.start()
            worker_threads.append(t)
        coord.join(worker_threads)