In [None]:
import jax

from jax import jit,lax,numpy as jnp

@jit
def discount_rewards(rewards, gamma):
    """
    Calculate discounted rewards for batched trajectories.

    Args:
      rewards: a 2D jnp.array with shape (T, batch)
      gamma: discount factor.

    Returns:
      A jnp.array of discounted rewards with the same shape as `rewards`.
    """

    def discount_single(r):
        def scan_fn(carry, reward):
            new_carry = reward + gamma * carry
            return new_carry, new_carry

        _, discounted_reversed = lax.scan(scan_fn, 0.0, r[::-1])

        return discounted_reversed[::-1]

    discounted = jax.vmap(discount_single, in_axes=1, out_axes=1)(rewards)
    
    T = rewards.shape[0]
    alternating_signs = jnp.power(-1.0, jnp.arange(T) % 2)
    
    return discounted * alternating_signs[:, None]
