In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from timeit import timeit

In [2]:
# array
x = jnp.array([1.0, 2.0, 3.0, 4.0])
print(jnp.mean(x)) # 2.5

2.5


In [3]:
# autograd
def f(x):
  return x**2 + 2*x + 1
grad_f = jax.grad(f) # 2*x + 2
print(grad_f(1.0)) # 4.0
value_and_gradient_f = jax.value_and_grad(f)
value, grad = value_and_gradient_f(1.0) # 4.0, 4.0
print(value, grad)

4.0
4.0 4.0


In [4]:
# partial differentia
def g(x, y):
  return x**2 + 2*x*y + 2*y**2
grad_gx = jax.grad(g, argnums=0) # 2*x + 2*y
print(grad_gx(2.0, 1.0)) # 6.0
grad_gy = jax.grad(g, argnums=1) # 2*x + 4*y
print(grad_gy(2.0, 1.0)) # 8.0

6.0
8.0


In [5]:
# Example: reinforcement learning model and Q-Learning algorithms
# Define a simple policy network function
def policy(params, obs):
    """A simple linear policy function: action = params * obs."""
    return params * obs

# Define the policy loss function
def policy_loss_fn(policy_params, log_alpha, obs):
    """
    Compute a policy loss based on a simplified function:
    policy_loss = mean(-action_value - exp(log_alpha) * entropy)
    """
    # Get action using the current policy
    new_action = policy(policy_params, obs)

    # Calculate Q values (dummy values here for simplicity)
    q1_value = 2 * new_action  # Simulated Q value 1
    q2_value = 3 * new_action  # Simulated Q value 2

    # Take the minimum of Q values to estimate the expected value
    q_mean = jnp.minimum(q1_value, q2_value)

    # Entropy (simulated as negative of absolute action)
    entropy = -jnp.abs(new_action)

    # Calculate the policy loss
    policy_loss = jnp.mean(-q_mean - jnp.exp(log_alpha) * entropy)

    return policy_loss

policy_params = jnp.array([0.5])  # Initial policy parameter
log_alpha = jnp.array(0.1)  # Log of the entropy coefficient
obs = jnp.array([1.0, 2.0, 3.0])  # Example observations

value_and_grad_fn = jax.value_and_grad(policy_loss_fn, argnums=(0, 1))  # (0, 1) indicates to compute gradients for both policy_params and log_alpha

# Calculate loss and gradients
policy_loss, (policy_grads, log_alpha_grads) = value_and_grad_fn(policy_params, log_alpha, obs)
print(f"Policy Loss: {policy_loss}")
print(f"Policy Params Gradient: {policy_grads}")
print(f"Log Alpha Gradient: {log_alpha_grads}")


Policy Loss: -0.8948290348052979
Policy Params Gradient: [-1.7896582]
Log Alpha Gradient: 1.1051709651947021


In [7]:
# vmap
x = jnp.array([1.0, 2.0, 3.0, 4.0])

def f(x):
  return x**2 + 2*x + 1

vmap_f = jax.vmap(f)
print(vmap_f(x)) # [4.0, 9.0, 16.0, 25.0]

def h(x, y):
  return x + y

vmap_h = jax.vmap(h, in_axes=(0, None))
y = 5.0
print(vmap_h(x, y)) # [6.0, 7.0, 8.0, 9.0]

[ 4.  9. 16. 25.]
[6. 7. 8. 9.]


In [17]:
# for
# function signature jax.lax.scan
# output, final_state = jax.lax.scan(f, init, xs)
# init: Initial state (or carry) of the loop.
# xs: Sequence (e.g., array) to iterate over.
# output: Results of all iterations.
# final_state: The final state of the carry variable.

# calculate cumsum
x = jnp.array([1, 2, 3, 4, 5])

def cumulative_sum(carry, x):
    new_carry = carry + x  # Update the carry by adding the current element
    y = new_carry          # The output is the updated carry (cumulative sum)
    return new_carry, y    # Return the new carry and the current result


init = 0

final_carry, cumulative_sums = jax.lax.scan(cumulative_sum, init, x)

# Print the results
print(f"Input sequence: {x}")
print(f"Cumulative sums: {cumulative_sums}") # intermediate results
print(f"Final carry (total sum): {final_carry}")

Input sequence: [1 2 3 4 5]
Cumulative sums: [ 1  3  6 10 15]
Final carry (total sum): 15


In [22]:
# if
def g(x):
  return jax.lax.cond(x>0, lambda x: x**2, lambda x: x+1, x)
print(g(-1.0))
print(g(2.0))

0.0
4.0


In [28]:
# seed
key = jax.random.PRNGKey(0) # special JAX array with shape (2,)
action_policy_key, q_key = jax.random.split(key, 2)
# Generate random values using each sub-key
random_action = jax.random.uniform(action_policy_key, shape=(3,))
random_q_values = jax.random.normal(q_key, shape=(3,))

print("Random Action Values:", random_action)
print("Random Q-Values:", random_q_values)

Random Action Values: [0.5552479  0.69474125 0.29765356]
Random Q-Values: [ 1.1378784  -1.2209548  -0.59153634]


In [29]:
# jax.jit Compiled with XLA
def poly_function(x):
    return jnp.sin(x)**2 + jnp.cos(x)**2 + x

# JIT-compiled version of the function
poly_function_jit = jax.jit(poly_function)

x = jnp.linspace(0, 100, 100_000)

non_jit_time = timeit(lambda: poly_function(x).block_until_ready(), number=100)
print(f"Average Time without JIT: {non_jit_time:.6f} seconds")

jit_time = timeit(lambda: poly_function_jit(x).block_until_ready(), number=100)
print(f"Average Time with JIT: {jit_time:.6f} seconds")

# Calculate speedup
speedup = non_jit_time / jit_time
print(f"Speedup using JIT: {speedup:.2f}x faster")

Average Time without JIT: 0.773352 seconds
Average Time with JIT: 0.196852 seconds
Speedup using JIT: 3.93x faster
