<a href="https://colab.research.google.com/github/ShyamSundhar1411/My-ML-Notebooks/blob/master/Playgrounds/JAX_Playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to JAX

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(42)
x = random.normal(key,(10,))

In [3]:
x

Array([ 0.36900434, -0.46067542, -0.8650934 ,  1.2080883 ,  1.003065  ,
       -0.8708058 , -0.3984997 , -0.6670092 ,  0.33689356,  0.39822492],      dtype=float32)

## Matrix Multiplication



1. Key is the psuedo random seed that is generated using random
2. jnp by default uses gpu to execute
3. jnp also works on np but by default it runs operations on cpu for type np. Use of device_put is recommended for every np type.



In [5]:
mat1 = random.normal(key,(3000,3000),dtype=jnp.float32)

In [7]:
jnp.dot(mat1,mat1.T).block_until_ready()

Array([[ 3.0604648e+03,  6.3151691e+01, -3.6272667e+01, ...,
        -1.2364353e+01, -6.8656975e+01, -8.6898956e+01],
       [ 6.3151691e+01,  2.8897720e+03, -3.5676849e+01, ...,
         7.7677994e+00,  1.1219151e+00,  6.8914673e+01],
       [-3.6272667e+01, -3.5676849e+01,  2.9800063e+03, ...,
        -1.3242654e+01, -4.2341629e+01, -3.2430225e+01],
       ...,
       [-1.2364353e+01,  7.7677994e+00, -1.3242654e+01, ...,
         2.9366499e+03,  3.2153156e+01,  7.1246815e+00],
       [-6.8656975e+01,  1.1219151e+00, -4.2341629e+01, ...,
         3.2153156e+01,  2.9689722e+03, -1.5499933e+01],
       [-8.6898956e+01,  6.8914673e+01, -3.2430225e+01, ...,
         7.1246815e+00, -1.5499933e+01,  3.0071340e+03]], dtype=float32)

## JIT Compile

In [8]:
def selu(x, scale, alpha):
  return scale*jnp.where(x>0,x,alpha*(jnp.exp(x)-1))

In [10]:
selu_jit = jit(selu)

In [12]:
x = random.normal(key,shape = (3000,))

In [14]:
selu_jit(x,1.05,1.67)

Array([ 0.5873864 , -0.38615307,  2.5704193 , ..., -0.47609577,
       -0.45193538,  0.08148339], dtype=float32)