In [None]:
import jax
import jax.numpy as jnp

@jax.jit
def shuffle_last_two_dims(key, x):
    # Get the shape of the input array
    n_batches, _, h, w = x.shape
    
    # Create a random permutation of indices for each batch
    batch_indices = jnp.arange(h * w)
    shuffled_indices = jax.vmap(lambda k: jax.random.permutation(k, batch_indices, independent=True))(jax.random.split(key, n_batches))

    # Flatten the last two dimensions for easy reshaping
    x_flat = x.reshape(n_batches, 1, h * w)
    
    # Apply the shuffled indices to the last two dimensions for each batch
    x_shuffled = jax.vmap(lambda x_b, idx: x_b[:, idx])(x_flat, shuffled_indices)
    
    # Reshape back to the original shape
    x_shuffled = x_shuffled.reshape(n_batches, 1, h, w)
    
    return x_shuffled




In [None]:
# Simulating a training loop
def training_loop(n_steps, x):
    key = jax.random.PRNGKey(42)  # Initialize base key
    
    for step in range(n_steps):

        key, subkey = jax.random.split(key)  # Split the key at each iteration
        print(key)
        shuffled_x = shuffle_last_two_dims(subkey, x)  # Use the subkey
        # Simulate training step (e.g., forward pass, loss, backprop)
        print(f"Step {step} shuffled_x: \n", shuffled_x)
        
# Example usage
x = jnp.arange(2 * 1 * 4 * 4).reshape(2, 1, 4, 4)
training_loop(5, x)  # Running the loop for 5 steps
