In [2]:
import numpy as np
import jax.numpy as jnp

# Define the number of units per team and the range of target coordinates
num_units = 16
k = 5  # Define the range of target_x and target_y

# Create grid of target coordinates within the specified range
target_coords = jnp.mgrid[-k:k+1, -k:k+1].reshape(2, -1).T

# Assuming team_positions and opponent_positions are already defined:
# team_positions, opponent_positions shapes are both (2, 16, 2)
# Example initialization (random positions within a 24x24 grid):
team_positions = jnp.array(np.random.randint(0, 24, (2, num_units, 2)))
opponent_positions = jnp.array(np.random.randint(0, 24, (2, num_units, 2)))

# Calculate all possible attack positions for each team's units
attack_positions = team_positions[:, :, None, :] + target_coords[None, None, :, :]
# Shape of attack_positions: (2, 16, (2*k+1)^2, 2)

# Expand opponent positions for broadcasting comparison
opponent_positions_expanded = opponent_positions[:, None, :, None, :]
# Shape becomes (2, 1, 16, 16, 2) to compare with attack positions

# Calculate if any attack position matches any opponent position
hits = jnp.all(attack_positions[..., None, :] == opponent_positions_expanded, axis=-1)
# hits shape: (2, 16, (2*k+1)^2, 16), where the last dimension checks across all opponent positions

# Aggregate hits to form a mask where True indicates a valid attack target
valid_targets = jnp.any(hits, axis=-1)
# valid_targets shape: (2, 16, (2*k+1)^2), True if any attack hits an opponent

# Output the valid targets mask
print("Valid Targets Mask:", valid_targets)


TypeError: eq got incompatible shapes for broadcasting: (2, 16, 121, 1, 2), (2, 1, 16, 1, 2).

In [3]:
target_coords = jnp.mgrid[-k:k+1, -k:k+1].reshape(2, -1).T

In [5]:
target_coords.shape

(121, 2)

In [13]:
import jax.numpy as jnp

# Assuming you define the maximum attack range
max_attack_range = 5.0  # Maximum distance units can attack

# Assuming team_positions and opponent_positions are already defined:
# Example initialization (random positions within a 24x24 grid):
team_positions = jnp.array(np.random.randint(0, 24, (2, 16, 2)))
opponent_positions = jnp.array(np.random.randint(0, 24, (2, 16, 2)))

# Calculate squared distances between each team unit and each opponent unit
# Broadcasting to create matrices of positions for subtraction and squaring
diff = team_positions[:, :, None, :] - opponent_positions[:, None, :, :]
distances_squared = jnp.sum(diff**2, axis=-1)

# Check which distances are within the squared attack range
# This avoids the need for computing a square root, making it more efficient
valid_targets = distances_squared <= max_attack_range ** 2

# valid_targets shape: (2, 16, 16), where True indicates a valid attack target within range
# print("Valid Targets Mask:", valid_targets)


In [30]:
(abs(diff[0, 0, :, 0]) <= max_attack_range) & (abs(diff[0, 0, :, 1]) <= max_attack_range)

Array([False,  True,  True, False, False, False, False, False, False,
        True, False, False, False, False, False, False], dtype=bool)

In [31]:
diff[0, 0, :, :]

Array([[ -1,  -7],
       [  4,   2],
       [  3,   3],
       [ -7,   7],
       [  5,  -6],
       [  9,   3],
       [ -7,  -6],
       [  6,   5],
       [  8,  -5],
       [  2,   5],
       [  9, -10],
       [  5, -12],
       [  5, -10],
       [ -8,  -8],
       [ -9, -15],
       [  8,  -7]], dtype=int32)

In [251]:
MAX_SAP_RANGE = 4

In [316]:
team_positions = jnp.array([[[100, 100],
        [15, 18],
        [12, 11],
        [ -2,  -2]]])

In [317]:
opponent_positions = jnp.array([[[5,  5],
        [ 7, 23],
        [ 1, 18],
        [13, 16],
        [12, 16],
        [ 3,  1]]])

In [318]:
adjusted_opponent_positions = opponent_positions + MAX_SAP_RANGE
adjusted_opponent_positions

Array([[[ 9,  9],
        [11, 27],
        [ 5, 22],
        [17, 20],
        [16, 20],
        [ 7,  5]]], dtype=int32)

In [319]:
diff = -team_positions[:, :, None, :] + adjusted_opponent_positions[:, None, :, :]
diff

Array([[[[-91, -91],
         [-89, -73],
         [-95, -78],
         [-83, -80],
         [-84, -80],
         [-93, -95]],

        [[ -6,  -9],
         [ -4,   9],
         [-10,   4],
         [  2,   2],
         [  1,   2],
         [ -8, -13]],

        [[ -3,  -2],
         [ -1,  16],
         [ -7,  11],
         [  5,   9],
         [  4,   9],
         [ -5,  -6]],

        [[ 11,  11],
         [ 13,  29],
         [  7,  24],
         [ 19,  22],
         [ 18,  22],
         [  9,   7]]]], dtype=int32)

In [320]:
diff = jnp.where(diff < 0, 20, diff)

In [321]:
diff[..., 0].shape

(1, 4, 6)

In [322]:
import jax.numpy as jnp
from jax import jit, vmap


# Function to set True for one row given indices
def set_true_row(bool_array, indices):
    # valid_indices = indices[indices >= 0]  # Filter out invalid indices
    return bool_array.at[indices].set(True)

# Vectorize the function across rows using vmap
def update_bool_array(bool_array, turn_ons):
    # vmap across the first axis (rows of turn_ons and bool_array)
    return vmap(set_true_row, in_axes=(0, 0), out_axes=0)(bool_array, turn_ons)

# Use JIT compilation for performance
update_bool_array_jit = jit(update_bool_array)

In [323]:
diff[..., 0].shape

(1, 4, 6)

In [324]:
jnp.squeeze(diff[..., 0], axis=0).shape

(4, 6)

In [325]:
bool_array = jnp.full((4, (2 * MAX_SAP_RANGE) + 1), False, dtype=bool)

In [326]:
attack_x = update_bool_array_jit(bool_array, jnp.squeeze(diff[..., 0], axis=0))
print(attack_x)

[[False False False False False False False False False]
 [False  True  True False False False False False False]
 [False False False False  True  True False False False]
 [False False False False False False False  True False]]


In [327]:
attack_y = update_bool_array_jit(bool_array, jnp.squeeze(diff[..., 1], axis=0))
print(attack_y)

[[False False False False False False False False False]
 [False False  True False  True False False False False]
 [False False False False False False False False False]
 [False False False False False False False  True False]]


In [342]:
(attack_x.sum(axis=1) > 0) & (attack_y.sum(axis=1) > 0) & ((jnp.squeeze(diff, axis=0).sum(-1) < (4 * MAX_SAP_RANGE)).sum(-1) > 0)

Array([False,  True, False, False], dtype=bool)

In [340]:
(jnp.squeeze(diff, axis=0).sum(-1) < (4 * MAX_SAP_RANGE)).sum(-1)

Array([0, 2, 2, 0], dtype=int32)

In [407]:
import jax
import jax.numpy as jnp
from jax import vmap

# Example logits3 array and ranges
rng = jax.random.PRNGKey(42)
# logits3 = jax.random.normal(rng, (2, 16, 17))  # Dummy data with the shape (2, 16, 17)
logits3 = jnp.ones((2, 16, 17))
ranges = jnp.array([5, 3])  # Example ranges a = 5, b = 3

# Function to set part of the array to -inf up to a certain index
def set_to_neg_inf(upto_idx, logits):
    # print(upto_idx)
    # Using JAX's numpy indexing to set values
    # logits = logits.at[:, :upto_idx].set(-jnp.inf)
    # logits = logits.at[:, -upto_idx:].set(-jnp.inf)
    return logits.at[:, : upto_idx].set(-jnp.inf)

# Using vmap to apply this function across the first dimension (different batches)
# The `in_axes` argument is set to (0, 0) to indicate that both logits and upto_idx are
# mapped over the first dimension.
mask_logits = vmap(set_to_neg_inf)(ranges, logits3)

print(mask_logits)


IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([5, 3], dtype=int32)
  batch_dim = 0, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [412]:
import jax
import jax.numpy as jnp

def mask_logits_slice(logits_slice, cutoff):
    # logits_slice: (16, 17)
    # cutoff: integer value like a or b
    cols = logits_slice.shape[1]
    # Create a mask: True where column index < cutoff
    mask = jnp.arange(cols) < cutoff
    # Apply mask to set values to -inf where mask is True
    return jnp.where(mask[None, :], -jnp.inf, logits_slice)

ranges = jnp.array([1, 4])
masked_logits3 = jax.vmap(mask_logits_slice, in_axes=(0, 0))(logits3, ranges)


In [413]:
masked_logits3

Array([[[-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1.,   1.],
        [-inf,   1.,   1.,   1.,

In [420]:
import jax
import jax.numpy as jnp

def mask_logits_slice(logits_slice, cutoff):
    # logits_slice: (16, 17)
    # cutoff: integer value like a or b
    cols = logits_slice.shape[1]
    # Create a mask: True where column index < cutoff
    mask = jnp.arange(cols) < cutoff
    # Apply mask to set values to -inf where mask is True
    logits_slice = jnp.where(mask[None, :], -jnp.inf, logits_slice)
    
    mask2 = jnp.arange(cols) > (16 - cutoff)
    logits_slice = jnp.where(mask2[None, :], -jnp.inf, logits_slice)
    return logits_slice

ranges = jnp.array([1, 4])
masked_logits3 = jax.vmap(mask_logits_slice, in_axes=(0, 0))(logits3, ranges)
masked_logits3

Array([[[-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.,   1.,   1.,   1.,   1.,   1., -inf],
        [-inf,   1.,   1.,   1.,

In [21]:
import jax
import jax.numpy as jnp

def transform_observation(obs):
    # Horizontal flip
    flipped = jnp.flip(obs, axis=1)
    
    # Rotate 90 degrees clockwise after flip
    rotated = jnp.rot90(flipped, k=-1, axes=(0, 1))
    
    return rotated

# Assuming obs is your 24x24x9 observation tensor
rng = jax.random.PRNGKey(42)
obs = jax.random.randint(rng, (5, 5), 0, 3)
transformed_obs = transform_observation(obs)


In [22]:
obs

Array([[1, 0, 1, 1, 2],
       [1, 1, 0, 2, 1],
       [1, 2, 0, 2, 1],
       [0, 1, 0, 1, 2],
       [2, 2, 0, 2, 2]], dtype=int32)

In [23]:
transformed_obs

Array([[2, 2, 1, 1, 2],
       [2, 1, 2, 2, 1],
       [0, 0, 0, 0, 1],
       [2, 1, 2, 1, 0],
       [2, 0, 1, 1, 1]], dtype=int32)

In [96]:
def transform_coordinates(x, y, map_size=24):
    # Adjust for horizontal flip
    new_x, new_y = map_size - 1 - x, y
    # Adjust for 90-degree rotation clockwise
    final_x, final_y = map_size - 1 - new_y, new_x
    
    return final_x, final_y

# Example usage:
x, y = 12,20 # Example starting position
transformed_x, transformed_y = transform_coordinates(x, y, 24)
print(f"Transformed coordinates: ({transformed_x}, {transformed_y})")


Transformed coordinates: (3, 11)


In [78]:
import jax.numpy as jnp

def transform_coordinates(unit_positions):
    # Constants
    MAP_WIDTH = 24
    MAP_HEIGHT = 24
    
    # Adjust for horizontal flip: (x, y) -> (MAP_WIDTH - 1 - x, y)
    flipped_positions = jnp.stack([MAP_WIDTH - 1 - unit_positions[:,:,0], unit_positions[:,:,1]], axis=-1)
    
    # Adjust for 90-degree rotation clockwise: (MAP_WIDTH - 1 - x, y) -> (y, MAP_WIDTH - 1 - x)
    rotated_positions = jnp.stack([MAP_HEIGHT - 1 - flipped_positions[:,:,1], flipped_positions[:,:,0]], axis=-1)
    
    return rotated_positions

# Assuming unit_positions_team is your 4x16x2 tensor
unit_positions_team = jnp.zeros((4, 16, 2))  # Example tensor for demonstration
transformed_positions = transform_coordinates(unit_positions_team)


In [None]:
input_positions = jnp.array([[[0, 0], [23, 23], [12, 0], [0, 23]],
                                [[23, 0], [0, 23], [12, 23], [23, 0]],
                                [[11, 11], [12, 12], [13, 13], [14, 14]],
                                [[0, 12], [23, 12], [12, 20], [12, 0]]])

# Expected output after horizontal flip and 90-degree rotation clockwise
expected_output = jnp.array([[[23, 23], [0, 0], [23, 11], [0, 23]],
                                [[23, 0], [0, 23], [0, 11], [23, 0]],
                                [[12, 12], [11, 11], [10, 10], [9, 9]],
                                [[11, 23], [11, 0], [3, 11], [23, 12]]])

# Running the transformation function
transformed_positions = transform_coordinates(input_positions)

# Assert that the transformed positions match expected output
jnp.array_equal(transformed_positions, expected_output)

Array(False, dtype=bool)

In [80]:
transformed_positions

Array([[[23, 23],
        [ 0,  0],
        [23, 11],
        [ 0, 23]],

       [[23,  0],
        [ 0, 23],
        [ 0, 11],
        [23,  0]],

       [[12, 12],
        [11, 11],
        [10, 10],
        [ 9,  9]],

       [[11, 23],
        [11,  0],
        [-1, 11],
        [23, 11]]], dtype=int32)