<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/jax_vs_numpy_difference_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##JAX arrays are immutable

In [40]:
import jax
import jax.numpy as jnp
import numpy as np

In [22]:
a_jnp = jnp.array(range(10))
a_np = np.array(range(10))

In [23]:
a_jnp[8], a_np[8]

(Array(8, dtype=int64), np.int64(8))

In [24]:
a_np[8] = 100
a_np[8]

np.int64(100)

In [25]:
try:
  a_jnp[8] = 100
except TypeError as e:
  print(e)

JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [26]:
a_jnp = a_jnp.at[8].set(100)
a_jnp[8]

Array(100, dtype=int64)

##Out-of-bounds indexing - Remember to control for this!

In [27]:
a_jnp[42] #default = clipping for unvalid index

Array(9, dtype=int64)

In [28]:
a_jnp.at[42].get(mode='drop') #ignores out of bound indices

Array(-9223372036854775808, dtype=int64)

In [29]:
a_jnp.at[42].get(mode='fill', fill_value = -1) #fill with a specified fill value

Array(-1, dtype=int64)

In [30]:
a_jnp = a_jnp.at[42].set(100) #drop -> default for 'set'
a_jnp

Array([  0,   1,   2,   3,   4,   5,   6,   7, 100,   9], dtype=int64)

In [31]:
a_jnp = a_jnp.at[42].set(100, mode = 'clip') #clamps out-of-bounds into valid range
a_jnp

Array([  0,   1,   2,   3,   4,   5,   6,   7, 100, 100], dtype=int64)

##Floating point calculations

In [32]:
#default jax FP is 32 and for 64, you must set at startup
from jax import config
config.update("jax_enable_x64", True)

In [33]:
x = jnp.array(range(10), dtype=jnp.float64)
x.dtype

dtype('float64')

In [34]:
xb16 = jnp.array(range(10), dtype=jnp.bfloat16)
xb16.dtype

dtype(bfloat16)

In [35]:
x16 = jnp.array(range(10), dtype=jnp.float16)
x16.dtype

dtype('float16')

In [36]:
xb16 + x16 #float16 + bf16 = float32

Array([ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], dtype=float32)

##jax.lax primitives

In [45]:
def horizontal_flip_func(image):
  return jnp.flip(image, axis=1)

def rotate_func(image, k, axes):
  return jnp.rot90(image, k = k, axes=axes)

def add_noise_func(image, key,  noise_level=0.1):
  noise = noise_level * jax.random_normal(key, image.shape)
  return image + noise

def adjust_colors_func(image, adjust_factor):
  return jnp.clip(image * adjust_factor, 0.0, 1.0)

In [46]:
augmentations = [add_noise_func, horizontal_flip_func, rotate_func, adjust_colors_func]

In [47]:
from jax import random
def random_augmentation(image, augmentations, rng_key):
  """A function that applies a random transformation to an image"""
  augmentation_index = random.randint(key = rng_key, minval=0, maxval=len(augmentations), shape=()) #generate random integer to index the list
  augmented_image = jax.lax.switch(augmentation_index, augmentations, image) #switch between options
  return augmented_image

In [42]:
#Type promotion
jnp.add(42, 42.0)

Array(84., dtype=float64, weak_type=True)

In [43]:
try:
  jax.lax.add(42, 42.0)
except TypeError as e:
  print(e)

lax.add requires arguments to have the same dtypes, got int64, float64. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).


In [44]:
jax.lax.add(np.float64(42), 42.0)

Array(84., dtype=float64)