In [2]:
import numpy as np
import jax.numpy as jnp

# Define the number of units per team and the range of target coordinates
num_units = 16
k = 5  # Define the range of target_x and target_y

# Create grid of target coordinates within the specified range
target_coords = jnp.mgrid[-k:k+1, -k:k+1].reshape(2, -1).T

# Assuming team_positions and opponent_positions are already defined:
# team_positions, opponent_positions shapes are both (2, 16, 2)
# Example initialization (random positions within a 24x24 grid):
team_positions = jnp.array(np.random.randint(0, 24, (2, num_units, 2)))
opponent_positions = jnp.array(np.random.randint(0, 24, (2, num_units, 2)))

# Calculate all possible attack positions for each team's units
attack_positions = team_positions[:, :, None, :] + target_coords[None, None, :, :]
# Shape of attack_positions: (2, 16, (2*k+1)^2, 2)

# Expand opponent positions for broadcasting comparison
opponent_positions_expanded = opponent_positions[:, None, :, None, :]
# Shape becomes (2, 1, 16, 16, 2) to compare with attack positions

# Calculate if any attack position matches any opponent position
hits = jnp.all(attack_positions[..., None, :] == opponent_positions_expanded, axis=-1)
# hits shape: (2, 16, (2*k+1)^2, 16), where the last dimension checks across all opponent positions

# Aggregate hits to form a mask where True indicates a valid attack target
valid_targets = jnp.any(hits, axis=-1)
# valid_targets shape: (2, 16, (2*k+1)^2), True if any attack hits an opponent

# Output the valid targets mask
print("Valid Targets Mask:", valid_targets)


TypeError: eq got incompatible shapes for broadcasting: (2, 16, 121, 1, 2), (2, 1, 16, 1, 2).

In [3]:
target_coords = jnp.mgrid[-k:k+1, -k:k+1].reshape(2, -1).T

In [5]:
target_coords.shape

(121, 2)

In [13]:
import jax.numpy as jnp

# Assuming you define the maximum attack range
max_attack_range = 5.0  # Maximum distance units can attack

# Assuming team_positions and opponent_positions are already defined:
# Example initialization (random positions within a 24x24 grid):
team_positions = jnp.array(np.random.randint(0, 24, (2, 16, 2)))
opponent_positions = jnp.array(np.random.randint(0, 24, (2, 16, 2)))

# Calculate squared distances between each team unit and each opponent unit
# Broadcasting to create matrices of positions for subtraction and squaring
diff = team_positions[:, :, None, :] - opponent_positions[:, None, :, :]
distances_squared = jnp.sum(diff**2, axis=-1)

# Check which distances are within the squared attack range
# This avoids the need for computing a square root, making it more efficient
valid_targets = distances_squared <= max_attack_range ** 2

# valid_targets shape: (2, 16, 16), where True indicates a valid attack target within range
# print("Valid Targets Mask:", valid_targets)


In [30]:
(abs(diff[0, 0, :, 0]) <= max_attack_range) & (abs(diff[0, 0, :, 1]) <= max_attack_range)

Array([False,  True,  True, False, False, False, False, False, False,
        True, False, False, False, False, False, False], dtype=bool)

In [31]:
diff[0, 0, :, :]

Array([[ -1,  -7],
       [  4,   2],
       [  3,   3],
       [ -7,   7],
       [  5,  -6],
       [  9,   3],
       [ -7,  -6],
       [  6,   5],
       [  8,  -5],
       [  2,   5],
       [  9, -10],
       [  5, -12],
       [  5, -10],
       [ -8,  -8],
       [ -9, -15],
       [  8,  -7]], dtype=int32)

In [40]:
MAX_SAP_RANGE = 8

In [43]:
team_positions = jnp.array(np.random.randint(0, 24, (1, 4, 2)))
opponent_positions = jnp.array(np.random.randint(0, 24, (1, 4, 2)))

In [44]:
team_positions

Array([[[16,  9],
        [12,  2],
        [ 0,  4],
        [ 1,  3]]], dtype=int32)

In [45]:
opponent_positions

Array([[[22,  6],
        [15, 10],
        [ 3,  5],
        [22, 15]]], dtype=int32)

In [46]:
adjusted_opponent_positions = opponent_positions + MAX_SAP_RANGE
adjusted_opponent_positions

Array([[[30, 14],
        [23, 18],
        [11, 13],
        [30, 23]]], dtype=int32)

In [161]:
diff = -team_positions[:, :, None, :] + adjusted_opponent_positions[:, None, :, :]
diff

Array([[[[14,  5],
         [ 7,  9],
         [-5,  4],
         [14, 14]],

        [[18, 12],
         [11, 16],
         [-1, 11],
         [18, 21]],

        [[30, 10],
         [23, 14],
         [11,  9],
         [30, 19]],

        [[29, 11],
         [22, 15],
         [10, 10],
         [29, 20]]]], dtype=int32)

In [162]:
diff = np.array(diff)
diff[(diff < 0)] = 20

In [163]:
diff[..., 0].shape

(1, 4, 4)

In [165]:
import jax.numpy as jnp
from jax import jit, vmap

bool_array = jnp.full((4, 17), False, dtype=bool)

# Function to set True for one row given indices
def set_true_row(bool_array, indices):
    # valid_indices = indices[indices >= 0]  # Filter out invalid indices
    return bool_array.at[indices].set(True)

# Vectorize the function across rows using vmap
def update_bool_array(bool_array, turn_ons):
    # vmap across the first axis (rows of turn_ons and bool_array)
    return vmap(set_true_row, in_axes=(0, 0), out_axes=0)(bool_array, turn_ons)

# Use JIT compilation for performance
update_bool_array_jit = jit(update_bool_array)




In [None]:
attack_x = update_bool_array_jit(bool_array, jnp.squeeze(diff[..., 0], axis=0))
print(attack_x)

[[False False False False False False False  True False False False False
  False False  True False False]
 [False False False False False False False False False False False  True
  False False False False False]
 [False False False False False False False False False False False  True
  False False False False False]
 [False False False False False False False False False False  True False
  False False False False False]]


In [168]:
attack_y = update_bool_array_jit(bool_array, jnp.squeeze(diff[..., 1], axis=0))
print(attack_y)

[[False False False False  True  True False False False  True False False
  False False  True False False]
 [False False False False False False False False False False False  True
   True False False False  True]
 [False False False False False False False False False  True  True False
  False False  True False False]
 [False False False False False False False False False False  True  True
  False False False  True False]]


Array([[[14,  7, -5, 14],
        [18, 11, -1, 18],
        [30, 23, 11, 30],
        [29, 22, 10, 29]]], dtype=int32)