In [106]:
from training.rollout import run_selfplay

In [110]:
import jax.numpy as jnp
# Mock ActorCritic Model
class MockActorCritic:
    def apply(self, params, obs):
        # Return dummy logits and value
        batch_size = obs.shape[0]
        board_size = obs.shape[1]
        num_actions = board_size * board_size
        # Simple logits (e.g., uniform or favoring one action)
        dummy_logits = jnp.zeros((batch_size, board_size, board_size))
        # Make one action slightly more likely to ensure selection is possible
        dummy_logits = dummy_logits.at[:, 0, 0].set(1.0)
        dummy_value = jnp.ones(batch_size) * 0.5 # Constant value
        return dummy_logits, dummy_value

    def sample_action(self, logits, rng_key):
        # Sample deterministically for testing (e.g., always pick the highest logit)
        # Or use the key for controlled randomness if needed
        batch_size = logits.shape[0]
        board_size = logits.shape[1]
        flat_logits = logits.reshape(batch_size, -1)
        flat_action_idx = jnp.argmax(flat_logits, axis=-1)
        # Convert flat index back to (row, col)
        action_row = flat_action_idx // board_size
        action_col = flat_action_idx % board_size
        action = jnp.stack([action_row, action_col], axis=-1)
        return action

    def evaluate_actions(self, params, states, actions):
         # Needed for PPO trainer, return dummy values consistent with shapes
        T, B = states.shape[0], states.shape[1] # Assuming states shape (T, B, H, W, C) or similar
        action_log_probs = jnp.zeros((T, B)) - 1.0 # Dummy log prob
        entropy = jnp.ones((T, B)) * 1.5 # Dummy entropy
        values = jnp.ones((T, B)) * 0.5 # Dummy value
        return action_log_probs, entropy, values


# Mock Environment Functions
def mock_reset_env(env_state):
    board_size = env_state['board_size']
    num_boards = env_state['num_boards']
    initial_obs = jnp.zeros((num_boards, board_size, board_size), dtype=jnp.float32)
    # Reset dones, board state etc.
    env_state['boards'] = jnp.zeros_like(env_state['boards'])
    env_state['dones'] = jnp.zeros(num_boards, dtype=jnp.bool_)
    env_state['current_player'] = jnp.ones(num_boards, dtype=jnp.int32) # Player 1 starts
    env_state['steps'] = 0
    return env_state, initial_obs

def mock_step_env(env_state, action):
    # Simple mock step: Game ends after `max_mock_steps`, fixed reward
    num_boards = env_state['num_boards']
    board_size = env_state['board_size']
    max_mock_steps = 5 # Let the mock game end quickly

    env_state['steps'] += 1
    dones = (env_state['steps'] >= max_mock_steps) | env_state['dones'] # Check if game should end
    rewards = jnp.where(dones & ~env_state['dones'], 1.0, 0.0) # Reward 1.0 only on the step it becomes done
    env_state['dones'] = dones
    # Flip player (doesn't really matter for this mock logic)
    env_state['current_player'] = 3 - env_state['current_player']
    # Dummy next observation
    next_obs = jnp.ones((num_boards, board_size, board_size)) * env_state['steps']

    return env_state, next_obs, rewards, dones

def mock_get_action_mask(env_state):
    # Allow all actions
    num_boards = env_state['num_boards']
    board_size = env_state['board_size']
    return jnp.ones((num_boards, board_size, board_size), dtype=jnp.bool_)


In [111]:
import jax.numpy as jnp

def test_run_selfplay(mock_env_state, mock_actor_critic_instance, mock_params, rng):
    # Use the patched environment functions implicitly
    max_mock_steps = 5 # Must match mock_step_env
    board_size = mock_env_state["board_size"]
    num_boards = mock_env_state["num_boards"]

    trajectory, final_rng = run_selfplay(
        mock_env_state, mock_actor_critic_instance, mock_params, rng
    )

    assert isinstance(trajectory, dict)
    assert "obs" in trajectory
    assert "actions" in trajectory
    assert "rewards" in trajectory
    assert "masks" in trajectory
    # assert "values" in trajectory # run_selfplay doesn't collect values currently
    assert "episode_length" in trajectory

    T = trajectory["episode_length"]
    # In mock, all envs finish at the same time
    assert T == max_mock_steps

    assert trajectory["obs"].shape == (T, num_boards, board_size, board_size)
    assert trajectory["actions"].shape == (T, num_boards, 2)
    assert trajectory["rewards"].shape == (T, num_boards)
    assert trajectory["masks"].shape == (T, num_boards)
    # assert trajectory["values"].shape == (T, num_boards)

    # Check mask is False on the last step for all boards
    assert jnp.all(~trajectory["masks"][-1, :])
    # Check mask is True on steps before last for all boards
    if T > 1:
         assert jnp.all(trajectory["masks"][-2, :])

    assert not jnp.array_equal(rng, final_rng) # RNG should be consumed


In [114]:
import pytest

@pytest.fixture
def mock_actor_critic_instance():
    return MockActorCritic()

@pytest.fixture
def mock_params():
    # Params can be empty for our mock model
    return FrozenDict({})

@pytest.fixture
def mock_env_state():
    board_size = 5
    num_boards = 2 # Test with batch size > 1
    return {
        "board_size": board_size,
        "num_boards": num_boards,
        "boards": jnp.zeros((num_boards, board_size, board_size), dtype=jnp.int32),
        "current_player": jnp.ones(num_boards, dtype=jnp.int32),
        "dones": jnp.zeros(num_boards, dtype=jnp.bool_),
        "steps": 0, # Custom field for mock step tracking
    }

@pytest.fixture
def rng():
    return jax.random.PRNGKey(42)


In [116]:

test_run_selfplay(mock_env_state(), mock_actor_critic_instance(), mock_params(), rng())

Failed: Fixture "mock_env_state" called directly. Fixtures are not meant to be called directly,
but are created automatically when test functions request them as parameters.
See https://docs.pytest.org/en/stable/explanation/fixtures.html for more information about fixtures, and
https://docs.pytest.org/en/stable/deprecations.html#calling-fixtures-directly about how to update your code.

In [5]:
import jax.numpy as jnp

x = jnp.zeros((4,5,3,15,15))
current_player = jnp.ones((4,5,3))
prefix_shape = x.shape[:-2] # e.g., (batch,) or (T, batch) or ()
board_shape = x.shape[-2:] # (board_size, board_size)

# Add channel dimension for board state
# Shape: (..., board_size, board_size, 1)
x_proc = jnp.expand_dims(x, axis=-1)

# Create player channel
# Ensure current_player has the correct prefix shape
player_array = jnp.broadcast_to(current_player, prefix_shape)
print(player_array.shape)
# Reshape player_array to (..., 1, 1, 1) for broadcasting to spatial dims
player_array_reshaped = player_array.reshape(prefix_shape + (1,) * (len(board_shape) + 1))
# Create the channel plane: (..., board_size, board_size, 1)
player_channel = jnp.ones_like(x_proc) * player_array_reshaped

# Concatenate board state and player channel
# Shape: (..., board_size, board_size, 2)
x_combined = jnp.concatenate([x_proc, player_channel], axis=-1)


(4, 5, 3)


In [6]:
print(x_combined.shape)

(4, 5, 3, 15, 15, 2)
