In [2]:
import os
import jax
import jax.numpy as jnp
from jax import lax
from typing import Tuple, Dict, Any, NamedTuple, Optional
from functools import partial
from jax.sharding import Mesh, NamedSharding, PartitionSpec as PS
from dataclasses import dataclass
from jax import jit,lax,tree
import distrax

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'



In [2]:
print(jax.devices())

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]


In [5]:
mesh = Mesh(jax.devices(),axis_names=("batch"))

def mesh_sharding(*names: str | None) -> NamedSharding:
    return NamedSharding(mesh, PS(*names))

@dataclass(unsafe_hash=True)
class MeshRules:
    batch: tuple[str | None, ...]
    replicated: tuple[str | None, ...]
    buffer: tuple[str | None, ...]
    
    def __call__(self, name: str) -> NamedSharding:
        sharding_spec = getattr(self, name)
        return mesh_sharding(*sharding_spec)  

In [7]:
mesh_rules = MeshRules(
    batch=('batch',),
    replicated=(),  # Empty tuple for no sharding
    buffer=(None, 'batch')
)

In [5]:
# ENV RELATED COMPONENTS 

import jax
import jax.numpy as jnp
from typing import Tuple, Dict, Any, NamedTuple

# This is the state definition from the original environment.
# It's needed for type hints and for the methods to return the correct state structure.
class EnvState(NamedTuple):
    """Holds the dynamic state of the batched Gomoku environment for a single step."""
    boards: jnp.ndarray  # (B, board_size, board_size) float32 tensor
    current_players: jnp.ndarray  # (B,) int32 tensor (1 or -1)
    dones: jnp.ndarray  # (B,) bool tensor
    winners: jnp.ndarray  # (B,) int32 tensor (1, -1, or 0 for draw/ongoing)


WIN_LENGTH = 5 # Default from original environment

class DummyEnv:
    """
    A dummy version of Env with minimal logic, focusing on compatible shapes.
    Observations are player agnostic.
    """

    def __init__(self, B: int, board_size: int = 9, win_length: int = WIN_LENGTH):
        """
        Initializes the dummy environment configuration.

        Args:
            B: Batch size.
            board_size: The size of the Gomoku board.
            win_length: The number of consecutive pieces needed to win (unused in dummy logic).
        """
        # self.B is set by JaxEnvBase
        self.B = B
        self.board_size = board_size
        self.win_length = win_length # Stored for compatibility, but not used by dummy logic

    def init_state(self) -> EnvState:
        """
        Creates an initial dummy EnvState.
        """
        return lax.with_sharding_constraint(EnvState(
            boards=jnp.zeros((self.B, self.board_size, self.board_size), dtype=jnp.float32),
            current_players=jnp.ones((self.B,), dtype=jnp.int32), # Player 1 starts
            dones=jnp.zeros((self.B,), dtype=jnp.bool_),
            winners=jnp.zeros((self.B,), dtype=jnp.int32),
        ),mesh_rules("batch"))

    def reset(
        self
    ) -> Tuple[EnvState, jnp.ndarray, Dict[str, Any]]:
        """
        Resets environments to initial dummy states.
        """
        new_state = self.init_state()
        initial_observations = new_state.boards # Shape: (B, board_size, board_size)
        info = {}
        return new_state, initial_observations, info

    def step(
        self, state: EnvState, actions: jnp.ndarray
    ) -> Tuple[EnvState, jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[str, Any]]:
        """
        Takes a dummy step. Switches player, returns zero rewards and current dones.
        Board state does not change. Actions are ignored.

        Args:
            state: The current GomokuState.
            actions: JAX array of actions (row, col). Shape (B, 2). (Ignored by dummy)

        Returns:
            A tuple (new_state, observations, rewards, dones, info).
        """
        # Dummy state update: just flip player, keep board, rng, and dones the same
        # If you need dones to eventually become true for testing, you'd add logic here.
        # For a simple dummy, keeping dones as they are (initially all False) is okay.
        # Or, for testing termination, you could do:
        # rng_step, _ = jax.random.split(state.rng)
        # new_dones = jax.random.bernoulli(rng_step, 0.1, (self.B,)) # Example: 10% chance of ending
        new_dones = state.dones # Keep dones as they are for simplicity

        new_state = state._replace(dones=new_dones)          
        observations = new_state.boards # Shape: (B, board_size, board_size)
        rewards = jnp.zeros((self.B,), dtype=jnp.float32) # No rewards
        
        info = {}
        return new_state, observations, rewards, new_dones, info

    def initialize_trajectory_buffers(self, max_steps: int) -> Tuple[jnp.ndarray, ...]:
        """
        Creates and returns pre-allocated JAX arrays for storing trajectory data.
        This is mostly shape-dependent and matches the original.
        """
        obs_shape = self.observation_shape
        act_shape = self.action_shape

        observations = jnp.zeros((max_steps, self.B) + obs_shape, dtype=jnp.float32)
        actions = jnp.zeros((max_steps, self.B) + act_shape, dtype=jnp.int32)
        # values buffer is often T+1 for GAE
        values = jnp.zeros((max_steps + 1, self.B), dtype=jnp.float32) 
        rewards = jnp.zeros((max_steps, self.B), dtype=jnp.float32)
        dones = jnp.zeros((max_steps, self.B), dtype=jnp.bool_)
        log_probs = jnp.zeros((max_steps, self.B), dtype=jnp.float32)
        current_players_buffer = jnp.zeros((max_steps, self.B), dtype=jnp.int32)

        sharded_output = tree.map(lambda x: lax.with_sharding_constraint(x,mesh_rules("buffer")), (observations, actions, values, rewards, dones, log_probs, current_players_buffer))
        return sharded_output

    @property
    def observation_shape(self) -> tuple:
        """Returns the shape of a single observation (board state)."""
        return (self.board_size, self.board_size)

    @property
    def action_shape(self) -> tuple:
        """Returns the shape of a single action (row, col)."""
        return (2,) # (row, col)

    def get_action_mask(self, state: EnvState) -> jnp.ndarray:
        """
        Returns a dummy boolean mask of valid actions.
        All actions are considered valid if the game is not done.
        """
        # All positions are valid if not done.
        # Shape: (B, board_size, board_size)
        mask = jnp.ones((self.B, self.board_size, self.board_size), dtype=jnp.bool_)
        # Respect existing done flags
        return mask & (~state.dones[:, None, None])


In [6]:
class LoopState(NamedTuple):
    state: EnvState
    obs: jnp.ndarray
    observations: jnp.ndarray
    values: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    dones: jnp.ndarray
    logprobs: jnp.ndarray
    current_players: jnp.ndarray  
    step_idx: int
    rng: jax.random.PRNGKey
    termination_step_indices: (
        jnp.ndarray
    )  # Stores the step index 't' when done first becomes True for each batch element


In [24]:

@partial(jit, static_argnames=["env", "actor_critic"])
def player_move(
    loop_state: LoopState, env, actor_critic, params
) -> LoopState:
    """Takes a single step in the environment using the provided actor-critic."""
    current_state: EnvState = loop_state.state
    current_obs: jnp.ndarray = loop_state.obs
    current_player: jnp.ndarray = (
        current_state.current_players
    ) 
    step_idx: int = loop_state.step_idx
    rng = loop_state.rng

    # Get policy distribution and value from the model
    pi_dist, value = actor_critic.apply(
        {"params": params}, current_obs, current_player
    ) 
    lax.with_sharding_constraint(value,mesh_rules("batch"))
    lax.with_sharding_constraint(pi_dist.logits,mesh_rules("batch"))

    # Get action mask from the environment
    action_mask = env.get_action_mask(current_state)  # (B, H, W)
    jax.debug.visualize_array_sharding(action_mask[:,:,0])
    B, H, W = action_mask.shape
    flat_action_mask = action_mask.reshape(B, -1)  # (B, H*W)

    # Get original logits from the distribution
    original_logits = pi_dist.logits  # Shape (B, H*W)

    # Apply the mask to the logits
    masked_logits = jnp.where(flat_action_mask, original_logits, -jnp.inf)

    # Create a new distribution with masked logits
    masked_pi_dist = distrax.Categorical(logits=masked_logits)

    # Sample action from the masked distribution
    rng, subkey = jax.random.split(rng)
    flat_action = masked_pi_dist.sample(seed=subkey)  # Shape (B,)
    logprob = masked_pi_dist.log_prob(flat_action)  # Shape (B,)

    # Convert flat action back to (row, col) for the environment step
    action_row = flat_action // W
    action_col = flat_action % W
    action = jnp.stack([action_row, action_col], axis=-1)  # Shape (B, 2)


    observations = loop_state.observations.at[step_idx].set(current_obs)
    actions = loop_state.actions.at[step_idx].set(action)
    values = loop_state.values.at[step_idx].set(value)
    logprobs = loop_state.logprobs.at[step_idx].set(logprob)
    current_players = loop_state.current_players.at[step_idx].set(current_player)

    next_state, next_obs, step_rewards, dones, info = env.step(current_state, action)

    rewards = loop_state.rewards.at[step_idx].set(step_rewards)
    dones_buffer = loop_state.dones.at[step_idx].set(dones)

    # Update termination indices: if not already terminated and current step is done, record step_idx
    current_termination_indices = loop_state.termination_step_indices
    not_terminated_yet = current_termination_indices == jnp.iinfo(jnp.int32).max
    new_termination_indices = jnp.where(
        not_terminated_yet & dones,
        step_idx,  # Record current step index as termination step
        current_termination_indices,  # Keep existing index (either max or previously recorded step)
    )

    return loop_state._replace(
        state=next_state,
        obs=next_obs,
        observations=observations,
        actions=actions,
        values=values,
        rewards=rewards,
        dones=dones_buffer,
        logprobs=logprobs,
        current_players=current_players,
        step_idx=step_idx + 1,
        rng=rng,
        termination_step_indices=new_termination_indices,
    )

In [25]:

@partial(
    jit,
    static_argnames=["env", "black_actor_critic", "white_actor_critic", "buffer_size"],
)
def run_episode(
    env,
    black_actor_critic,
    black_params,
    white_actor_critic,
    white_params,
    rng,
    buffer_size
) -> Tuple[
    Dict[str, Any], EnvState, jax.random.PRNGKey
]:  # Return full buffers, final_state, rng
    """
    Collect trajectories for self-play with separate black and white models.
    Runs until all environments are done.
    Buffers are allocated based on buffer_size.

    Args:
        env: An instance of JaxEnvBase. Assumes player 1 is "black".
        black_actor_critic: ActorCritic model for black player (first player).
        black_params: Parameters for black player model.
        white_actor_critic: ActorCritic model for white player (second player).
        white_params: Parameters for white player model.
        rng: JAX RNG key.
        buffer_size: The size of the trajectory buffers to allocate.

    Returns:
        full_trajectory: Dict containing the full, un-sliced buffers (observations, actions, rewards, dones, logprobs, current_players, valid_mask, T, termination_indices).
                         Arrays have length buffer_size.
        final_state: The environment state after the final step taken in the loop.
        rng: updated RNG key.

    Note: This function can simulate the behavior of `run_selfplay` function
          by passing the same actor_critic model and params for both the black and white players.
    """
    initial_state, initial_obs, _ = env.reset()

    buffers = env.initialize_trajectory_buffers(buffer_size)
    observations, actions, values, rewards, dones_buffer, logprobs, current_players_buffer = (
        buffers  # Unpack players buffer
    )
    B = initial_obs.shape[0]  # Infer batch size
    initial_termination_indices = lax.with_sharding_constraint(jnp.full(
        (B,), jnp.iinfo(jnp.int32).max, dtype=jnp.int32
    ),mesh_rules("batch"))

    initial_loop_state = LoopState(
        state=initial_state,
        obs=initial_obs,
        observations=observations,
        actions=actions,
        values=values,
        rewards=rewards,
        dones=dones_buffer,
        logprobs=logprobs,
        current_players=current_players_buffer, 
        step_idx=0,
        rng=rng,
        termination_step_indices=initial_termination_indices,  # Initialize termination indices
    )

    def cond_fn(l_state: LoopState) -> bool:
        return ~jnp.all(l_state.state.dones)

    @partial(jit, static_argnames=["env", "black_actor_critic", "white_actor_critic"])
    def body_fn_alternating(
        l_state: LoopState,
        env,
        black_actor_critic,
        black_params,
        white_actor_critic,
        white_params,
    ) -> LoopState:
        current_step = l_state.step_idx
        is_black_turn = current_step % 2 == 0

        return jax.lax.cond(
            is_black_turn,
            lambda s: player_move(s, env, black_actor_critic, black_params),
            lambda s: player_move(s, env, white_actor_critic, white_params),
            l_state,
        )

    def body_fn_wrapped(l_state: LoopState) -> LoopState:
        # Pass static args explicitly if needed by jit context, or rely on closure
        return body_fn_alternating(
            l_state,
            env,
            black_actor_critic,
            black_params,
            white_actor_critic,
            white_params,
        )

    final_state = lax.while_loop(cond_fn, body_fn_wrapped, initial_loop_state)

    #final value needed for GAE calculation
    final_value = final_state.values[-1]
    final_values = final_state.values.at[-1].set(final_value)



    term_indices = final_state.termination_step_indices  # Shape (B,)
    T = final_state.step_idx  # Use actual steps taken up to buffer_size
    B = initial_obs.shape[0]

    # Ensure T is used correctly for the mask dimensions even if less than buffer_size
    step_indices = jnp.arange(buffer_size)[:, None]  # Shape (buffer_size, 1)

    # Broadcast comparison: mask is True if step_index <= termination_index
    # Using '<=' ensures the terminal step itself is included as valid
    # We create a mask for the full buffer size
    valid_mask = step_indices <= term_indices[None, :]  # Shape (buffer_size, B)

    full_trajectory = {
        # Use the full buffers
        "observations": final_state.observations,  # Shape (buffer_size, B, ...)
        "actions": final_state.actions,  # Shape (buffer_size, B, ...)
        "values": final_values,  # Shape (buffer_size+1, B)
        "rewards": final_state.rewards,  # Shape (buffer_size, B)
        "dones": final_state.dones,  # Shape (buffer_size, B)
        "logprobs": final_state.logprobs,  # Shape (buffer_size, B)
        "current_players": final_state.current_players,  # Add stored players
        "valid_mask": valid_mask,  # Add the calculated mask, shape (buffer_size, B)
        "T": T,  # Actual number of steps executed (can be less than buffer_size)
        "termination_step_indices": final_state.termination_step_indices,  # Keep this too if needed elsewhere
    }
    rng = final_state.rng

    # Return the final EnvState directly, contains final_obs and final_player
    return full_trajectory, final_state.state, rng


In [9]:
@jit
def calculate_gae(
    rewards: jnp.ndarray,
    values: jnp.ndarray,
    dones: jnp.ndarray,
    gamma: float = 0.99,
    gae_lambda: float = 0.95,
) -> jnp.ndarray:
    """
    Compute Generalized Advantage Estimation (GAE) using lax.scan directly on batched data.

    Args:
        rewards: Rewards array, shape (T, B).
        values: Value estimates, shape (T+1, B). Include value of *terminal* state.
        dones: Done flags, shape (T, B). Dones resulting from the action at step t.
        gamma: Discount factor.
        gae_lambda: GAE lambda parameter.

    Returns:
        advantages: GAE advantages, shape (T, B).
        returns: GAE-based returns (advantages + values), shape (T, B).
    """
    T = rewards.shape[0]
    B = rewards.shape[1]
    assert (
        values.shape[0] == T + 1
    ), f"Values should have shape ({T+1}, B), but got {values.shape}"
    assert (
        values.shape[1] == B
    ), f"Values batch dimension mismatch: {values.shape[1]} vs {B}"
    assert (
        dones.shape[0] == T
    ), f"Dones time dimension mismatch: {dones.shape[0]} vs {T}"
    assert (
        dones.shape[1] == B
    ), f"Dones batch dimension mismatch: {dones.shape[1]} vs {B}"

    values_t = values[:-1]  # V(s_0)...V(s_{T-1}), shape (T, B)
    values_tp1 = values[1:]  # V(s_1)...V(s_T), shape (T, B)
    dones = dones.astype(jnp.float32)  # Ensure float, shape (T, B)

    # Calculate deltas: delta_t = r_t - gamma * V(s_{t+1}) * (1 - d_t)
    # minus sign as the next value is wrt to opponent
    # not sure if this is correct
    deltas = rewards - gamma * values_tp1 * (1.0 - dones) - values_t  # Shape (T, B)

    def scan_fn(carry_gae_batch, step_data_batch):
        # carry_gae_batch: shape (B,) - Represents A_{t+1} from the opponent's perspective
        # step_data_batch: tuple (delta_batch, done_batch), each shape (B,)
        delta_batch, done_batch = step_data_batch

        # Calculate GAE for the batch: A_t = delta_t - gamma * lambda * A_{t+1} * (1 - d_t)
        # Subtract the opponent's advantage A_{t+1} (carry_gae_batch) because of the zero-sum game
        # All operations are element-wise across the batch dimension.
        gae_batch = (
            delta_batch - gamma * gae_lambda * (1.0 - done_batch) * carry_gae_batch
        )  # Shape (B,)

        # Return the new carry (current GAE) and the value to store (also current GAE)
        return gae_batch, gae_batch

    # Prepare inputs for scan over time axis (0)
    # Scan operates on the leading dimension T.
    scan_inputs = (deltas, dones)  # Structure: ((T, B), (T, B))

    # Initial carry state for the scan needs to match the batch dimension
    initial_carry = jnp.zeros(B)  # Shape (B,)

    # Scan over axis 0 (time) in reverse.
    # Inputs structure ((T, B), (T, B)), step_data_batch will be ((B,), (B,))
    # Carry has shape (B,). Output ys will have shape (T, B).
    # lax.scan with reverse=True processes inputs from T-1 down to 0,
    # but returns the collected outputs in the original order (0..T-1).
    _, advantages = lax.scan(scan_fn, initial_carry, scan_inputs, reverse=True)

    # Calculate returns: R_t = A_t + V(s_t)
    returns = advantages + values_t  # Shape (T, B)

    return advantages, returns

@staticmethod
@partial(jax.jit,out_shardings=mesh_rules("buffer"),static_argnames=("gamma","gae_lambda"))
def compute_gae_targets(
    rewards: jnp.ndarray,
    values: jnp.ndarray,
    dones: jnp.ndarray,
    gamma: float,
    gae_lambda: float,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Calculates Generalized Advantage Estimation (GAE) and returns (targets for value function).
    The returned advantages are normalized (mean=0, std=1) over the batch.

    Args:
        rewards: Sequence of rewards, shape (T, B) or (T,).
        values: Sequence of value estimates V(s_t), including V(s_T), shape (T+1, B) or (T+1,).
        dones: Sequence of done flags, shape (T, B) or (T,).
        gamma: Discount factor.
        gae_lambda: GAE lambda parameter.

    Returns:
        tuple: (normalized_advantages, returns)
                - normalized_advantages: Normalized GAE estimates, shape (T, B) or (T,).
                - returns: Target values for the value function, shape (T, B) or (T,).
    """
    advantages_raw, returns = calculate_gae(
        rewards, values, dones, gamma, gae_lambda
    )

    # Normalize advantages over the batch
    advantages_mean = advantages_raw.mean()
    advantages_std = (
        advantages_raw.std() + 1e-8
    )  # Add epsilon for numerical stability
    advantages_normalized = (advantages_raw - advantages_mean) / advantages_std

    return advantages_normalized, returns


In [10]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Tuple
import distrax

class SimpleActorCritic(nn.Module):
    board_size: int

    @nn.compact
    def __call__(
        self, x: jnp.ndarray, current_players: jnp.ndarray
    ) -> Tuple[distrax.Categorical, jnp.ndarray]:
        """
        Simplified forward pass with dummy outputs maintaining shapes.
        Includes a dummy parameter to ensure 'params' collection is created.
        """
        # Add a dummy parameter to ensure 'params' collection is created by init
        # This parameter is not used in the actual computation.
        _ = self.param('dummy_param', nn.initializers.zeros, (1,))

        prefix_shape = x.shape[:-2]  # e.g., (batch,) or (T, batch) or ()
        
        num_actions = self.board_size * self.board_size
        # Using jnp.ones for logits to make all actions equally likely before masking
        dummy_policy_logits = jnp.ones(prefix_shape + (num_actions,)) 
        pi = distrax.Categorical(logits=dummy_policy_logits)

        dummy_value = jnp.zeros(prefix_shape) # Dummy value: (...)

        return pi, dummy_value

    def evaluate_actions(
        self, obs: jnp.ndarray, current_players: jnp.ndarray, actions: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """
        Simplified evaluate_actions with dummy outputs.
        """
        pi, value = self(obs, current_players) # This will use the __call__ with the dummy param

        prefix_shape = obs.shape[:-2]
        
        # Dummy log_prob, entropy, and value with correct shapes
        dummy_log_prob = jnp.zeros(prefix_shape) 
        dummy_entropy = jnp.zeros(prefix_shape)
        
        return dummy_log_prob, dummy_entropy, value


In [11]:
black_model = SimpleActorCritic(board_size=15)
white_model = SimpleActorCritic(board_size=15)
rng = jax.random.PRNGKey(0)
model_rng_b, model_rng_w = jax.random.split(rng)
dummy_obs = jnp.zeros((1,15,15))
dummy_player = jnp.ones((1,), dtype=jnp.int32)

black_params = black_model.init(model_rng_b, dummy_obs, dummy_player)["params"]
white_params = white_model.init(model_rng_w, dummy_obs, dummy_player)["params"]

black_params = jax.device_put(black_params,mesh_rules("replicated"))
white_params = jax.device_put(white_params,mesh_rules("replicated"))



In [13]:
board_size = 15
action_shape = (2,)
buffer_size = board_size * board_size
batch_size = 64

env = DummyEnv(batch_size,board_size,5)
rng = jax.random.PRNGKey(0)


In [26]:
full_trajectory, final_env_state, _ = run_episode( # Rollout RNG is consumed here
    env=env,
    black_actor_critic=black_model,
    black_params=black_params,
    white_actor_critic=white_model,
    white_params=white_params,
    rng=rng,
    buffer_size=buffer_size
)

In [15]:
full_trajectory["observations"].shape

(225, 64, 15, 15)

In [16]:
jax.debug.visualize_array_sharding(full_trajectory["rewards"])

In [9]:
def f(x):
    return lax.with_sharding_constraint(x,mesh_rules("buffer"))


rng = jax.random.PRNGKey(0)
arr = jax.random.normal(rng, (1024,1024))
arr = f(arr)
jax.debug.visualize_array_sharding(arr)
