# Introduction to JAX

This notebook summarizes with examples some of the great JAX documentation available at: https://docs.jax.dev/en/latest/quickstart.html

### Quickstart
JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.

The main features of JAX can be summarized as: 
- `jax.numpy`: JAX provides a unified *NumPy-like* interface to computations that run on *CPU, GPU, or TPU, in local or distributed settings*.

- `jax.jit`: JAX features built-in *Just-In-Time (JIT)* compilation via *Open XLA*, an open-source machine learning compiler ecosystem.

- `jax.grad`: JAX functions support efficient evaluation of gradients via its *automatic differentiation* transformations.

- `jax.vmap`: JAX functions can be *automatically vectorized* to efficiently map them over arrays representing batches of inputs.

## Installation

### Make new conda environment
Install JAX by creating a fresh `conda` environment. 

As a side note, [`mamba`](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) is a fast version of `conda` that uses `conda-forge` as its main channel (this is the channel you should always be using for science work).

With `conda` (`mamba` already does this by default):
```bash
conda update conda
conda config --add channels conda-forge
conda config --set channel_priority strict
```
Make environment:
```bash
conda create -n jax_intro python=3.12
conda activate jax_intro
```

### Install JAX

Using JAX requires installing two packages: `jax`, which is pure Python and cross-platform, and `jaxlib` which contains compiled binaries, and requires different builds for different operating systems and accelerators.

-- CPU-only (Linux/macOS/Windows):
```bash
pip install -U jax
```

-- GPU (NVIDIA, CUDA 12):
```bash
pip install -U "jax[cuda12]"
```

-- GPU (Mac M chips): (Experimental)

Follow: https://developer.apple.com/metal/jax/
```bash
pip install -U "jax[cuda12]"
```

