In [6]:
import numpy as np
import os
import tensorflow as tf

from keras.utils import to_categorical
from pommerman.agents import BaseAgent, SimpleAgent
from pommerman.configs import ffa_v0_env
from pommerman.constants import BOARD_SIZE
from pommerman.envs.v0 import Pomme

In [7]:
initial_rollouts = 400
train_data_path = './dagger/train_data/'
train_data_obs = 'obs.npy'
train_data_labels = 'labels.npy'
if not os.path.isdir(train_data_path):
    os.makedirs(train_data_path)

In [9]:
# Simple wrapper around policy function to have an act function
class Expert:
    def __init__(self, config):
        self.__agent = SimpleAgent(config)

    def act(self, obs):
        return self.__agent.act(obs, None)

    def record_reward(self, reward):
        pass
    
    def add_log(self, tag, value, step):
        pass    


class TensorforceAgent(BaseAgent):
    def act(self, obs, action_space):
        pass


# Environment wrapper
class Stimulator:
    def __init__(self, env, config, agent_pos=0):
        self.env = env
        self.init(config, agent_pos)
        self.episode_number = 0

    def init(self, config, agent_pos):
        self.env.seed(0)
        # Add 3 random agents
        agents = []
        for agent_id in range(4):
            # Add TensorforceAgent
            if agent_id == agent_pos:
                agents.append(TensorforceAgent(config["agent"](agent_id, config["game_type"])))
            else:
                agents.append(SimpleAgent(config["agent"](agent_id, config["game_type"])))
        self.env.set_agents(agents)
        self.env.set_training_agent(agents[agent_pos].agent_id)
        self.env.set_init_game_state(None)

    def stimulate(self, agent, num_rollouts, render=False, logging=False):
        returns = []
        observations = []
        actions = []
        
        for i in range(num_rollouts):
            self.episode_number += 1
            obs = self.env.reset()
            done = False
            total_reward = 0.
            episode_steps = 0
            
            while not done:
                if render:
                    self.env.render()                
                action = agent.act(obs[self.env.training_agent])
                all_actions = self.env.act(obs)                
                all_actions.insert(self.env.training_agent, action)   
                obs, reward, done, _ = self.env.step(all_actions)                
                total_reward += reward[self.env.training_agent]                
                episode_steps += 1
                observations.append(obs[self.env.training_agent])
                actions.append(action)
            print('rollout %i/%i return=%f' % (i + 1, num_rollouts, total_reward))
            if logging:
                agent.add_log('Episode reward', total_reward, self.episode_number)
                agent.add_log('Episode length', episode_steps, self.episode_number)                
            returns.append(total_reward)
        print('Return summary: mean=%f, std=%f' % (np.mean(returns), np.std(returns)))
        agent.record_reward(returns)
        return (np.array(observations), to_categorical(actions, self.env.action_space.n))

In [10]:
# Instantiate the environment
config = ffa_v0_env()
env = Pomme(**config["env_kwargs"])

expert = Expert(config["agent"](0, config["game_type"]))

# Generate training data
stimulator = Stimulator(env, config, 0)
training_data1 = stimulator.stimulate(expert, num_rollouts=initial_rollouts)
stimulator = Stimulator(env, config, 1)
training_data2 = stimulator.stimulate(expert, num_rollouts=initial_rollouts)
stimulator = Stimulator(env, config, 2)
training_data3 = stimulator.stimulate(expert, num_rollouts=initial_rollouts)
stimulator = Stimulator(env, config, 3)
training_data4 = stimulator.stimulate(expert, num_rollouts=initial_rollouts)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
rollout 1/400 return=-1.000000
rollout 2/400 return=1.000000
rollout 3/400 return=-1.000000
rollout 4/400 return=-1.000000
rollout 5/400 return=-1.000000
rollout 6/400 return=-1.000000
rollout 7/400 return=-1.000000
rollout 8/400 return=1.000000
rollout 9/400 return=1.000000
rollout 10/400 return=-1.000000
rollout 11/400 return=-1.000000
rollout 12/400 return=1.000000
rollout 13/400 return=-1.000000
rollout 14/400 return=-1.000000
rollout 15/400 return=-1.000000
rollout 16/400 return=-1.000000
rollout 17/400 return=1.000000
rollout 18/400 return=-1.000000
rollout 19/400 return=-1.000000
rollout 20/400 return=-1.000000
rollout 21/400 return=-1.000000
rollout 22/400 return=1.000000
rollout 23/400 return=1.000000
rollout 24/400 return=1.000000
rollout 25/400 return=-1.000000
rollout 26/400 return=-1.000000
rollout 27/400 return=-1.000000
rollout 28/400 return=-1.000000
rollout 29/40

rollout 252/400 return=1.000000
rollout 253/400 return=-1.000000
rollout 254/400 return=1.000000
rollout 255/400 return=-1.000000
rollout 256/400 return=-1.000000
rollout 257/400 return=-1.000000
rollout 258/400 return=-1.000000
rollout 259/400 return=-1.000000
rollout 260/400 return=1.000000
rollout 261/400 return=-1.000000
rollout 262/400 return=-1.000000
rollout 263/400 return=-1.000000
rollout 264/400 return=-1.000000
rollout 265/400 return=-1.000000
rollout 277/400 return=-1.000000
rollout 278/400 return=-1.000000
rollout 279/400 return=-1.000000
rollout 280/400 return=-1.000000
rollout 281/400 return=-1.000000
rollout 282/400 return=-1.000000
rollout 283/400 return=-1.000000
rollout 284/400 return=1.000000
rollout 285/400 return=-1.000000
rollout 286/400 return=-1.000000
rollout 287/400 return=-1.000000
rollout 288/400 return=-1.000000
rollout 289/400 return=1.000000
rollout 290/400 return=-1.000000
rollout 291/400 return=-1.000000
rollout 292/400 return=-1.000000
rollout 293/400

rollout 115/400 return=-1.000000
rollout 116/400 return=-1.000000
rollout 117/400 return=-1.000000
rollout 118/400 return=-1.000000
rollout 119/400 return=-1.000000
rollout 120/400 return=-1.000000
rollout 121/400 return=1.000000
rollout 122/400 return=-1.000000
rollout 123/400 return=1.000000
rollout 124/400 return=-1.000000
rollout 125/400 return=-1.000000
rollout 126/400 return=-1.000000
rollout 127/400 return=-1.000000
rollout 128/400 return=-1.000000
rollout 129/400 return=-1.000000
rollout 130/400 return=-1.000000
rollout 131/400 return=-1.000000
rollout 132/400 return=-1.000000
rollout 133/400 return=-1.000000
rollout 134/400 return=1.000000
rollout 135/400 return=-1.000000
rollout 136/400 return=-1.000000
rollout 137/400 return=-1.000000
rollout 138/400 return=-1.000000
rollout 139/400 return=-1.000000
rollout 140/400 return=-1.000000
rollout 141/400 return=1.000000
rollout 142/400 return=-1.000000
rollout 143/400 return=-1.000000
rollout 144/400 return=-1.000000
rollout 145/40

rollout 365/400 return=-1.000000
rollout 366/400 return=-1.000000
rollout 367/400 return=1.000000
rollout 368/400 return=1.000000
rollout 369/400 return=1.000000
rollout 370/400 return=-1.000000
rollout 371/400 return=-1.000000
rollout 372/400 return=-1.000000
rollout 373/400 return=-1.000000
rollout 374/400 return=1.000000
rollout 375/400 return=-1.000000
rollout 376/400 return=-1.000000
rollout 377/400 return=1.000000
rollout 378/400 return=-1.000000
rollout 379/400 return=-1.000000
rollout 380/400 return=-1.000000
rollout 381/400 return=-1.000000
rollout 382/400 return=1.000000
rollout 383/400 return=-1.000000
rollout 384/400 return=-1.000000
rollout 385/400 return=-1.000000
rollout 386/400 return=-1.000000
rollout 387/400 return=-1.000000
rollout 388/400 return=1.000000
rollout 389/400 return=1.000000
rollout 390/400 return=1.000000
rollout 391/400 return=1.000000
rollout 392/400 return=-1.000000
rollout 393/400 return=1.000000
rollout 394/400 return=-1.000000
rollout 395/400 retur

rollout 217/400 return=1.000000
rollout 218/400 return=-1.000000
rollout 219/400 return=-1.000000
rollout 220/400 return=-1.000000
rollout 221/400 return=-1.000000
rollout 222/400 return=-1.000000
rollout 223/400 return=-1.000000
rollout 224/400 return=-1.000000
rollout 225/400 return=-1.000000
rollout 226/400 return=-1.000000
rollout 227/400 return=-1.000000
rollout 228/400 return=-1.000000
rollout 229/400 return=-1.000000
rollout 230/400 return=-1.000000
rollout 231/400 return=-1.000000
rollout 232/400 return=-1.000000
rollout 233/400 return=1.000000
rollout 234/400 return=-1.000000
rollout 235/400 return=-1.000000
rollout 236/400 return=-1.000000
rollout 237/400 return=-1.000000
rollout 238/400 return=-1.000000
rollout 239/400 return=-1.000000
rollout 240/400 return=1.000000
rollout 241/400 return=1.000000
rollout 242/400 return=-1.000000
rollout 243/400 return=-1.000000
rollout 244/400 return=-1.000000
rollout 245/400 return=-1.000000
rollout 246/400 return=1.000000
rollout 247/400

rollout 68/400 return=-1.000000
rollout 69/400 return=1.000000
rollout 70/400 return=1.000000
rollout 71/400 return=-1.000000
rollout 72/400 return=-1.000000
rollout 73/400 return=-1.000000
rollout 74/400 return=-1.000000
rollout 75/400 return=-1.000000
rollout 76/400 return=-1.000000
rollout 77/400 return=-1.000000
rollout 78/400 return=1.000000
rollout 79/400 return=-1.000000
rollout 80/400 return=1.000000
rollout 81/400 return=-1.000000
rollout 82/400 return=-1.000000
rollout 83/400 return=-1.000000
rollout 84/400 return=1.000000
rollout 85/400 return=-1.000000
rollout 86/400 return=-1.000000
rollout 87/400 return=-1.000000
rollout 88/400 return=-1.000000
rollout 89/400 return=-1.000000
rollout 90/400 return=1.000000
rollout 91/400 return=-1.000000
rollout 92/400 return=-1.000000
rollout 93/400 return=1.000000
rollout 94/400 return=-1.000000
rollout 95/400 return=-1.000000
rollout 96/400 return=-1.000000
rollout 97/400 return=-1.000000
rollout 98/400 return=-1.000000
rollout 99/400 

rollout 319/400 return=-1.000000
rollout 320/400 return=-1.000000
rollout 321/400 return=-1.000000
rollout 322/400 return=-1.000000
rollout 323/400 return=-1.000000
rollout 324/400 return=1.000000
rollout 325/400 return=-1.000000
rollout 326/400 return=-1.000000
rollout 327/400 return=1.000000
rollout 328/400 return=-1.000000
rollout 329/400 return=-1.000000
rollout 330/400 return=1.000000
rollout 331/400 return=-1.000000
rollout 332/400 return=-1.000000
rollout 333/400 return=-1.000000
rollout 334/400 return=-1.000000
rollout 335/400 return=-1.000000
rollout 336/400 return=-1.000000
rollout 337/400 return=-1.000000
rollout 338/400 return=-1.000000
rollout 339/400 return=-1.000000
rollout 340/400 return=-1.000000
rollout 341/400 return=-1.000000
rollout 342/400 return=-1.000000
rollout 343/400 return=-1.000000
rollout 344/400 return=-1.000000
rollout 345/400 return=1.000000
rollout 346/400 return=-1.000000
rollout 347/400 return=-1.000000
rollout 348/400 return=-1.000000
rollout 349/40

In [20]:
training_data_obs = np.concatenate([training_data1[0], training_data2[0], training_data3[0], training_data4[0]])
training_data_labels = np.concatenate([training_data1[1], training_data2[1], training_data3[1], training_data4[1]])

In [19]:
training_data_obs.shape

(1256511,)

In [21]:
np.save(train_data_path + train_data_obs, training_data_obs)
np.save(train_data_path + train_data_labels, training_data_labels)

In [28]:
np.sum(training_data_labels, axis=0) / np.sum(training_data_labels)

array([0.36810502, 0.16168979, 0.14268638, 0.1558132 , 0.1535856 ,
       0.01812002], dtype=float32)