## Experiment 2: Simple pokemon battle (AKA advanced rock paper scissors)
### Create an agent to choose the best move given its move types and its opponent's type

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import csv
import ast
import math

In [None]:
# RUN TO CREATE CLEANED CSV FILES
""" pokedf = None
with open('pokemon-data.csv') as csvfile:
    reader = csv.reader(csvfile, delimiter=' ', quotechar='|')
    pokemon = []
    colNames = []
    for row in reader:
        pokemon.append(' '.join(row).split(';'))
    colNames = pokemon[0]
    pokemon.pop(0)
    pokedf = pd.DataFrame(pokemon, columns=colNames)

pokedf.drop(['Tier', 'Next Evolution(s)'], axis=1, inplace=True)
pokedf.to_csv('pokemon-data-clean.csv', index=False) """

""" movedf = pd.read_csv('move-data.csv')
movedf.drop(['Index', 'Generation'], axis=1, inplace=True)
movedf.to_csv('move-data-clean.csv', index=False)
 """

In [None]:
def cleanup(arr):
    newArr = []
    for move in arr:
        newName = (move.replace('-', '')).replace("'", '') # Remove dashes
        if movedf[movedf['Name'] == newName]['Category'].values[0] == 'Status': # Remove status moves
            continue

        # Remove moves with no power or accuracy
        if np.isnan(movedf[movedf['Name'] == newName]['Power'].values[0]) or np.isnan(movedf[movedf['Name'] == newName]['Accuracy'].values[0]):
            continue
        
        newArr.append(newName) 
    return newArr 

def standardize(row):
    factor = 600 / (row['HP'] + row['Attack'] + row['Defense'] + row['Special Attack'] + row['Special Defense'] + row['Speed'])
    row['HP'] = math.floor(factor * row['HP'])
    row['Attack'] = math.floor(factor * row['Attack'])
    row['Defense'] = math.floor(factor * row['Defense'])
    row['Special Attack'] = math.floor(factor * row['Special Attack'])
    row['Special Defense'] = math.floor(factor * row['Special Defense'])
    row['Speed'] = math.floor(factor * row['Speed'])
    return row

movedf = pd.read_csv('move-data-clean.csv')
movedf['Name'] = movedf['Name'].apply(lambda x: (x.replace('-', '')).replace("'", '')) # Remove dashes (like in Double-Edge)
movedf['Power'] = movedf['Power'].apply(lambda x: int(x) if x != 'None' else None)
movedf['Accuracy'] = movedf['Accuracy'].apply(lambda x: int(x) if x != 'None' else None)

pokedf = pd.read_csv('pokemon-data-clean.csv')
for col in ['Types', 'Abilities', 'Moves']: # Turn the types, abilities, and moves from a string to a list
    pokedf[col] = pokedf[col].apply(ast.literal_eval)
pokedf = pokedf.apply(standardize, axis=1)
pokedf['Moves'] = pokedf['Moves'].apply(cleanup)

typedf = pd.read_csv('type-data-clean.csv', index_col=0)

allTypes = ['Normal', 'Fire', 'Water', 'Electric', 'Grass', 'Ice', 'Fighting', 'Poison', 'Ground', 'Flying', 'Psychic', 'Bug', 'Rock', 'Ghost', 'Dragon', 'Dark', 'Steel', 'Fairy']
allCategories = ['Physical', 'Special']

In [None]:
# Verify dataframe is formatted correctly
print(pokedf)

In [None]:
from __future__ import absolute_import, division, print_function
import abc
import tensorflow as tf
import numpy as np
from sklearn.feature_extraction import DictVectorizer
import random
import math

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

from tf_agents.agents.dqn import dqn_agent
from tf_agents.networks import q_network
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import tf_py_environment
from tf_agents.environments import py_environment
#from tf_agents.environments import trajectory
from tf_agents.environments import wrappers
#from tf_agents.metrics import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_episode_driver

tf.compat.v1.enable_v2_behavior()

In [None]:
# Battle helper functions (Ignoring abilities and pp)
# Pokemon: [types, hp, attack, defense, special attack, special defense, speed, move1type, move1category, ..., move1accuracy, move2type, ..., move4accuracy]
# Moves: [type, category, pp, power, accuracy]
# one-hot encode: types, movetype, category 
""" 
print(movedf.loc[0].values)
print(typedf.loc['Water', 'Fire']) """
def calcDamage(poke1, poke2, move):
    move = poke1[24+23*move:24+23*(move+1)]
    attack = poke1[19] if move[18] == 'Physical' else poke1[21]
    defense = poke2[20] if move[18] == 'Physical' else poke2[22]
    modifier = 2
    types = onehotToType(poke2[:18])
    for type in types:
        modifier *= typedf.loc[onehotToType(move[:18]), type]
    damage = ((2/5 + 2) * move[20] * attack / defense / 50 + 2) * modifier # * (move[21] > np.random.randint(1, 101))
    return math.ceil(damage)

def typeToOneHot(types): # Input: list of types, output: one hot array where types are 1, else 0
    onehot = [0] * len(allTypes)
    for type in types:
        onehot[allTypes.index(type)] = 1
    
    return onehot

def onehotToType(onehot): # Input onehot of types, output: list of types
    onehot = np.array(onehot)
    types = np.array(allTypes)
    return types[np.where(onehot == 1)]

def randPokemon():
    poke = pokedf[:20].sample(1).values[0]
    poke = np.delete(poke, 2) # Remove abilties
    poke[8] = random.sample(poke[8], 4) # Choose 4 random moves for the pokemon
    return poke


def moveToVector(move): # Input: move name, output: move vector (movetype (one-hot), move category (one-hot), pp, power, accuracy)
    move = movedf[movedf['Name'] == move]
    moveType = typeToOneHot([move['Type'].values[0]])
    moveCategory = [1, 0] if move['Category'].values[0] == 'Physical' else [0, 1]
    moveOthers = [move['PP'].values[0], move['Power'].values[0], move['Accuracy'].values[0]]
    return np.concatenate((moveType, moveCategory, moveOthers))


def pokeToVector(poke):
    pokeVec = np.concatenate((typeToOneHot(poke[1]), poke[2:8]))

    for move in poke[8]:
        pokeVec = np.concatenate((pokeVec, moveToVector(move)))
    
    return pokeVec.astype(np.int32)

class PokemonEnv(py_environment.PyEnvironment):
  def __init__(self):
    self.round = 0
    self._action_spec = array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=3, name='action') # Action is single int from 0-3 signifying chosen move
    self._observation_spec = array_spec.BoundedArraySpec( # Observation: [poke2type (18), poke1move1type (18), poke1move2type (18), poke1move3type (18), poke1move4type (18)]
        shape=(1, 90,), dtype=np.int32, name='observation')
    
    self.poke1 = randPokemon()
    self.poke2 = randPokemon()

    self.poke1Vec = pokeToVector(self.poke1)
    self.poke2Vec = pokeToVector(self.poke2)

    moves = []
    for move in self.poke1[8]:
        moves = np.concatenate((moves, moveToVector(move)[:18]))

    self._state = np.concatenate((self.poke2Vec[:18], moves)).astype(np.int32) # State: poke1, poke2[type and stats]
    self._episode_ended = False

  def action_spec(self):
    return self._action_spec

  def observation_spec(self):
    return self._observation_spec

  def _reset(self):
    self.round = 0
    self.poke1 = randPokemon()
    self.poke2 = randPokemon()

    self.poke1Vec = pokeToVector(self.poke1)
    self.poke2Vec = pokeToVector(self.poke2)

    moves = []
    for move in self.poke1[8]:
        moves = np.concatenate((moves, moveToVector(move)[:18]))

    self._state = np.concatenate((self.poke2Vec[:18], moves)).astype(np.int32) # State: poke2 type, move1 type, move2 type, ..., move4 type
    self._episode_ended = False
    return ts.restart(self._state)

  def _step(self, action):
    if self._episode_ended:
      return self.reset()

    if action == 0 or action == 1 or action == 2 or action == 3: # If action is valid
        pokeTypes = onehotToType(self._state[:18])
        moveType = onehotToType(self._state[18 + 18*action: 18 + 18 * (action + 1)])[0]
        reward = 100
        for pokeType in pokeTypes:
          reward *= typedf.loc[moveType, pokeType]
      
    self._episode_ended = True
    if self._episode_ended:
      return ts.termination(self._state, reward)
    else:
      return ts.transition(self._state, reward=reward, discount=1)

In [None]:
env = PokemonEnv()
utils.validate_py_environment(env, episodes=5)

In [None]:
train_py_env = wrappers.TimeLimit(PokemonEnv(), duration=100)
eval_py_env = wrappers.TimeLimit(PokemonEnv(), duration=100)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

In [None]:
def compute_avg_return(environment, policy, num_episodes=10):

    total_return = 0.0
    for _ in range(num_episodes):

        time_step = environment.reset()
        episode_return = 0.0

        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
            total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

num_iterations = 75000  # @param

initial_collect_steps = 1000  # @param
collect_steps_per_iteration = 1  # @param
replay_buffer_capacity = 100000  # @param

fc_layer_params = (100,)

batch_size = 128  # @param
learning_rate = 1e-5  # @param
log_interval = 200  # @param

num_eval_episodes = 2  # @param
eval_interval = 1000  # @param

q_net = q_network.QNetwork(
        train_env.observation_spec(),
        train_env.action_spec(),
        fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)

tf_agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        #td_errors_loss_fn = dqn_agent.element_wise_squared_loss,
        train_step_counter=train_step_counter)

tf_agent.initialize()

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_capacity)

replay_observer = [replay_buffer.add_batch]

dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
    num_steps=2).prefetch(3)
    
iterator = iter(dataset)

train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
]

driver = dynamic_step_driver.DynamicStepDriver(
            train_env,
            collect_policy,
            observers=replay_observer + train_metrics,
    num_steps=1)

In [None]:
episode_len = []

final_time_step, policy_state = driver.run()

for i in range(num_iterations):
    final_time_step, _ = driver.run(final_time_step, policy_state)

    experience, _ = next(iterator)
    train_loss = tf_agent.train(experience=experience)
    step = tf_agent.train_step_counter.numpy()

    if step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss.loss))
        episode_len.append(train_metrics[3].result().numpy())
        print('Average episode length: {}'.format(train_metrics[3].result().numpy()))

    if step % eval_interval == 0:
        avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
        print('step = {0}: Average Return = {1}'.format(step, avg_return))
plt.plot(episode_len)
plt.show()

In [None]:
my_policy = agent.collect_policy
saver = PolicySaver(my_policy, batch_size=1)

In [None]:
saved_policy = tf.compat.v2.saved_model.load('policy_move')
pokeEnv = PokemonEnv()
timestep = pokeEnv.reset()
print(timestep)
print(saved_policy.action(timestep))

#action = saved_policy.action(timestep)