# Intro to Jax a.k.a swapping `np.X` for `jnp.X`

## Lesson Goals:

By the end of this lesson, you will have an understanding of how to migrate from `numpy` to `jax`, and get a feel for how similar the two libraries can be. 

In [None]:
import numpy as np
import jax.numpy as jnp

np.random.seed(42)

# What is Jax?

To put it simply, Jax is numpy for various hardware accelerators. However, it offers much more than that by providing higher-level abstractions, utilizing a different backend (XLA), and supporting automatic differentiation.

From the website:

> JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Despite these capabilities, not all concepts and idioms from NumPy translate directly, and there are certain ‼️sharp edges‼️ of which you should be aware.

## Sample Exercises

Below, we provide some exercises to help you become familiar with Jax and Numpy. The solutions are more or less what you might expect from a drop-in replacement.

In [None]:

def dot_product():
    v = np.random.rand(10)
    M = np.random.rand(10, 5)

    expected_result = np.dot(v, M)
    actual_result = ... # Your code here

    assert jnp.allclose(expected_result, actual_result)
    print("Dot product passed")
    

def is_even_filter():
    to_filter_np = np.asarray([1, 2, 3, 5, 10, 20])
    expected_result = to_filter_np[to_filter_np % 2 == 0]

    to_filter_jnp = jnp.asarray(to_filter_np)
    actual_result = ... # Your code here

    assert jnp.allclose(expected_result, actual_result)
    print("is_even_filter passed")

def top_n_of_norm_squared():
    M = np.random.rand(10, 5)
    TOP_N = 5
    
    expected_result = np.sort(np.linalg.norm(M @ M.T, axis=1))[::-1][:TOP_N]

    jnp_M = jnp.asarray(M)
    actual_result = ... # Your code here

    assert jnp.allclose(expected_result, actual_result)
    print("top_n_of_norm_squared passed")


def hadamard():
    M = np.random.rand(10, 5)
    expected_result = M * M

    jnp_M = jnp.asarray(M)
    actual_result = ... # Your code here

    assert jnp.allclose(expected_result, actual_result)
    print("hadamard passed")
    
    

dot_product()
is_even_filter()
top_n_of_norm_squared()
hadamard()