In [48]:
actions = (0, 2, 0)
q_vals = np.arange(len(actions) * 4).reshape(-1, 4)
values = np.arange(len(actions)).reshape(-1, 1) * 10
n_visits = np.arange(len(actions) * 4).reshape(-1, 4)

In [61]:
still_playing = np.array([True, False, True])

In [62]:
print(np.argmax(q_vals, axis=1))
np.where(
    still_playing,
    np.argmax(q_vals, axis=1),
    -1
)

[3 2 0]


array([ 3, -1,  0])

In [53]:
q_vals = np.where(
    np.eye(4)[np.array(actions)],
    q_vals + values,
    q_vals
)
q_vals

array([[ 0,  1,  2,  3],
       [ 4,  5, 16,  7],
       [28,  9, 10, 11]])

In [9]:
import copy
from kaggle_environments import make as kaggle_make
import numpy as np
from scipy import special, stats
import torch
import torch.nn.functional as F

from hungry_geese.utils import ActionMasking
from hungry_geese.env import goose_env as ge
from hungry_geese.env.lightweight_env import LightweightEnv, make
from hungry_geese.mcts.basic_mcts import run_mcts

In [2]:
env = make('hungry_geese')
env.reset(num_agents=4);

In [21]:
goose_lengths = np.arange(2.) * 3
special.softmax(goose_lengths / 3.)

array([0.26894142, 0.73105858])

In [3]:
def action_mask_func(state):
    return ActionMasking.LETHAL.get_action_mask(state)

def actor_func(state):
    return np.zeros((len(state), 4)) + 0.25
    
# Geese are evaluated based on the proportion of their length to the total length of all geese
# This value is then rescaled to account for geese that have already died
# Finally, the overall rankings and ranking estimates are rescaled to be from -1 to 1
def critic_func(state):
    n_geese = float(len(state))
    goose_lengths = np.array([len(goose) for goose in state[0]['observation']['geese']]).astype(np.float)
    dead_geese_mask = goose_lengths == 0
    agent_rankings = stats.rankdata([agent['reward'] for agent in state], method='average') - 1.
    agent_rankings_rescaled = agent_rankings / (n_geese - 1.)
    
    goose_lengths_norm = goose_lengths / goose_lengths.sum()
    remaining_equity = n_geese / 2. - np.sum(np.arange(dead_geese_mask.sum()) / (n_geese - 1.))
    goose_lengths_norm_rescaled = goose_lengths_norm * remaining_equity
    
    final_ranks = np.where(
        dead_geese_mask,
        agent_rankings_rescaled,
        goose_lengths_norm_rescaled
    )
    if not np.isclose(final_ranks.sum(), n_geese / 2.):
        raise RuntimeError(f'Final ranks should sum to {n_geese / 2.}\n'
                           f'Final ranks: {final_ranks}\nDead geese mask: {dead_geese_mask}')
    return 2. * final_ranks - 1.
    
def terminal_value_func(state):
    agent_rankings = stats.rankdata([agent['reward'] for agent in state], method='average') - 1.
    ranks_rescaled = 2. * agent_rankings / (len(state) - 1.) - 1.
    return ranks_rescaled
    
def actor_critic_func(state):
    return actor_func(state), critic_func(state)

In [6]:
print(env.render_ansi())

search_tree = run_mcts(
    env=env,
    n_iter=10000,
    action_mask_func=action_mask_func,
    actor_critic_func=actor_critic_func,
    terminal_value_func=terminal_value_func,
    max_time=1.
)

+---+---+---+---+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+---+---+---+---+
|   |   |   |   | 0 |   |   |   | F |   |   |
+---+---+---+---+---+---+---+---+---+---+---+
|   |   |   |   |   |   | 2 | 3 |   |   |   |
+---+---+---+---+---+---+---+---+---+---+---+
|   |   | F |   |   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+---+---+---+---+
|   | 1 |   |   |   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+---+---+---+---+



In [8]:
search_tree.n_visits

array([1381., 1381., 1381., 1381.])