In [1]:
import jax
import numpy as np

print(jax.__version__)

0.4.30


In [2]:
print(jax.devices())
print(jax.device_count())
print(jax.local_devices())
print(jax.default_device())

[cuda(id=0)]
1
[cuda(id=0)]
<contextlib._GeneratorContextManager object at 0x7f25b0414a70>


In [6]:
ar1 = [x for x in range(10)]
print(f"ar1: {ar1}")
npr1 = np.array(ar1)
print(f"npr1: {npr1}")

ar1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
npr1: [0 1 2 3 4 5 6 7 8 9]


In [8]:
print(f"ar1 * 2: {ar1 * 2}")
print(f"npr1 * 2: {npr1 * 2}")

ar1 * 2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
npr1 * 2: [ 0  2  4  6  8 10 12 14 16 18]


In [14]:
ar2 = [x for x in range(10) if x % 2 == 0]
print(f"ar2: {ar2}")
npr2 = np.array(ar2)
print(f"npr2: {npr2}")
print(f"ar1 + ar2: {ar1 + ar2}")
# print(f"npr1 + npr2: {npr1 + npr2}") # ValueError: operands could not be broadcast together with shapes (10,) (5,)
npr3 = np.array([x for x in range(10) if x % 2 > 0])
print(f"npr3: {npr3}")
print(f"npr1.shape: {npr1.shape}, npr2.shape: {npr2.shape}, npr3.shape: {npr3.shape}")
print(f"npr2 + npr3: {npr2 + npr3}")

ar2: [0, 2, 4, 6, 8]
npr2: [0 2 4 6 8]
ar1 + ar2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 2, 4, 6, 8]
npr3: [1 3 5 7 9]
npr1.shape: (10,), npr2.shape: (5,), npr3.shape: (5,)
npr2 + npr3: [ 1  5  9 13 17]


In [26]:
npr = np.array([x for x in range(10) if x % 2 == 0])
print(f"npr: {npr}, npr.shape: {npr.shape}")
npr_row = np.array([npr])
npr_column = np.array([npr]).T
print(f"npr_row: {npr_row}, npr_row.shape: {npr_row.shape}")
print(f"npr_column: {npr_column}, npr_column.shape: {npr_column.shape}")
print(npr_row + npr_column)
print(npr_row * npr_column)

npr: [0 2 4 6 8], npr.shape: (5,)
npr_row: [[0 2 4 6 8]], npr_row.shape: (1, 5)
npr_column: [[0]
 [2]
 [4]
 [6]
 [8]], npr_column.shape: (5, 1)
[[ 0  2  4  6  8]
 [ 2  4  6  8 10]
 [ 4  6  8 10 12]
 [ 6  8 10 12 14]
 [ 8 10 12 14 16]]
[[ 0  0  0  0  0]
 [ 0  4  8 12 16]
 [ 0  8 16 24 32]
 [ 0 12 24 36 48]
 [ 0 16 32 48 64]]


In [31]:
npa1 = np.array([x for x in range(10)])
print(f"npa1: {npa1}")
print(f"npa1[2:6]: {npa1[2:6]}")

npa1: [0 1 2 3 4 5 6 7 8 9]
npa1[2:6]: [2 3 4 5]


In [45]:
print(f"npr_row[0][1:3]: {npr_row[0][1:3]}")
print(f"npr_column[1:4][0:3]: {npr_column[1:4][0:3]}")

npr_row[0][1:3]: [2 4]
npr_column[1:4][0:3]: [[2]
 [4]
 [6]]


In [48]:
npa2 = np.arange(12)
print(f"npa2: {npa2}")
print(f"npa2.reshape(3,4): \n{npa2.reshape(3,4)}")

npa2: [ 0  1  2  3  4  5  6  7  8  9 10 11]
npa2.reshape(3,4): 
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]


In [3]:
import jax.numpy as jnp
arr = jnp.array([1, 2, 3])
print(arr)
# element-wise operations
sqrs = arr * arr
print(sqrs)
sqrs2 = jnp.square(arr)
print(sqrs2)

[1 2 3]
[1 4 9]
[1 4 9]


### JAX functions and Automatic Differentiation

In [4]:
def square(x):
  """ Square a number """
  return x ** 2

n = 4.
# Calculate the square of n with autodiff!
grad_square = jax.grad(square) # Get the gradient function
square_value = square(n)
square_grad = grad_square(n) # Calculate the gradient at x = 5.0 - must be float or array for autodiff
print(f"Square of 5: {square_value}, Gradient at 5: {square_grad}")

Square of 5: 16.0, Gradient at 5: 8.0


In [5]:
# https://www.youtube.com/watch?v=2uk_pvndOMw

from prettytable import PrettyTable

def f(x):
  return x**4 + 3*x**3 - 36*x**2 - 68*x + 240

n_array = [-7., -6.5, -6., -5.118, -0.9]

t = PrettyTable(["x", "f(x)", "f'(x)"])
for n in n_array:
  grad_f = jax.grad(f)
  t.add_row([n, f(n), grad_f(n)])
print(t)

+--------+--------------------+---------------+
|   x    |        f(x)        |     f'(x)     |
+--------+--------------------+---------------+
|  -7.0  |       324.0        |     -495.0    |
|  -6.5  |      122.1875      |    -318.25    |
|  -6.0  |        0.0         |     -176.0    |
| -5.118 | -71.01711857822397 | -0.0006713867 |
|  -0.9  |      270.5091      |   1.1739953   |
+--------+--------------------+---------------+


### Vectorized operations, calculations are performed on entire arrays simultaneously 

In [6]:
x = jnp.arange(10) # Array of numbers 0 to 9
y = jnp.ones(10)*2 # array of 10 elements, all set to 2
# vectorized addition
z = x + y
print(x)
print(y)
print(z)

[0 1 2 3 4 5 6 7 8 9]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[ 2.  3.  4.  5.  6.  7.  8.  9. 10. 11.]
