<a href="https://colab.research.google.com/github/Leila828/Learning_JAX_for_deepLearning/blob/main/JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## JAX Basics

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
x = np.zeros(10)
y= jnp.zeros(10)

In [None]:
x

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [None]:
y

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [None]:
x = np.random.rand(1000,1000)
y = jnp.array(x)

In [None]:
%timeit -n 1 -r 1 np.dot(x,x)

1 loop, best of 1: 52.6 ms per loop


In [None]:
%timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready()

1 loop, best of 1: 1.47 ms per loop


## Automatic differentiation with grad

In [None]:
from jax import grad

def f(x):
  return 3*x**2 + 2*x + 5

def f_prime(x):
  return 6*x +2

grad(f)(1.0)


DeviceArray(8., dtype=float32)

In [None]:
f_prime(1.0)

8.0

## XLA and Jit

In [None]:
from jax import jit

x = np.random.rand(1000,1000)
y = jnp.array(x)

def f(x):
  for _ in range(10):
      x = 0.5*x + 0.1* jnp.sin(x)
  return x

g = jit(f)



%timeit -n 5 -r 5 f(y).block_until_ready()


5 loops, best of 5: 11.4 ms per loop


In [None]:
%timeit -n 5 -r 5 g(y).block_until_ready()


5 loops, best of 5: 309 µs per loop


## pmap

In [None]:
from jax import pmap

def f(x):
  return jnp.sin(x) + x**2

f(np.arange(4))
# pmap(f)(np.arange(4))

## Note:colab doesn't allow to attach multiple GPUs to test this

DeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)

In [None]:
from functools import partial
from jax.lax import psum

@partial(pmap, axis_name="i")
def normalize(x):
  return x/ psum(x,'i')

normalize(np.arange(8.))

## Note:colab doesn't allow to attach multiple GPUs to test this

## vmap

In [None]:
from jax import vmap

def f(x):
  return jnp.square(x)

f(jnp.arange(5))
vmap(f)(jnp.arange(5))


1 loop, best of 1: 930 µs per loop
1 loop, best of 1: 1.14 ms per loop


## Pseudo Random Number Generator

In [None]:
from jax import random
key = random.PRNGKey(5)
random.uniform(key)

DeviceArray(0.6343405, dtype=float32)

## Profiler

In [None]:
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")


SyntaxError: ignored