# Data Exploration & Analysis

In [None]:
import os
import numpy as np
import jax
from jaxmarl import make

from safetensors.numpy import load_file
from ah2ac2.datasets.dataloader import HanabiLiveGamesDataloader

In [None]:
directory_path = './data'

restored_games = []
restored_games_names = []
for filename in os.listdir(directory_path):
    if not filename.endswith(".safetensors"):
        continue

    file_path = os.path.join(directory_path, filename)
    if os.path.isdir(file_path):
        continue

    loaded_tensors = load_file(file_path)
    restored_games.append(loaded_tensors)
    restored_games_names.append(filename) 

In [None]:
from ah2ac2.datasets.dataset import HanabiLiveGamesDataset
for restored_game, n in zip(restored_games, restored_games_names):
    print("###########################################")
    print(n)
    print(f'Players: {restored_game["num_players"]}')
    print(f'Games: {len(restored_game["game_ids"])}')
    
    print("Scores:")
    scores = restored_game["scores"]
    print(f'Min={scores.min()} | Max={scores.max()} | Avg={scores.mean()} | Median={np.median(scores)} | Std={scores.std()}')
    
    print("Game Lengths:")
    game_lengths = restored_game["num_actions"]
    print(f'Min={game_lengths.min()} | Max={game_lengths.max()} | Avg={game_lengths.mean()} | Median={np.median(game_lengths)} | Std={game_lengths.std()}')
    print("###########################################\n")

# Dataset & Dataloader

We provide implementation for loading the data into a dataset so the game data can be conveniently accessed. Additionally, we provide the data loader that can be used for training the models.

In [None]:
train_dataset_path = './data/2_player_games_train_1k.safetensors'
val_dataset_path = './data/2_player_games_val.safetensors'

train_dataset = HanabiLiveGamesDataset(
    file=train_dataset_path,
    color_shuffle_key=jax.random.PRNGKey(0),  # Using color permutations - just pass the key!
)
val_dataset = HanabiLiveGamesDataset(file=val_dataset_path)

# You can access info about the individual game like this:
first_game = train_dataset[0]
game_attributes = [
    attr for attr 
    in dir(first_game)
    if not attr.startswith('__') and not attr.startswith('_')
]
print(f"Game Attributes: {game_attributes}")
# How many games are in the dataset?
print(f"Total Number of Games in the 2-Player Train Dataset: {len(train_dataset)}")
print(f"Total Number of Games in the 2-Player Val Dataset: {len(val_dataset)}")

We provide dataloader that supports iteration over batches and shuffling when loading the games. The demo on how to use the dataloader is provided in the cell below. For shuffling the dataset, just pass the key - you will have reproducible and deterministic data loading process.  

In [None]:
from typing import NamedTuple
from jaxmarl.environments.hanabi import hanabi_game
import chex
import jax.numpy as jnp
from jaxmarl.environments.hanabi.hanabi import HanabiEnv

rng, _rng = jax.random.split(jax.random.PRNGKey(0))

batch_size = 8
train_loader = HanabiLiveGamesDataloader(
    dataset=train_dataset, 
    batch_size=batch_size, 
    shuffle_key=_rng
)

# If you don't pass the batch size, you will get all the games at once when iterating
val_loader = HanabiLiveGamesDataloader(val_dataset)


In the next cell we show how to unroll the games in JaxMARL using vmap. Additionally, you can check out `bc.py` where we show the entire training loop for a BC policy.

In [None]:
class Transition(NamedTuple):
    current_timestep: int  # We know there is `turn` in env_state, but game might reset!
    env_state: hanabi_game.State  # Current state of the environment.
    reached_terminal: jnp.bool_

def batchify(x, agent_list):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((len(agent_list), -1))

def make_play(num_players):
    env: HanabiEnv = make("hanabi", num_agents=int(num_players))

    def play(
        rng: chex.PRNGKey,
        deck: chex.Array,
        actions: chex.Array,
    ):
        # Initialize the environment.
        _, initial_env_state = env.reset_from_deck_of_pairs(deck)

        def _step(transition: Transition, step_actions: jax.Array):
            # Unbatchify actions
            env_act = {a: step_actions[i] for i, a in enumerate(env.agents)}

            # Step the environment with selected actions.
            new_obs, new_env_state, reward, dones, infos = env.step_env(
                rng,  # NOTE: This is not really important, not stochastic.
                transition.env_state,
                env_act,
            )

            is_episode_end = jnp.logical_or(dones["__all__"], transition.reached_terminal)
            return Transition(
                current_timestep=transition.current_timestep + 1,
                env_state=new_env_state,
                reached_terminal=is_episode_end
            ), None

        initial_transition = Transition(
            current_timestep=0,
            env_state=initial_env_state,
            reached_terminal=False,
        )
        return jax.lax.scan(_step, initial_transition, actions)

    return play


play_game_vjit = jax.vmap(make_play(train_loader.dataset.num_players), in_axes=0)
for game_batch in train_loader:
    batch_actions = game_batch.actions
    batch_decks = game_batch.decks

    play_game_keys = jax.random.split(jax.random.PRNGKey(0), game_batch.game_ids.size)
    final_transition, _ = play_game_vjit(play_game_keys, batch_decks, batch_actions)