<a href="https://colab.research.google.com/github/DunkleCat/a3c-tensorflow/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Libraries


In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

import gym
import numpy as np
from threading import Thread, Lock
from multiprocessing import cpu_count

tf.keras.backend.set_floatx('float64')

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


# ActorCritic


## Actor

In [2]:
class Actor:

    def __init__(self, state_shape, action_shape):
      self.state_shape = np.ndarray(shape = (1,
                                             state_shape[0],
                                             state_shape[1],
                                             state_shape[2])).shape
      self.action_shape = action_shape
      self.model = create_model(self.state_shape, self.action_shape)
      self.opt = tf.keras.optimizers.Adam(0.0005)

    def get_action(self, state):
      tmp = np.ndarray(shape = (1,
                                1,
                                state.shape[0],
                                state.shape[1],
                                state.shape[2]))
      tmp[0][0] = state
      action_distribution = self.model.predict(tmp)
      action = np.random.choice(self.action_shape, p = action_distribution[0])
      return np.argmax(action_distribution == action)

    def train(self, states, actions, advantages):
      # print("Start Actor training with len {}".format(len(states)))
      # print(states.shape)
      # print(states[0].shape)
      tmp = np.ndarray(shape = (len(states),
                              1,
                              states[0].shape[0],
                              states[0].shape[1],
                              states[0].shape[2]))
      for k in range(len(states)):
        tmp[k][0] = states[k]

      def compute_loss(actions, action_dist, advantages):
        # Compute policy loss
        scc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True)
        policy_loss = scc(actions, action_dist, sample_weight = tf.stop_gradient(advantages))
        # Compute entropy
        cc = tf.keras.losses.CategoricalCrossentropy(from_logits = True)
        entropy = 0.01 * cc(action_dist, action_dist)
        return policy_loss - entropy

      with tf.GradientTape() as tape:
        action_dist = self.model(tmp, training = True)
        loss = compute_loss(actions, action_dist, advantages)
      grads = tape.gradient(loss, self.model.trainable_variables)
      self.opt.apply_gradients(zip(grads, self.model.trainable_variables))
      return loss

## Critic

In [3]:
class Critic:
  def __init__(self, state_shape):
    self.state_shape = np.ndarray(shape = (1,
                                           state_shape[0],
                                           state_shape[1],
                                           state_shape[2])).shape
    self.model = create_model(self.state_shape, 1)
    self.opt = tf.keras.optimizers.Adam(0.001)

  def get_values(self, states):
    # print(states.shape)
    tmp = np.ndarray(shape = (len(states),
                              1,
                              states[0].shape[0],
                              states[0].shape[1],
                              states[0].shape[2]))
    for k in range(len(states)):
      tmp[k][0] = states[k]
    # print(tmp.shape)
    return self.model.predict(tmp)

  def get_value(self, state):
    tmp = np.ndarray(shape = (1,
                              1,
                              state.shape[0],
                              state.shape[1],
                              state.shape[2]))
    tmp[0][0] = state
    return self.model.predict(tmp)


  def train(self, states, td_targets):
    # print("Start Critic training with len {}".format(len(states)))
    # print(states.shape)
    # print(states[0].shape)
    tmp = np.ndarray(shape = (len(states),
                              1,
                              states[0].shape[0],
                              states[0].shape[1],
                              states[0].shape[2]))
    for k in range(len(states)):
      tmp[k][0] = states[k]
    # print("tmp.shape = {}".format(tmp.shape))

    def compute_loss(v_pred, v_targets):
      mse = tf.keras.losses.MeanSquaredError()
      return mse(td_targets, v_pred)

    with tf.GradientTape() as tape:
      v_pred = self.model(tmp, training=True)
      # print(v_pred.shape)
      # print(td_targets.shape)
      assert v_pred.shape == td_targets.shape
      loss = compute_loss(v_pred, tf.stop_gradient(td_targets))
    grads = tape.gradient(loss, self.model.trainable_variables)
    self.opt.apply_gradients(zip(grads, self.model.trainable_variables))
    return loss

## Model

In [4]:
def create_model(input_shape, output_shape):
  # Input Layer
  model = keras.Sequential([
                            layers.InputLayer(input_shape),
                            layers.ConvLSTM2D(64,4,2, 
                                              data_format='channels_last',
                                              dropout = 0.5,
                                              recurrent_dropout = 0.5,
                                              return_sequences = True),
                            layers.BatchNormalization(),
                            # layers.ConvLSTM2D(48,4,1, 
                            #                   data_format='channels_last',
                            #                   dropout = 0.5,
                            #                   recurrent_dropout = 0.5,
                            #                   return_sequences = True),
                            # layers.BatchNormalization(),
                            layers.ConvLSTM2D(32,4,2, 
                                              data_format='channels_last',
                                              dropout = 0.5,
                                              recurrent_dropout = 0.5,
                                              return_sequences = True),
                            layers.BatchNormalization(),
                            layers.ConvLSTM2D(1,1,1, 
                                              data_format='channels_last',
                                              dropout = 0.5,
                                              recurrent_dropout = 0.5,
                                              return_sequences = False),
                            layers.BatchNormalization(),
        
                            layers.AveragePooling2D(),
                            layers.Flatten(),
                            layers.Dropout(0.5),
                            layers.Dense(128)
  ])
        
  if output_shape is 1:
    model.add(layers.Dense(output_shape, activation = "linear"))
  else:
    model.add(layers.Dense(output_shape, activation = "softmax"))
    
  return model

# Agent

In [5]:
class Agent:
    def __init__(self, env_name):
        env = gym.make(env_name)
        self.env_name = env_name
        self.state_shape = env.observation_space.shape
        self.action_shape = env.action_space.n

        self.global_actor = Actor(self.state_shape, self.action_shape)
        self.global_critic = Critic(self.state_shape)
        # self.num_workers = cpu_count()
        self.num_workers = 1

    def train(self, max_episodes = 1000):
        workers = []

        for i in range(self.num_workers):
            env = gym.make(self.env_name)
            workers.append(WorkerAgent(env, 
                                       self.global_actor,
                                       self.global_critic, 
                                       max_episodes))

        for worker in workers:
            worker.start()

        for worker in workers:
            worker.join()

# WorkerAgent

In [6]:
CUR_EPISODE = 0

class WorkerAgent(Thread):
    def __init__(self, env, global_actor, global_critic, max_episodes):
        Thread.__init__(self)
        self.lock = Lock()
        self.env = env
        self.state_shape = env.observation_space.shape
        self.action_shape = env.action_space.n

        self.max_episodes = max_episodes
        self.global_actor = global_actor
        self.global_critic = global_critic
        self.actor = Actor(self.state_shape, self.action_shape)
        self.critic = Critic(self.state_shape)

        self.actor.model.set_weights(self.global_actor.model.get_weights())
        self.critic.model.set_weights(self.global_critic.model.get_weights())

    def train(self):

        def list_to_batch(list):
            batch = list[0]

            for elem in list[1:]:
                batch = np.append(batch, elem, axis = 0)
            return batch

        def n_step_td_target(rewards, next_v_value, done):
            td_targets = np.zeros_like(rewards)
            cumulative = 0 if not done else next_v_value

            for k in reversed(range(0, len(rewards))):
                cumulative = 0.99 * cumulative + rewards[k]
                td_targets[k] = cumulative

            return td_targets

        global CUR_EPISODE

        while self.max_episodes >= CUR_EPISODE:
            state_batch = []
            action_batch = []
            reward_batch = []
            episode_reward, episode_actor_loss, episode_critic_loss, done = 0, 0, 0, False

            state = self.env.reset()

            while not done:
                # self.env.render()
                action = self.actor.get_action(state)

                next_state, reward, done, _ = self.env.step(action) 

                # state = np.reshape(state, [1, self.state_shape])
                action = np.reshape(action, [1, 1])
                # next_state = np.reshape(next_state, [1, self.state_shape])
                reward = np.reshape(reward, [1, 1])

                state_batch.append(state)
                action_batch.append(action)
                reward_batch.append(reward)

                if len(state_batch) >= 5 or done:
                    states = np.array(state_batch)
                    # actions = list_to_batch(action_batch)
                    rewards = list_to_batch(reward_batch)
                    # states = state_batch
                    actions = np.array(action_batch)
                    # rewards = np.array(reward_batch)

                    next_v_value = self.critic.get_value(next_state)
                    td_targets = n_step_td_target(rewards, next_v_value, done)
                    advantages = td_targets - self.critic.get_values(states)

                    with self.lock:
                        actor_loss = self.global_actor.train(states, actions, advantages)
                        critic_loss = self.global_critic.train(states, td_targets)
                        
                        self.actor.model.set_weights(self.global_actor.model.get_weights())
                        self.critic.model.set_weights(self.global_critic.model.get_weights())

                        episode_actor_loss += actor_loss
                        episode_critic_loss += critic_loss
        
                    state_batch = []
                    action_batch = []
                    reward_batch = []

                episode_reward += reward
                state = next_state

            print("EP{}: Reward = {}, Actor Loss = {}, Critic Loss = {}".format(CUR_EPISODE, episode_reward, episode_actor_loss, episode_critic_loss))
            # wandb.log({'Reward': episode_reward})
            CUR_EPISODE += 1

    def run(self):
        self.train()

# Entrypoint

In [None]:
def getNumberOfWorkers():
    return cpu_count()

def getMaxEpisodes():
    return 1000

def environment():
    return "Pong-v0"

def main():
    env_name = "Pong-v0"
    agent = Agent(env_name)
    agent.train()

if __name__ == "__main__":
    main()

EP0: Reward = [[-21.]], Actor Loss = 25.70464836622123, Critic Loss = 190.54473880678415
EP1: Reward = [[-21.]], Actor Loss = 2.078670869753698, Critic Loss = 44.56169947050512
EP2: Reward = [[-21.]], Actor Loss = -27.527088342070343, Critic Loss = 34.52162306010723
EP3: Reward = [[-21.]], Actor Loss = -15.171518729516832, Critic Loss = 26.13019384117797
EP4: Reward = [[-21.]], Actor Loss = -26.70156973910518, Critic Loss = 24.463142310269177
EP5: Reward = [[-21.]], Actor Loss = -21.974221642185007, Critic Loss = 22.49769447534345
EP6: Reward = [[-21.]], Actor Loss = 24.044484234116972, Critic Loss = 18.790567808551714
EP7: Reward = [[-21.]], Actor Loss = -28.70336722143388, Critic Loss = 17.655923280864954
EP8: Reward = [[-21.]], Actor Loss = 54.03293051749467, Critic Loss = 16.712781139533035
EP9: Reward = [[-21.]], Actor Loss = 30.869076483903935, Critic Loss = 15.372003926895559
