# Import

In [1]:
# IMPORT FOR BASIC UTILITIES

import sys
import io
import gc
import os
import pickle
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
from PIL import Image
import tensorflow as tf
from tensorflow import keras
import numpy as np
import gym
from Agents.DDDQN.DDDQNAgent import DDDQNAgent
#from Agents.TD3.TD3Agent import TD3Agent
from Utilities.TicTacToe import TicTacToeEnv
#from Utilities.ConnectFour import ConnectFourEnv
#from Utilities.Santorini import SantoriniEnv
from Utilities.Wrappers import OpponentWrapper
from Utilities.TrainWizard import TurnGameTrainWizard


  ACTION_SPACE = np.array([[2, 8, 8], [2, 3, 3, 3, 3]]) #Deprecato


In [None]:
# TEST CONFIGURATION

algorithm = 'DDDQN'
environment = 'TicTacToe'
representation = 'Tabular'
opponent = 'Random'
agent_turn = 'First'

config_name = algorithm + '_' + environment + '_' + representation + '_' + opponent + '_' + agent_turn
data_path = '..\\Results\\' + config_name + '\\'
gif_path = data_path + 'GIFs\\'
network_path = data_path + 'NetworkParameters\\'
! mkdir $data_path
! mkdir $gif_path
! mkdir $network_path

# Environment

In [None]:
# GAME PARAMETERS AND NETWORK STRUCTURE

env = OpponentWrapper(TicTacToeEnv(representation, agent_turn=='First'), 'Random')

network_dict_base = {0: 
     {'name': 'Dense',
      'params': {
          'units': 64, 
          'activation': 'relu',
          'kernel_initializer': tf.keras.initializers.HeNormal()
      }},
     1: 
     {'name': 'Dense',
      'params': {
          'units': 32, 
          'activation': 'relu',
          'kernel_initializer': tf.keras.initializers.HeNormal()
      }}}
network_dict_advantage = {2: 
     {'name': 'Flatten',
      'params': {}
      },
      3: 
      {'name': 'Dense',
       'params': {
          'units': env.action_space.n, 
          'activation': 'relu',
          'kernel_initializer': tf.keras.initializers.HeNormal()
      }}}
network_dict_value = {4: 
     {'name': 'Flatten',
      'params': {}
      },
      5: 
      {'name': 'Dense',
      'params': {
          'units': 1, 
          'activation': 'relu',
          'kernel_initializer': tf.keras.initializers.HeNormal()
      }}}

In [None]:
# AGENT

agent = DDDQNAgent(environment=env,
                   q_net_dict=[network_dict_base, network_dict_advantage, network_dict_value],
                   q_target_net_dict=[network_dict_base, network_dict_advantage, network_dict_value], 
                   double_q=True, 
                   dueling_q=True, 
                   q_net_update=4,
                   q_target_net_update=10000, 
                   discount_factor=0.99, 
                   q_net_optimizer=keras.optimizers.Adam, 
                   q_target_net_optimizer=keras.optimizers.Adam, 
                   q_net_learning_rate=1e-5,
                   q_target_net_learning_rate=1e-5, 
                   q_net_loss=keras.losses.Huber(), 
                   q_target_net_loss=keras.losses.Huber(), 
                   num_episodes=100000,
                   memory_size=8192, 
                   memory_alpha=0.7, 
                   memory_beta=0.4, 
                   max_epsilon=1.0, 
                   min_epsilon=0.001, 
                   epsilon_A=0.35, 
                   epsilon_B=0.25, 
                   epsilon_C=0.1,
                   batch_size=32, 
                   checkpoint_dir=network_path)

In [None]:
# WIZARD

wizard = TurnGameTrainWizard(environment=env,
                     agent=agent,
                     objective_score=1,
                     running_average_length=100,
                     evaluation_steps=1000,
                     evaluation_games=100,
                     agent_turn=True,
                     agent_turn_test=True,
                     opponent='Random',
                     path=data_path)

# Training

In [None]:
wizard.train()

# Plots

In [None]:
sns.set(rc={'figure.figsize':(16,9)})
data = np.array([(key, value[0][i][0], value[0][i][1]) for key, value in wizard.eval_reward_history.items() for i in range(len(value[0]))])

In [None]:
sns.lineplot(data[:,0], data[:,1])

In [None]:
sns.lineplot(data[:,0], data[:,2])