In [3]:
import jax
import jax.numpy as jnp

def extend_substring(tensor: jnp.ndarray, i: int) -> jnp.ndarray:
    """
    Extends a substring of consecutive integers (starting from 0) along the 
    last dimension of a tensor up to index i.

    Args:
        tensor: The input JAX array.
        i: The index up to which the substring should be extended.

    Returns:
        A new JAX array with the substring extended.
    """

    # Find the start index of the substring (the first occurrence of 0)
    start_index = jnp.where(tensor == 0, size=1, fill_value=tensor.shape[-1])[0][0]

    # Create the extended substring using jnp.arange
    extended_substring = jnp.arange(0, jnp.maximum(0, i - start_index + 1))

    # Create a mask to identify the region to be updated
    mask = jnp.arange(tensor.shape[-1]) >= start_index
    mask = mask & (jnp.arange(tensor.shape[-1]) <= i)

    # Update the tensor using the mask and the extended substring
    # We need to handle cases where the extended_substring is shorter than the mask
    # due to i being less than start_index. In such cases, we pad extended_substring.
    
    padded_extended_substring = jnp.pad(
        extended_substring,
        (0, jnp.maximum(0, jnp.sum(mask) - extended_substring.shape[0])),
        mode='constant'
    )

    # Use where to choose either from padded_extended_substring or the original tensor
    updated_tensor = jnp.where(
        mask,
        padded_extended_substring,
        tensor
    )
    return updated_tensor

!JAX_PLATFORMS=cpu
# Example Usage:
tensor = jnp.array([
    [10, 11, 12, 0, 1, 2, 3, 17, 18, 19],
    [20, 21, 0, 1, 2, 3, 4, 5, 28, 29],
    [30, 31, 32, 33, 34, 0, 1, 2, 38, 39]
])

# Extend up to index 7
i = 7
extended_tensor = extend_substring(tensor, i)
print(f"Original Tensor:\n{tensor}\n")
print(f"Extended Tensor (up to index {i}):\n{extended_tensor}")

# Extend up to index 2 (substring should not be extended)
i = 2
extended_tensor = extend_substring(tensor, i)
print(f"Original Tensor:\n{tensor}\n")
print(f"Extended Tensor (up to index {i}):\n{extended_tensor}")

# Extend up to index 5 (different for each row)
i = 5
extended_tensor = extend_substring(tensor, i)
print(f"Original Tensor:\n{tensor}\n")
print(f"Extended Tensor (up to index {i}):\n{extended_tensor}")

  pid, fd = os.forkpty()


RuntimeError: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 402304. Not attempting to load libtpu.so in this process. (set JAX_PLATFORMS='' to automatically choose an available backend)