<a href="https://colab.research.google.com/github/Tushaam/JAX-tutorials/blob/main/JAX_Basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda]
  Downloading jax-0.5.3-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.3,>=0.5.3 (from jax[cuda])
  Downloading jaxlib-0.5.3-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-plugin<=0.5.3,>=0.5.3 (from jax-cuda12-plugin[with_cuda]<=0.5.3,>=0.5.3; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_plugin-0.5.3-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting jax-cuda12-pjrt==0.5.3 (from jax-cuda12-plugin<=0.5.3,>=0.5.3->jax-cuda12-plugin[with_cuda]<=0.5.3,>=0.5.3; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_pjrt-0.5.3-py3-none-manylinux2014_x86_64.whl.metadata (492 bytes)
Collecting nvidia-cuda-nvcc-cu12>=12.6.85 (from jax-cuda12-plugin[with_cuda]<=0.5.3,>=0.5.3; extra == "cuda"->jax[cuda])
  Downloading nvidia_cuda_nvcc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Downloading

In [2]:
import jax
import jax.numpy as jnp

In [3]:
print(jax.devices())

[CudaDevice(id=0)]


# Let us try to first create a JAX array now

In [4]:
x = jnp.array([5, 10, 15])


print("x =", x)

x = [ 5 10 15]


In [5]:
for value in range(2):
  x= value + x**2

print("x=",x)

x= [  626 10001 50626]


*So what did we do here?* we just iterated from 0 to 1 and updated each element of x by squaring it and adding value to it

**first iteration:**  x= 0 + [5x5,10x10,15x15]...now x=[25,100,225]

**second iteration:** x= 1 + [25x25, 100x100, 225x225]..now x=[626,10001,50626]

# Let us now try to compute sin and cos of the array x elements

In [6]:
sin_x = jnp.sin(x)
cos_x = jnp.cos(x)

print("Sin(x) for the array x are:",sin_x)
print("Cos(x) for the array x are:",cos_x)

Sin(x) for the array x are: [-0.73323137 -0.9663353   0.6929788 ]
Cos(x) for the array x are: [-0.67997926 -0.2572861  -0.720958  ]


Now let us come to the *magical* part of JAX

# **AUTOMATIC DIFFERENTIATION**

In [7]:
def func(x):
  return jnp.sum(x ** 2 + 3*x)

grad_func = jax.grad(func)

x = jnp.array([5.0, 10.0, 15.0])

grad = grad_func(x)
print("Gradient of x^2 is:", grad)

Gradient of x^2 is: [13. 23. 33.]


This automatic differentition ability of JAX is amazing especially in large neural network architectures where computing higher order gradients can be cumbersome.

suppose we want to obtain 2nd order differentiation...we can just use this jax.grad() function to the 1st gradient obtained.😀

# Let us now use another remarkable concept of JAX called **JIT**(just in time)

What exactly JIT does is it speeds up the computations remarkably faster compared to numpy operations. Let us write a code to compare the time for an operation implemented with numpy and then JAX

In [8]:
import time
from jax import jit

def slow_function(x):
    return jnp.sin(x) ** 2 + jnp.cos(x) ** 2


fast_function = jit(slow_function)

x = jnp.linspace(0, 1000, 1_000_000)

start = time.time()
slow_function(x)
print("Without jit:", time.time() - start)

start = time.time()
fast_function(x)
print("With jit (1st call - compile):", time.time() - start)

start = time.time()
fast_function(x)
print("With jit (2nd call):", time.time() - start)


Without jit: 0.35568976402282715
With jit (1st call - compile): 0.22977805137634277
With jit (2nd call): 0.00046706199645996094


WOHOOOOOO!! Isn't it fascinating????

The 2nd call of function using JIT was literally **99.87%** faster than that with numpy