In [24]:
import jax
import jax.numpy as jnp

win_length = 5

ones = jnp.ones((1, win_length), dtype=jnp.int8)  # shape: (1, win_length)
zeros = jnp.zeros((win_length-1, win_length), dtype=jnp.int8)  # shape: (win_length-1, win_length)

# Create horizontal kernel
horizontal = jnp.expand_dims(jnp.vstack([ones, zeros]), axis=(0,1))  # shape: (1, 1, win_length, win_length)
vertical = jnp.expand_dims(jnp.vstack([ones, zeros]).T, axis=(0,1))  # shape: (1, 1, win_length, win_length)
diagonal = jnp.expand_dims(jnp.eye(win_length, dtype=jnp.int8), axis=(0, 1))  # shape: (1, 1, win_length, win_length)
anti_diagonal = jnp.expand_dims(jnp.fliplr(jnp.eye(win_length, dtype=jnp.int8)), axis=(0, 1))  # shape: (1, 1, win_length, win_length)

# Stack all kernels


In [47]:
kernels = jnp.concatenate([horizontal, vertical, diagonal, anti_diagonal], axis=0)

def check_winner(board, kernels):
    """
    Performs convolution between the board and all kernels to check for a winning state.

    Args:
        board (jnp.ndarray): The current Gomoku board state with shape (N, H, W, C).
        kernels (jnp.ndarray): The stacked kernels with shape (H, W, C_in, C_out).

    Returns:
        bool: True if a winning state is detected, False otherwise.
    """
    # Perform convolution
    conv_result = jax.lax.conv_general_dilated(
        lhs=board,
        rhs=kernels,
        window_strides=(1, 1),
        padding='VALID',
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )
    print(conv_result.shape)
    
    conv_result = conv_result.transpose(0,3,1,2)
    # Check if any convolution result equals 5 or -5
    winner = jnp.any((conv_result == 5) | (conv_result == -5))
    
    return bool(winner)

# Initialize the board
board = jnp.array([
    [0, 0, 0, 1, 1],
    [0, 1, 1, 1, 0],
    [1, 0, 1, 1, 0],
    [0, 1, 1, 1, 0],
    [1, 0, 0, 0, 1]
],dtype=jnp.int8)  # Shape: (5, 5)
board = jnp.pad(board,pad_width=((0,5),(0,5)),mode='constant',constant_values=0)

# Reshape the board to match the expected dimensions (N, H, W, C)
board = board.reshape(1, 10, 10, 1)  # Shape: (1, 5, 5, 1)

# Reshape kernels to match the expected dimensions (H, W, C_in, C_out)
kernels = kernels.transpose(2, 3,1,0)  # Shape: (5, 5, 1, 4)
print(check_winner(board, kernels))
board=board.squeeze(axis=3)
kernels=kernels.squeeze(axis=2)



(1, 6, 6, 4)
True


In [35]:
print(board.shape)
print(kernels.shape)

(1, 10, 10)
(5, 5, 4)


In [36]:
def get_valid_actions(board):
    return board == 0

print(get_valid_actions(board))

87


In [46]:
from jax import random
arr = random.normal(random.PRNGKey(0), shape=(10,10))

print(arr[None,:,None,:,None].shape)

(1, 10, 1, 10, 1)


In [48]:

class GomokuEnv:
    def __init__(self, board_size=15, num_envs=1, device=jax.devices('cpu')[0]):
        """
        Initializes the Gomoku environment.

        Args:
            board_size (int): Size of the game board (default is 15).
            num_envs (int): Number of parallel environments (default is 1).
            device: JAX device to run the computations on.
        """
        self.board_size = board_size
        self.num_envs = num_envs
        self.device = device
        self.win_length = 5
        self.board = jnp.zeros((num_envs,board_size+self.win_length-1, board_size+self.win_length-1), dtype=jnp.int8, device=device) #shape: (num_envs, board_size, board_size)
        #extra 0s for convolution

        self.current_player = jnp.ones((num_envs,), dtype=jnp.int8, device=device) #shape: (num_envs,) 1 for black, -1 for white
        self.game_over = jnp.zeros((num_envs,), dtype=bool, device=device) #shape: (num_envs,)

        # Initialize kernels for win detection
        self.kernels = self._create_kernels()

    def to(self,device=jax.devices('cpu')[0]):
        """
        Moves the environment to a different device.

        Args:
            device: JAX device to move the environment to.

        Returns:
            The environment object itself.
        """
        self.device = device
        self.board = self.board.to(device)
        self.current_player = self.current_player.to(device)
        self.game_over = self.game_over.to(device)
        self.kernels = self.kernels.to(device)
        return self

    def _create_kernels(self):
        """
        Creates kernels for horizontal, vertical, diagonal, and anti-diagonal win detection.

        Returns:
            jnp.ndarray: A tensor of shape (4, 1, 5, 5) containing the kernels.
        """
        ones = jnp.ones((1, self.win_length), dtype=jnp.int8)  # shape: (1, win_length)
        zeros = jnp.zeros((self.win_length-1, self.win_length), dtype=jnp.int8)  # shape: (win_length-1, win_length)

        horizontal = jnp.expand_dims(jnp.vstack([ones, zeros]), axis=(0,1))  # shape: (1, 1, win_length, win_length)
        vertical = jnp.expand_dims(jnp.vstack([ones, zeros]).T, axis=(0,1))  # shape: (1, 1, win_length, win_length)
        diagonal = jnp.expand_dims(jnp.eye(self.win_length, dtype=jnp.int8), axis=(0, 1))  # shape: (1, 1, win_length, win_length)
        anti_diagonal = jnp.expand_dims(jnp.fliplr(jnp.eye(self.win_length, dtype=jnp.int8)), axis=(0, 1))  # shape: (1, 1, win_length, win_length)

        # Stack all kernels
        kernels = jnp.concatenate([horizontal, vertical, diagonal, anti_diagonal], axis=0)[jnp.newaxis,:,:,:].transpose(2,3,0,1)
        return kernels  # Shape: (win_length, win_length, 1, 4)

    def reset(self, env_indices=None):
        """
        Resets the environment(s).

        Args:
            env_indices (array-like, optional): Specific environments to reset. 
                                                If None, all environments are reset.

        Returns:
            Tuple of (board, current_player, game_over) after reset.
        """
        if env_indices is None:
            self.board = jnp.zeros((self.num_envs, self.board_size, self.board_size), dtype=jnp.int8, device=self.device)
            self.current_player = jnp.ones((self.num_envs,), dtype=jnp.int8, device=self.device)
            self.game_over = jnp.zeros((self.num_envs,), dtype=bool, device=self.device)
        else:
            self.board = self.board.at[env_indices].set(
                jnp.zeros((len(env_indices), self.board_size, self.board_size), dtype=jnp.int8, device=self.device)
            )
            self.current_player = self.current_player.at[env_indices].set(
                jnp.ones((len(env_indices),), dtype=jnp.int8, device=self.device)
            )
            self.game_over = self.game_over.at[env_indices].set(
                jnp.zeros((len(env_indices),), dtype=bool, device=self.device)
            )
        
        return self.get_state()

    def step(self, actions: jnp.ndarray):
        """
        Performs a step in multiple environments by applying the given actions.

        Args:
            actions (jnp.ndarray): Array of shape (num_envs, 2) where each row is (row, col) indices 
                                   for the corresponding environment's action.

        Returns:
            Tuple of (board, current_player, game_over) after the actions.
        """
        env_indices = jnp.arange(self.num_envs)
        rows, cols = actions[:, 0], actions[:, 1]

        # Ensure the actions are within the board
        valid_bounds = (0 <= rows) & (rows < self.board_size) & (0 <= cols) & (cols < self.board_size)

        # Check if the positions are already taken
        positions = self.board[env_indices, rows, cols]
        valid_positions = positions == 0

        # Combine validity checks
        valid_actions = valid_bounds & valid_positions & ~self.game_over

        if not jnp.all(valid_actions):
            invalid_envs = jnp.where(~valid_actions)[0]
            raise ValueError(f"Invalid actions in environments: {invalid_envs}")

        # Update the board
        self.board = self.board.at[env_indices, rows, cols].set(self.current_player)

        # Check for win or draw using convolution
        winners, dones = self.check_game_over()

        # Update game_over based on winners and draws
        self.game_over = jnp.logical_or(self.game_over, dones | winners)

        # Switch players
        self.current_player = self.current_player.at[env_indices].set(-self.current_player)

        return self.get_state()

    def check_game_over(self):
        """
        Checks if the latest actions resulted in a win or draw.

        Returns:
            Tuple of (winners, dones) where:
                - winners is a boolean array indicating if the current move caused a win.
                - dones is a boolean array indicating if the game is a draw.
        """
        # Prepare the board for convolution
        # Shape: (num_envs, 1, board_size, board_size)
        player_boards = (self.board * self.current_player[:,jnp.newaxis,jnp.newaxis]).astype(jnp.int8)[:,:,:,jnp.newaxis]

        # Perform convolution
        conv_output = lax.conv_general_dilated(
            player_boards,
            self.kernels,
            window_strides=(1, 1),
            padding='VALID',
            dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
            feature_group_count=1
        )  # Shape: (num_envs, board_size - 4, board_size - 4,4)

        win_condition = conv_output == self.win_length
        winners = jnp.any(win_condition, axis=(1, 2, 3))  # Shape: (num_envs,)

        # Check for draw (no empty spaces left)
        empty_spaces = jnp.any(self.board == 0, axis=(1, 2))
        dones = ~empty_spaces

        return winners, dones

    def get_state(self):
        """
        Retrieves the current state of the environment(s).

        Returns:
            Tuple of (board, current_player, game_over).
        """
        return self.board, self.current_player, self.game_over


In [None]:
game = GomokuEnv(num_envs=20)

for i in range(100):
    action_mask = game.get_action_mask()
    

In [7]:
from jax import lax
import jax.numpy as jnp
WIN_LENGTH = 5
def _create_kernels():
    """
    Creates the kernels to check win conditions.
    """
    ones = jnp.ones((1, WIN_LENGTH), dtype=jnp.float32)  # shape: (1, win_length)
    zeros = jnp.zeros((WIN_LENGTH-1, WIN_LENGTH), dtype=jnp.float32)  # shape: (win_length-1, win_length)

    horizontal = jnp.expand_dims(jnp.vstack([ones, zeros]), axis=(0,1))  # shape: (1, 1, win_length, win_length)
    vertical = jnp.expand_dims(jnp.vstack([ones, zeros]).T, axis=(0,1))  # shape: (1, 1, win_length, win_length)
    diagonal = jnp.expand_dims(jnp.eye(WIN_LENGTH, dtype=jnp.float32), axis=(0, 1))  # shape: (1, 1, win_length, win_length)
    anti_diagonal = jnp.expand_dims(jnp.fliplr(jnp.eye(WIN_LENGTH, dtype=jnp.float32)), axis=(0, 1))  # shape: (1, 1, win_length, win_length)

    # Stack all kernels
    kernels = jnp.concatenate([horizontal, vertical, diagonal, anti_diagonal], axis=0).transpose(2,3,1,0)
    return kernels  # Shape: (win_length, win_length, 1, 4)


In [42]:
from jax import lax
WIN_LENGTH = 5


def _check_win(board,kernels):
    """
    Checks if placing a stone at (row, col) leads to a win for the given player.
    A win is defined as having 5 or more consecutive stones in any direction.
    """
    # Prepare the board for convolution
    # Shape: (num_envs, board_size, board_size, 1)

    # Perform convolution
    conv_output = lax.conv_general_dilated(
        board,
        kernels,
        window_strides=(1, 1),
        padding='SAME',
        dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
        feature_group_count=1
    )  # Shape: (active_boards_shape, board_size, board_size,4)

    win_condition = conv_output == WIN_LENGTH
    print(win_condition.shape)
    winners = jnp.any(win_condition)
    return winners.item()


In [43]:
kernels = _create_kernels()
print(kernels.shape)

(5, 5, 1, 4)


In [44]:
import jax
board = jax.random.randint(key=jax.random.PRNGKey(0),shape=(10,10),minval=0,maxval=2).astype(jnp.float32)
boards = board[jnp.newaxis,:,:,jnp.newaxis]
_check_win(boards,kernels)

(1, 10, 10, 4)


True

In [5]:
import jax.numpy as jnp

arr =  jnp.array([[True,True],[True,True]])

if jnp.all(arr):
    print("all true")
else:
    print("not all true")

all true


In [8]:
import jax
jax.tree.map(lambda x, y: [x] + y, [5, 6,8], [[7, 9], [1, 2],[0]])

[[5, 7, 9], [6, 1, 2], [8, 0]]

In [23]:
import jax
import jax.numpy as jnp
from jax import lax

WIN_LENGTH = 5

def _create_kernels():
    ones = jnp.ones((1, WIN_LENGTH), dtype=jnp.float32)
    zeros = jnp.zeros((WIN_LENGTH - 1, WIN_LENGTH), dtype=jnp.float32)

    horizontal = jnp.expand_dims(jnp.vstack([ones, zeros]), axis=(0, 1))
    vertical = jnp.expand_dims(jnp.vstack([ones, zeros]).T, axis=(0, 1))
    diagonal = jnp.expand_dims(jnp.eye(WIN_LENGTH, dtype=jnp.float32), axis=(0, 1))
    anti_diagonal = jnp.expand_dims(jnp.fliplr(jnp.eye(WIN_LENGTH, dtype=jnp.float32)), axis=(0, 1))

    kernels = jnp.concatenate([horizontal, vertical, diagonal, anti_diagonal], axis=0).transpose(2, 3, 1, 0)
    return kernels

def _check_win(arr,kernels):
    """
    Checks if the current board state has a winning condition.
    Uses convolution with predefined kernels.
    Returns a JAX boolean (without using .item()).
    """
    player_boards = (arr * -1)[jnp.newaxis, :, :, jnp.newaxis]
    padding = ((WIN_LENGTH - 1, WIN_LENGTH - 1), (WIN_LENGTH - 1, WIN_LENGTH - 1))

    conv_output = lax.conv_general_dilated(
        player_boards,
        kernels,
        window_strides=(1, 1),
        padding=padding,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
    )
    win_condition = conv_output == WIN_LENGTH
    win = jnp.any(win_condition)
    return win  # Do not call .item(), just return the traced value


arr = jnp.array([[-1.,  0.,  0.,  0.,  0. -1. -1. -1.  1.],
 [-1., -1., -1.,  0.,  1.,  1., -1.,  1.,  0.],
 [ 0.,  1.,  0.,  0., -1., -1., -1.,  1.,  0.],
 [-1.,  1., -1., -1.,  1., -1., -1.,  1.,  1.],
 [-1.,  1.,  1., -1., -1.,  1.,  1.,  0.,  0.],
 [ 1.  0.  1. -1.  1. -1.  1. -1.  0.]
 [-1.  1.  1.  1. -1.  1. -1.  1.  0.]
 [ 0.  1.  0.  1. -1. -1.  1.  1.  1.]
 [ 1.  0.  1. -1. -1. -1.  0.,  1.,  0.]])

arr = arr.reshape(9,9)

kernels = _create_kernels()


_check_win(arr,kernels)














Array(True, dtype=bool)