In [1]:
import jax
print(jax.__version__)

0.4.30


In [2]:
print(jax.devices())
print(jax.device_count())
print(jax.local_devices())
print(jax.default_device())

[cuda(id=0)]
1
[cuda(id=0)]
<contextlib._GeneratorContextManager object at 0x7f01e35754c0>


In [3]:
import jax.numpy as jnp
arr = jnp.array([1, 2, 3])
print(arr)
# element-wise operations
sqrs = arr * arr
print(sqrs)
sqrs2 = jnp.square(arr)
print(sqrs2)

[1 2 3]
[1 4 9]
[1 4 9]


### JAX functions and Automatic Differentiation

In [4]:
def square(x):
  """ Square a number """
  return x ** 2

n = 4.
# Calculate the square of n with autodiff!
grad_square = jax.grad(square) # Get the gradient function
square_value = square(n)
square_grad = grad_square(n) # Calculate the gradient at x = 5.0 - must be float or array for autodiff
print(f"Square of 5: {square_value}, Gradient at 5: {square_grad}")

Square of 5: 16.0, Gradient at 5: 8.0


In [5]:
# https://www.youtube.com/watch?v=2uk_pvndOMw

from prettytable import PrettyTable

def f(x):
  return x**4 + 3*x**3 - 36*x**2 - 68*x + 240

n_array = [-7., -6.5, -6., -5.118, -0.9]

t = PrettyTable(["x", "f(x)", "f'(x)"])
for n in n_array:
  grad_f = jax.grad(f)
  t.add_row([n, f(n), grad_f(n)])
print(t)

+--------+--------------------+---------------+
|   x    |        f(x)        |     f'(x)     |
+--------+--------------------+---------------+
|  -7.0  |       324.0        |     -495.0    |
|  -6.5  |      122.1875      |    -318.25    |
|  -6.0  |        0.0         |     -176.0    |
| -5.118 | -71.01711857822397 | -0.0006713867 |
|  -0.9  |      270.5091      |   1.1739953   |
+--------+--------------------+---------------+


### Vectorized operations, calculations are performed on entire arrays simultaneously 

In [6]:
x = jnp.arange(10) # Array of numbers 0 to 9
y = jnp.ones(10)*2 # array of 10 elements, all set to 2
# vectorized addition
z = x + y
print(x)
print(y)
print(z)

[0 1 2 3 4 5 6 7 8 9]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[ 2.  3.  4.  5.  6.  7.  8.  9. 10. 11.]
