In [1]:
import chex
import jax
import jax.numpy as jnp
from jumanji.environments.logic.game_2048.env import Game2048
from jumanji.environments.logic.game_2048.types import Board
from jumanji.environments.logic.game_2048.utils import move_up, move_right, move_down, move_left

In [2]:
SEED = 0
SIZE = 4
key = jax.random.PRNGKey(SEED)
env = Game2048(SIZE)
state, timestep = env.reset(key)

In [3]:
def transform_board(board: Board, symmetry: int, axis: 0 | 1 = 0) -> Board:
    if abs(symmetry) >= 8:
        raise ValueError
    if axis != 0 and axis != 1:
        raise ValueError
    board = jnp.flip(board, axis) if symmetry >= 4 else board
    board = jnp.rot90(board, symmetry)
    board = jnp.flip(board, axis) if symmetry <= -4 else board
    return board

In [4]:
assert all((state.board == transform_board(transform_board(state.board, symmetry), -symmetry)).all() for symmetry in range(8))

In [5]:
def transform_actions(actions: chex.Array, symmetry: int, axis: 0 | 1 = 0) -> chex.Array:
    if abs(symmetry) >= 8:
        raise ValueError
    if axis != 0 and axis != 1:
        raise ValueError
    mask1 = jnp.array((axis, axis + 2))
    mask2 = jnp.array((axis + 2, axis))
    actions = actions.at[mask1].set(actions[mask2]) if symmetry <= -4 else actions
    actions = actions[(jnp.arange(4) - symmetry) % 4]
    actions = actions.at[mask1].set(actions[mask2]) if symmetry >= 4 else actions
    return actions

In [6]:
actions = jnp.arange(4)
assert all((actions == transform_actions(transform_actions(actions, symmetry, 0), -symmetry, 0)).all() for symmetry in range(8))
assert all((actions == transform_actions(transform_actions(actions, symmetry, 1), -symmetry, 1)).all() for symmetry in range(8))

In [16]:
jit_transform_actions = jax.jit(transform_actions, static_argnums=(1, 2))
actions = jnp.zeros(4).at[0].set(1)
for symmetry in range(8):
    print(jit_transform_actions(actions, symmetry))

[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
