In [None]:
import numpy as np
from random import random, choice

from matplotlib import cm
from time import sleep
from colosseumrl.envs.tron import TronGridEnvironment, TronRender, TronRllibEnvironment

import gym
from gym import Env
from gym.spaces import Dict, Discrete, Box

import ray
from ray import tune
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG
from ray.rllib.agents.dqn import DQNTrainer, DEFAULT_CONFIG

from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.models import ModelCatalog

SEED = 1517
np.random.seed(SEED)

## Getting Experimental
It would most beneficial to have a fully dynamic training process where different agents train against each other and learn to communicate with and outsmart each other. Unfortunately, this process is highly unstable. We no longer have a "static" environment on which to train an agent on.

## This is not stable and wont train well!

In [None]:
class TronExtractBoard(Preprocessor):
    """ Wrapper to extract just the board from the game state and simplify it for the network. """        
    def _init_shape(self, obs_space, options):
        board_size = env.observation_space['board'].shape[0]
        return (board_size + 2, board_size + 2, 2)
    
    def transform(self, observation):
        
        # Pretty hacky way to get the current player number
        # Requires having exactly 4 players
        board = observation['board']
        hor_offset = board.shape[0] // 2 + 2
        top_player = board[1, hor_offset]
        player_number = {1: 0, 4: 1, 3: 2, 2: 3}[top_player]
        
        return self._transform(observation, player_number)

    def _transform(self, observation, rotate: int = 0):
        board = observation['board'].copy()
        
        # Make all enemies look the same
        board[board > 1] = -1
        
        # Mark where all of the player heads are
        heads = np.zeros_like(board)
        
        if (rotate != 0):
            heads.ravel()[observation['heads']] += 1 + ((observation['directions'] - rotate) % 4)
            
            board = np.rot90(board, k=rotate)
            heads = np.rot90(heads, k=rotate)
            
        else:
            heads.ravel()[observation['heads']] += 1 + observation['directions']
            
        # Pad the outsides so that we know where the wall is
        board = np.pad(board, 1, 'constant', constant_values=-1)
        heads = np.pad(heads, 1, 'constant', constant_values=-1)
        
        # Combine together
        board = np.expand_dims(board, -1)
        heads = np.expand_dims(heads, -1)
        
        return np.concatenate([board, heads], axis=-1)

In [None]:
# Initialize training environment
ray.init()

def environment_creater(params=None):
    return TronRllibEnvironment(board_size=13, num_players=4)

env = environment_creater()
tune.register_env("tron_multi_player", environment_creater)
ModelCatalog.register_custom_preprocessor("tron_prep", TronExtractBoard)

# Configure Deep Q Learning for multi-agent training
config = DEFAULT_CONFIG.copy()
config['num_workers'] = 4
config["timesteps_per_iteration"] = 128
config['target_network_update_freq'] = 256
config['buffer_size'] = 10_000
config['schedule_max_timesteps'] = 100_000
config['exploration_fraction'] = 0.9
config['compress_observations'] = False
config['num_envs_per_worker'] = 1
config['train_batch_size'] = 256
config['n_step'] = 2

# All of the models will use the same network as before
agent_config = {
    "model": {
        "vf_share_layers": True,
        "conv_filters": [(64, 5, 2), (128, 3, 2), (256, 3, 2)],
        "fcnet_hiddens": [128],
        "custom_preprocessor": 'tron_prep'
    }
}

# This time we have 4 seperate policies
# All training AT THE SAME TIME
config['multiagent'] = {
        "policy_mapping_fn": lambda x: str(x),
        "policies": {str(i): (None, env.observation_space, env.action_space, agent_config) for i in range(env.env.num_players)}
}
       
trainer = DQNTrainer(config, "tron_multi_player")

num_epoch = 300
test_epochs = 1
for epoch in range(num_epoch):
    print("Training iteration: {}".format(epoch), end='')
    res = trainer.train()
    print(f", Average reward: {res['policy_reward_mean']}")