<a href="https://colab.research.google.com/github/Zhi-704/ERL/blob/master/DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install jumanji

In [None]:
import jumanji
import jax.numpy as jnp
import jax
import random
from collections import namedtuple, deque
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import Sequential
from tensorflow.python.keras.layers import Dense
import numpy as np
import matplotlib.pyplot as plt
from time import sleep

In [None]:
# Main framework taken from https://github.com/ultronify/dqn-from-scratch-with-tf2/blob/master/

class ReplayBuffer:

  def __init__(self):
    self.memory = deque(maxlen=1000)
    self.capacity = 1000

  def __len__(self):
    return len(self.memory)

  def store(self, state, next_state, reward, action, done):
    '''
    Records a single step of game play experience
    PARAM -
    state: current game state
    next_state: game state after taking action
    reard: reward taking action at the current state brings
    action: action taken at the current state
    done: boolean to indicate if game is finished after taking action
    RETURNS - N/A
    '''
    if len(self.memory) > self.capacity:
      del self.memory[0]
    self.memory.append((state, next_state, reward, action, done))


  def sample(self):
    '''
    Samples a batch of gameplay experiences for training
    PARAM - None
    RETURNS - list of gameplay experiences
    '''
    batch_size = min(128, len(self.memory))
    sample_batch = random.sample(self.memory, batch_size)
    state_batch = []
    next_state_batch = []
    reward_batch = []
    action_batch = []
    done_batch = []
    for experience in sample_batch:
      state_batch.append(experience[0])
      next_state_batch.append(experience[1])
      reward_batch.append(experience[2])
      action_batch.append(experience[3])
      done_batch.append(experience[4])

    return  np.array(state_batch), np.array(next_state_batch), np.array(reward_batch), np.array(action_batch), np.array(done_batch)

In [None]:
class DqnAgent:
  '''
  Create DQN agent class
  '''

  def __init__(self):
    self.q_net = self.build_dqn_model()
    self.target_q_net = self.build_dqn_model()

  @staticmethod
  def build_dqn_model():
    '''
    Builds deep neural network to predict Q values for all possible actions given a state.
    Input should have shape of the state and the output should have the same shape as action space
    RETURNS - Q network
    '''
    q_net = Sequential()
    # Adds fully connected layer with 128 units and uses rectified linear unit activation function. he_uniform initliazes weight of layer
    q_net.add(Dense(128, input_dim = 105, activation='relu', kernel_initializer='he_uniform'))
    q_net.add(Dense(64, activation='relu', kernel_initializer='he_uniform'))
    # 40 actions so 40 different outputs
    q_net.add(Dense(40, activation='linear', kernel_initializer='he_uniform'))

    # opt = keras.optimizers.Adam(learning_rate=0.001)
    q_net.compile(loss='mse', optimizer = 'adam')
    return q_net

  def convert_state(self,state):
    '''
    Convert state into observation variable that can be passed into neural network
    PARAM -
    state: current game state
    RETURNS - observation variable
    '''

    grid = state.grid_padded.flatten().tolist()
    tetromino = state.tetromino_index.flatten().tolist()
    # print(grid)
    # print(tetromino)
    obs_variable = np.asarray(grid+tetromino)
    # print("Current input dimensions:")
    # print(np.shape(obs_variable))
    # print(obs_variable)

    return obs_variable


  def policy(self,state):
    '''
    Takes state from environment and returns an action that has the highest q value using epsilon_greedy
    PARAM -
    state: current game state
    RETURNS - action
    '''
    # Matches state into 105, array
    state_array = self.convert_state(state)
    # Convert to into tensorflow tensor
    state_input = tf.convert_to_tensor(state_array[None, :], dtype=tf.float32)
    # Grabs Q values for all possible actions in current state
    action_q = self.q_net(state_input)

    action_mask = np.array(state.action_mask).flatten()

    valid_q_values = action_q * action_mask
    # Set elements that are zero to -999
    valid_q_values = np.where(valid_q_values == 0, -99999999, valid_q_values)

    # print("ACTION MASK")
    # print(action_mask)
    # print("CORRESPONDING Q VALUES")
    # print(valid_q_values)
    # print(np.shape(valid_q_values))
    true_indices = self.get_valid_actions(state)
    print(true_indices)
    if len(true_indices) == 0:
      return False, False
    else:
      action_index = np.argmax(valid_q_values, axis=1)[0]
      # Convert action into acceptable type
      array = np.array(state.action_mask)
      rotation, col_index = np.unravel_index(action_index, array.shape)
      action = [rotation, col_index]
      return action, action_index

  def get_valid_actions(self,state):
    array = np.array(state.action_mask)
    # Flatten the array
    flattened_array = array.flatten()
    # Get the indices of true values, which are the number of valid actions in the current state
    true_indices = np.where(flattened_array)[0]

    return true_indices

  def random_action(self, state):
    array = np.array(state.action_mask)
    true_indices = self.get_valid_actions(state)
    if len(true_indices) == 0:
      return False
    random_index = np.random.choice(true_indices)
    # Convert the random index to two-dimensional coordinates
    rotation, col_index = np.unravel_index(random_index, array.shape)
    return [rotation,col_index]

  def update_network(self):
    '''
    Updates current q network with q_net which brings all the training in q_net with target_q_net
    '''
    self.target_q_net.set_weights(self.q_net.get_weights())

  def train(self,batch):
    '''
    Trains underlying network with batch of gameplay experineces to help it predict Q values
    PARAM -
    Batch: batch of experiences
    RETURNS: Traning loss
    '''
    # Copying the batch over
    state_batch, next_state_batch, reward_batch, action_batch, done_batch = batch

    print(state_batch)
    print(next_state_batch)
    print("REWARD")
    print(reward_batch)
    print("ACTION")
    print(action_batch)
    print(done_batch)

    # Running states through the q_net gives output Q values for the states
    current_q = self.q_net(state_batch).numpy()
    print("Printing current q")
    print(current_q)
    # Copy over Q values for actions that weren't chosen
    target_q = np.copy(current_q)
    print("Printing target q")
    print(np.shape(target_q))
    # Get the max Q values of states after transition by running next_state through target_q_net and take max Q values for all actions for each sample
    next_q = self.target_q_net(next_state_batch).numpy()
    max_next_q = np.amax(next_q, axis=1)
    print("Entering loop")
    print("Printing next q")
    print(next_q)
    print("Printing max next q")
    print(max_next_q)
    # Update Q value of action taken with max Q value of next state plus intermediate reward from the action taken
    for i in range(state_batch.shape[0]):
      target_q_val = reward_batch[i].astype(float)
      action_index = action_batch[i]
      print(action_index)
      if not done_batch[i]:
        target_q_val += 0.95 * max_next_q[i]
      target_q[i][action_index] = target_q_val
    print("Finishing loop")
    # Train q_net with target Q values
    training_his = self.q_net.fit(x = state_batch, y=target_q)
    loss = training_his.history['loss']
    print("Exiting train")
    return loss

In [None]:
def evaluate_training(env, agent):
  '''
  Evaluates performance of DQN agent and calculates average reward
  PARAM -
  env: game environment
  agent: DQN agent
  RETURNS: Average reward across episodes
  '''
  total_reward = 0.0
  episodes_to_play = 6
  for i in range(episodes_to_play):
    key = jax.random.PRNGKey(1)
    state, timestep = jax.jit(env.reset)(key)
    done = False
    episode_reward = 0.0
    while not done:
      action, action_index = agent.policy(state)
      if action is False:
        done = True
        break
      next_state, next_timestep = jax.jit(env.step)(state, action)
      episode_reward += next_state.reward
      state = next_state
    total_reward += episode_reward
  average_reward = total_reward / episodes_to_play
  return average_reward

def collect_experiences(env, agent, buffer):
  '''
  Collect gameplay experiences by playing with env and store experiences in buffer
  '''
  key = jax.random.PRNGKey(1)
  state, timestep = jax.jit(env.reset)(key)
  done = False
  terminal = False


  print("Entering Loop")
  while not done:
    action, action_index = agent.policy(state)
    print(action)
    # Testing
    #action = agent.random_action(state)
    if action is False:
      done = True
      terminal = True
      break
    print(state.action_mask)
    print(action)
    next_state, next_timestep = jax.jit(env.step)(state, action)
    print("Action taken")
    # env.render(state)
    print(action)
    buffer.store(agent.convert_state(state), agent.convert_state(next_state), next_state.reward, action_index, terminal)
    state = next_state

    # sleep(5)

    env.render(state)

def train_model(max_episodes = 10):
  '''
  Trains DQN agent to play game
  RETURNS: None
  '''

  agent = DqnAgent()
  buffer = ReplayBuffer()
  # Instantiate tetris environment using registry
  env = jumanji.make('Tetris-v0', num_rows = 5, time_limit = 1000)
  # env = jumanji.make('Tetris-v0')


  for episode_cnt in range(max_episodes):
    collect_experiences(env, agent, buffer)
    gameplay_batch = buffer.sample()
    loss = agent.train(gameplay_batch)
    print('So far the loss is {0}'.format(loss))
    avg_reward = evaluate_training(env, agent)
    print('So far the performance is {0}'.format(avg_reward))
    # Update target q net every __ episodes (currently 2)
    if episode_cnt % 2 == 0:
      agent.update_network()
#    sleep(5)

In [None]:
train_model()
print('No problems')