# JAX tutorial

## Introduction to JAX
   - Overview of JAX
   - Comparison with NumPy and other libraries like TensorFlow and PyTorch
   - Setting up the environment

---

### Overview of JAX

JAX is a high-performance numerical computation library designed for machine learning research and development. It builds on the foundations of NumPy to offer advanced features such as automatic differentiation and the ability to run on GPU and TPU hardware. JAX supports just-in-time compilation and can efficiently handle large-scale, complex numerical computations which are often required in deep learning.

In [1]:
import jax.numpy as jnp

# Using JAX to perform a simple array operation
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = jnp.power(x, 2)
print(y)

[ 1.  4.  9. 16.]


### Comparison with NumPy and Other Libraries Like TensorFlow and PyTorch

**NumPy**: JAX extends NumPy by adding automatic differentiation and the option to run on accelerators like GPUs and TPUs. While it mimics the NumPy API, making it familiar and easy to pick up, JAX introduces functional programming paradigms and immutability of arrays.

**TensorFlow and PyTorch**:
- **TensorFlow**: Like TensorFlow, JAX offers auto-differentiation and GPU/TPU support. However, TensorFlow is designed with static computational graphs, whereas JAX promotes a more dynamic and Pythonic approach, using JIT compilation to optimize computations on the fly.
- **PyTorch**: PyTorch is known for its dynamic computation graphs and user-friendly interface. JAX similarly allows for dynamic graphs but with a focus on pure functions and transformations (like grad and jit). JAX’s functional approach can lead to more predictable and optimized code, especially in research.

## Basic Concepts in JAX
   - `jax.numpy` and `jax.random`: Basic operations
   - Understanding `jax.jit` for Just-In-Time compilation
   - Using `jax.grad` for automatic differentiation
   - `jax.vmap` for vectorization
   - `jax.pmap` for parallelization across devices

In [2]:
# `jax.numpy` and `jax.random`: Basic Operations
# JAX provides `jax.numpy` (often imported as `jnp`), a GPU- and TPU-compatible version of NumPy. It's designed to be used interchangeably with NumPy functions but with the added benefit of JAX's features like auto-differentiation and JIT compilation.
import jax.numpy as jnp

# Basic array operations using jax.numpy
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.sin(x) + jnp.cos(x)
print("Result of jnp operations:", y)

# `jax.random` is used for random number generation, which is slightly different due to the need for explicit control over the random state:
from jax import random

# Generating a random matrix
key = random.PRNGKey(0)
matrix = random.normal(key, (3, 3))
print("Random matrix:\n", matrix)

Result of jnp operations: [ 1.3817732   0.49315056 -0.8488725 ]
Random matrix:
 [[-0.3721109   0.26423115 -0.18252768]
 [-0.7368197   0.44973662 -0.1521442 ]
 [-0.67135346 -0.5908641   0.73168886]]


In [5]:
### Understanding `jax.jit` for Just-In-Time Compilation
# The `jax.jit` function is used to compile functions to run more efficiently on GPU/TPU. It's particularly useful for functions with operations that JAX can optimize through XLA (Accelerated Linear Algebra).
from jax import jit

@jit
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

# JIT-compiled sigmoid function
x = jnp.linspace(-10, 10, 100)
y = sigmoid(x)
print("JIT-compiled sigmoid results:", y)

JIT-compiled sigmoid results: [4.53978719e-05 5.55606748e-05 6.79983132e-05 8.32199730e-05
 1.01848804e-04 1.24647180e-04 1.52547960e-04 1.86692982e-04
 2.28478792e-04 2.79614789e-04 3.42191313e-04 4.18766722e-04
 5.12469152e-04 6.27125031e-04 7.67413818e-04 9.39054938e-04
 1.14904228e-03 1.40591967e-03 1.72012544e-03 2.10440415e-03
 2.57431110e-03 3.14881280e-03 3.85103165e-03 4.70911246e-03
 5.75728714e-03 7.03711621e-03 8.59898701e-03 1.05038397e-02
 1.28252096e-02 1.56514868e-02 1.90885402e-02 2.32625399e-02
 2.83228736e-02 3.44451889e-02 4.18339297e-02 5.07243499e-02
 6.13831095e-02 7.41067529e-02 8.92170370e-02 1.07052118e-01
 1.27951682e-01 1.52235821e-01 1.80176586e-01 2.11963370e-01
 2.47663721e-01 2.87185848e-01 3.30246359e-01 3.76354516e-01
 4.24816966e-01 4.74768937e-01 5.25231123e-01 5.75183094e-01
 6.23645484e-01 6.69753551e-01 7.12814093e-01 7.52336144e-01
 7.88036704e-01 8.19823325e-01 8.47764194e-01 8.72048259e-01
 8.92947793e-01 9.10782874e-01 9.25893307e-01 9.3861693

In [4]:
# Using `jax.grad` for Automatic Differentiation
# JAX's `jax.grad` is used to compute the gradient of a function. This is essential for tasks like optimizing training in machine learning.
from jax import grad

def loss_fn(w):
    """ A simple loss function: (w - 5)^2 """
    return (w - 5) ** 2

# Compute the gradient of the loss function
grad_loss = grad(loss_fn)
print("Gradient at w = 10:", grad_loss(10.0))

Gradient at w = 10: 10.0


In [6]:
# `jax.vmap` for Vectorization
# `jax.vmap` automatically vectorizes a function, enabling it to efficiently operate on batches of inputs without manual batching.
from jax import vmap

def square(x):
    return x * x

# Automatically vectorize the 'square' function
x = jnp.array([1, 2, 3, 4, 5])
squared = vmap(square)(x)
print("Vectorized squaring results:", squared)

Vectorized squaring results: [ 1  4  9 16 25]


In [7]:
# `jax.pmap` for Parallelization Across Devices
# `jax.pmap` (parallel map) is used for parallel computation across multiple devices like GPUs and TPUs, making it possible to scale computations efficiently.
from jax import pmap

@pmap
def cube(x):
    return x ** 3

# Assuming multiple devices are available, this will parallelize the operation
x = jnp.arange(6).reshape((2, 3))
cubed = cube(x) # if this gives an error it mean you only have 1 single device
print("Parallel cube results:", cubed)

ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2)

: 

## Linear Algebra in JAX
   - Implementing basic linear layers
   - Activation functions

Linear algebra is foundational to many operations in machine learning and data science. JAX provides efficient and flexible tools for performing these operations on modern hardware. This section will cover implementing basic linear layers and activation functions in JAX.

### Implementing Basic Linear Layers

A linear layer in a neural network performs a linear transformation \( y = Wx + b \), where \( W \) is the weight matrix, \( x \) is the input vector or matrix, and \( b \) is the bias vector. Here’s how you can implement a simple linear layer in JAX:

In [7]:
import jax.numpy as jnp
from jax import random

def init_linear_layer(input_dim, output_dim, key):
    """ Initialize weights and biases for a linear layer. """
    W_key, b_key = random.split(key)
    W = random.normal(W_key, (output_dim, input_dim)) * (2 / input_dim)**0.5
    b = random.normal(b_key, (output_dim,)) * (2 / input_dim)**0.5
    return W, b

def linear_layer(x, W, b):
    """ Apply a linear transformation to the input data x. """
    return jnp.dot(W, x) + b

# Example usage
key = random.PRNGKey(0)
input_dim, output_dim = 4, 3
W, b = init_linear_layer(input_dim, output_dim, key)
x = jnp.array([1, 2, 3, 4])
y = linear_layer(x, W, b)
print("Output of linear layer:", y)

Output of linear layer: [ 2.625999   1.5546751 -1.608583 ]


Implementing and applying these functions in JAX takes advantage of JIT compilation and auto-vectorization, allowing them to run very efficiently on both CPUs and GPUs.

### Activation Functions

Activation functions are crucial in neural networks as they introduce non-linear properties to the system, allowing for learning complex patterns. Common activation functions include ReLU, sigmoid, and tanh. Here’s how to implement these in JAX:

In [21]:
def relu(x):
    """ Rectified Linear Unit activation function """
    return jnp.maximum(0, x)

def sigmoid(x):
    """ Sigmoid activation function """
    return 1 / (1 + jnp.exp(-x))

def tanh(x):
    """ Hyperbolic Tangent activation function """
    return jnp.tanh(x)

# Applying activation functions
x = jnp.linspace(-3, 3, 7)
print("ReLU:", relu(x))
print("Sigmoid:", sigmoid(x))
print("Tanh:", tanh(x))

ReLU: [0. 0. 0. 0. 1. 2. 3.]
Sigmoid: [0.04742587 0.11920292 0.26894143 0.5        0.7310586  0.880797
 0.95257413]
Tanh: [-0.9950547 -0.9640276 -0.7615942  0.         0.7615942  0.9640276
  0.9950547]


## Building Neural Networks from Scratch
   - Implementing common layers manually (Convolutional, Recurrent, etc.)
   - Creating loss functions
   - Training loops

---

## Building Neural Networks from Scratch

Building neural networks from scratch can deepen your understanding of their inner workings. This section will demonstrate how to implement common types of layers, define loss functions, and construct training loops using JAX.

### Implementing Common Layers Manually

In [22]:
## Convolutional Layers
# Convolutional layers are fundamental in processing spatial data such as images. Here's how to implement a basic convolutional layer in JAX:
from jax import lax
import jax.numpy as jnp

def conv_layer(x, W, stride, padding):
    """ Apply a convolutional layer manually """
    return lax.conv_general_dilated(
        x,  # input
        W,  # kernel
        window_strides=(stride, stride),  # stride
        padding=padding  # padding
    )

# Example usage for convolutional layer
input_feature_map = jnp.ones((1, 5, 5, 1))  # Shape: (batch_size, height, width, channels)
kernel = jnp.ones((3, 3, 1, 1))  # Shape: (height, width, in_channels, out_channels)
output_feature_map = conv_layer(input_feature_map, kernel, stride=1, padding='SAME')
print("Output of convolutional layer:\n", output_feature_map)

ValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 5 // 1 != 3.

In [23]:
## Recurrent Layers
# Recurrent layers are essential for processing sequential data. Implementing a simple recurrent neural network (RNN) layer manually can be done as follows:
def rnn_cell(h_prev, x_t, W_h, W_x, b):
    """ A simple RNN cell """
    return jnp.tanh(jnp.dot(W_h, h_prev) + jnp.dot(W_x, x_t) + b)

# Example usage for RNN cell
h_prev = jnp.zeros(10)  # previous hidden state
x_t = jnp.ones(5)  # current input
W_h = random.normal(key, (10, 10))  # weights for previous hidden state
W_x = random.normal(key, (10, 5))  # weights for current input
b = jnp.zeros(10)  # bias
h_next = rnn_cell(h_prev, x_t, W_h, W_x, b)
print("Next hidden state:", h_next)

Next hidden state: [-0.9753993  -0.49872476 -0.8934176  -0.40455788 -0.7315557  -0.90595424
  0.869794    0.997634   -0.37609205 -0.99607754]


In [24]:
## Creating Loss Functions
# Loss functions evaluate how well your model performs. Here is an example of a common loss function used in regression tasks:
def mse_loss(y_pred, y_true):
    """ Mean Squared Error loss function """
    return jnp.mean((y_pred - y_true) ** 2)

# Example usage of MSE loss
y_pred = jnp.array([0.5, 2.0, 2.5])
y_true = jnp.array([0, 2, 3])
loss = mse_loss(y_pred, y_true)
print("MSE Loss:", loss)

MSE Loss: 0.16666667


In [25]:
## Training Loops
# Training loops involve repeatedly adjusting model parameters to minimize the loss. Here's a basic example of a training loop using JAX:
from jax import grad, jit

@jit
def update_params(params, grads, learning_rate):
    """ Update parameters using gradient descent """
    return [p - learning_rate * g for p, g in zip(params, grads)]

# Dummy data and parameters
params = [jnp.array([1.0, -0.5]), jnp.array([0.0])]  # example parameters
data = jnp.array([[1.0, 2.0], [2.0, 3.0]])  # example data
targets = jnp.array([1.5, 2.5])  # example targets

# Simulate a training loop
learning_rate = 0.01
for epoch in range(100):
    y_pred = jnp.dot(data, params[0]) + params[1]
    loss = mse_loss(y_pred, targets)
    grads = grad(mse_loss)(y_pred, targets)
    params = update_params(params, grads, learning_rate)

print("Updated parameters:", params)

Updated parameters: [Array([ 1.3704304, -0.1295694], dtype=float32), Array([0.31756377], dtype=float32)]


## Advanced Neural Network Architectures
   - Convolutional Neural Networks (CNNs) for image processing
   - Recurrent Neural Networks (RNNs) and Long Short-Term Memory networks (LSTMs) for sequence processing
   - Transformers for NLP tasks

## Optimization and Training Techniques
   - Gradient descent and its variants
   - Techniques for improving training stability and performance
   - Learning rate schedules

## Model Evaluation and Testing
   - Validation and cross-validation
   - Overfitting and regularization techniques
   - Metrics for performance evaluation

## Utilities and Helpers in JAX
   - Using `jax.tree_util` to handle complex data structures
   - Debugging tools in JAX

## Libraries and Frameworks for JAX
   - Introduction to Flax and Haiku: Higher-level abstractions for neural networks
   - Using Optax for gradient processing and optimization

## Quantization and Model Optimization
    - Understanding quantization in deep learning
    - Implementing quantization in JAX for model deployment

## Parallel and Distributed Computing
    - Using `jax.pmap` for data parallelism
    - Strategies for distributing computation across multiple GPUs or TPUs
    - Best practices for parallel and distributed training

## Advanced Topics and Applications
    - Mixed precision training for efficiency
    - Integrating JAX with other Python libraries
    - Real-world applications and case studies

13. **JAX for Research**
    - How to leverage JAX for experimental and cutting-edge research
    - Custom gradients and extending JAX with custom operations

14. **Performance Tuning and Optimization**
    - Profiling JAX applications
    - Techniques for maximizing performance on GPUs and TPUs

# END OF THE LINE