In [1]:
import jax
import numpy as np
import jax.numpy as jnp

print(jax.__version__)

0.4.31


In [2]:
# Mutable function
def mult_in_place(arr, factor):
    arr *= factor # Modifies the input array
    return arr

# Immutable function
def mult(arr, factor):
    return arr * factor # Returns a new array

In [7]:
input_nparr = np.array([1, 2, 3])
factor = 2

new_nparr = mult(input_nparr, factor)
print(f"new_nparr after mult: {new_nparr}") # [2 4 6]
print(f"input_nparr after mult: {input_nparr}") # [1 2 3]

new_nprr2 = mult_in_place(input_nparr, factor)
print(f"new_nprr2 after mult_in_place: {new_nprr2}") # [2 4 6]
print(f"input_nparr after mult_in_place: {input_nparr}") # [2 4 6]

new_nparr after mult: [2 4 6]
input_nparr after mult: [1 2 3]
new_nprr2 after mult_in_place: [2 4 6]
input_nparr after mult_in_place: [2 4 6]


In [8]:
input_jarr = jnp.array([1, 2, 3])
factor = 2

new_jarr = mult(input_jarr, factor)
print(f"new_jarr after mult: {new_jarr}") # [2 4 6]
print(f"input_jarr after mult: {input_jarr}") # [1 2 3]

new_jarr2 = mult_in_place(input_jarr, factor)
print(f"new_jarr2 after mult_in_place: {new_jarr2}") # [2 4 6]
print(f"input_jarr after mult_in_place: {input_jarr}") # [1 2 3]

new_jarr after mult: [2 4 6]
input_jarr after mult: [1 2 3]
new_jarr2 after mult_in_place: [2 4 6]
input_jarr after mult_in_place: [1 2 3]


### map
Applies a function to all elements of an iterable (list, array, etc.) and returns a new iterable with the results.

In [9]:
"""
a. square is the function to be applied to each element.
b. jnp.vectorize(square)(data) applies square to each element of data using vectorization
(effciently applying a function to an array).
c. squared_data contains the squares of each element in data.

Same as jax.vmap(square)(data).

vectorize: the inputs are broadcast according to numpy broadcasting rules. 
vmap: the function is mapped across a single specific axis of the inputs.

More info at: https://stackoverflow.com/questions/69099847/jax-vectorization-vmap-and-or-numpy-vectorize

"""

def square(x):
    return x ** 2

data = jnp.array([1, 2, 3, 4])
squared_data_vmap = jax.vmap(square)(data)
squared_data_vectorized = jnp.vectorize(square)(data)

print(f"squared_data_vmap: {squared_data_vmap}") # [1 4 9 16]
print(f"squared_data_vectorized: {squared_data_vectorized}") # [1 4 9 16]

squared_data_vmap: [ 1  4  9 16]
squared_data_vectorized: [ 1  4  9 16]


### filter
Creates a new iterable containing elements from the original iterable that pass a certain test (defined by a function).

In [33]:
numbers = jnp.array([1,2,3,4,5])
print(numbers)
print(jnp.where((numbers  % 2) == 0, numbers, 0))

[1 2 3 4 5]
[0 2 0 4 0]


In [31]:
numbers2 = jnp.arange(10)
print(numbers2)
print(jnp.where((numbers2  % 2) == 0, numbers2, 0))

[0 1 2 3 4 5 6 7 8 9]
[0 0 2 0 4 0 6 0 8 0]


### reduce
Applies a function repeatedly to accumulate a single value from an iterable. 
It takes two arguments: the function and the iterable.

In [35]:
from functools import reduce

def add(x, y):
    return x + y

numbers3 = jnp.arange(10)
"""
reduce(add, data, O) starts with the initial value (O) and repeatedly applies add to it
and the next element in data, resulting in the total sum.

"""
total_sum = reduce(add, numbers3, 0)
print(total_sum) # 45

45


### grad
Takes a Jax function as input and returns a new function that computes the gradients of
the original function with respect to its inputs.      
</br>
• Scalar Functions:       
If f takes a single scalar input and returns a scalar output, the gradient will be a scalar value.      
</br>
• Vector-valued Functions:       
If f takes a vector as input and returns a scalar output, the gradient will be a vector with the same dimensions as the input vector.       
Each element represents the partial derivative with respect to the corresponding input element.     
</br>
• Vector-Valued Outputs:       
If f takes a vector as input and returns a vector as output, the gradient will be a matrix.       
Each row represents the gradients of the output vector elements with respect to each element of the input vector.      
</br>

In [44]:
from jax import grad

def f(x):
    return x ** 2

# f_grad is the gradient of f
f_grad = grad(f)

# Calculate the gradient of f for a specific input
input_val = jnp.arange(10, dtype=jnp.float32)
print(f"input_val: {input_val}") # [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]

# calculate the gradients vector
gradients = jnp.vectorize(f_grad)(input_val)
print(gradients) # [ 0.  2.  4.  6.  8. 10. 12. 14. 16. 18.]

input_val: [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[ 0.  2.  4.  6.  8. 10. 12. 14. 16. 18.]


In [45]:
def linear_function(x):
    return 2 * x + 1

grad_fn = grad(linear_function)
input_val = jnp.array(4.0)
print(f"input_val: {input_val}") # 4.0
print(grad_fn(input_val)) # 2.0

input_val: 4.0
2.0
