# Getting started with JAX

**Learning Objectives:**
* Practice defining and performing basic operations on JAX arrays.
* Understand JAX's functional programming paradigm (immutability).
* Use JAX's automatic differentiation capability (`jax.grad`).
* Learn how to train a linear regression from scratch with JAX.

This notebook will cover basic JAX operations, automatic differentiation, and training a linear regression.

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt

## Operations on JAX Arrays

### JAX Arrays (Constants)

JAX arrays are immutable and are similar to `tf.constant` in TensorFlow. This means that once created, their values cannot be changed.

In [None]:
x = jnp.array([2, 3, 4])
print(x)

### Point-wise operations

JAX offers a comprehensive suite of point-wise operations, similar to what you'd find in NumPy or TensorFlow.

**Exercise:** Create two JAX arrays `a = jnp.array([5, 3, 8])` and `b = jnp.array([3, -1, 2])`. Then, compute:
1. The sum of `a` and `b` using `jnp.add` and `+`.
2. The product of `a` and `b` using `jnp.multiply` and `*`.
3. The exponential of `a` using `jnp.exp`.

In [None]:
a = jnp.array([5, 3, 8])
b = jnp.array([3, -1, 2])

sum_add = jnp.add(a, b)
sum_plus = a + b
print(f"Sum using jnp.add: {sum_add}")
print(f"Sum using +: {sum_plus}")

In [None]:

prod_multiply = jnp.multiply(a, b)
prod_star = a * b
print(f"Product using jnp.multiply: {prod_multiply}")
print(f"Product using *: {prod_star}")

In [None]:
a = jnp.array([5, 3, 8])

exp_a = jnp.exp(a)
print(f"Exponential of a: {exp_a}")

### NumPy Interoperability

JAX operations can seamlessly accept native Python types (like lists and scalars) and NumPy arrays as inputs. Conversely, JAX arrays can be converted to NumPy arrays using the standard `np.array()` constructor.

In [None]:
# numpy arrays
a_np = np.array([1, 2])
b_np = np.array([3, 4])
# jax sum
print(f"Sum of numpy arrays: {jnp.add(a_np, b_np)}")

# jax arrays
a_jax = jnp.array([1, 2])
b_jax = jnp.array([3, 4])
# jax sum
print(f"Sum of jax arrays: {jnp.add(a_jax, b_jax)}")

In [None]:
# Convert JAX array to NumPy array
a_jax_to_np = np.array(a_jax)
print(f"JAX array converted to NumPy: {a_jax_to_np}, type: {type(a_jax_to_np)}")

## Linear Regression

Now let's use JAX operations to implement linear regression. Later in the course, you'll see abstracted ways to do this using high-level libraries like Equinox or Flax.

### Toy Dataset

We'll model the following function: $y = 2x + 10$

In [None]:
X_train = jnp.array(range(10), dtype=jnp.float32)
Y_train = 2 * X_train + 10
print(f"X_train: {X_train}")
print(f"Y_train: {Y_train}")

Let's also create a test dataset to evaluate our models:

In [None]:
X_test = jnp.array(range(10, 20), dtype=jnp.float32)
Y_test = 2 * X_test + 10
print(f"X_test: {X_test}")
print(f"Y_test: {Y_test}")

#### Loss Function

A common baseline model is to predict the mean of the training target values. Let's calculate the Mean Squared Error (MSE) for this baseline on the test set.

In [None]:
y_mean = Y_test.mean()


def predict_mean(X):
    return jnp.full_like(X, y_mean)


Y_hat_baseline = predict_mean(X_test)
baseline_errors = (Y_hat_baseline - Y_test) ** 2
baseline_loss = jnp.mean(baseline_errors)
print(f"Baseline MSE Loss (predicting mean): {baseline_loss}")

Now, if $\hat{Y}$ represents the vector containing our model's predictions when we use a linear regression model $\hat{Y} = w_0X + w_1$, we can write a loss function taking as arguments the model parameters:

In [None]:
def loss_mse(params, X, Y):
    w0, w1 = params
    Y_hat = w0 * X + w1
    errors = (Y_hat - Y) ** 2
    return jnp.mean(errors)

### Gradient Function

To use gradient descent, we need to take the partial derivatives of the loss function with respect to each of the weights. With JAX's automatic differentiation capability (`jax.grad`), we don't have to compute them manually! `jax.grad` transforms a function into a new function that computes its gradient. The `argnums` parameter specifies with respect to which argument(s) the gradient should be computed.

In [None]:
# We want gradients with respect to params (arg 0)
grad_fn = jax.grad(loss_mse, argnums=0)

# The loss function takes (params, X, Y)


def compute_gradients(params, X, Y):
    return grad_fn(params, X, Y)

In [None]:
initial_params = [0.0, 0.0]  # w0, w1 as a list or tuple
dw0, dw1 = compute_gradients(initial_params, X_train, Y_train)

print(f"Initial d_w0: {dw0}")
print(f"Initial d_w1: {dw1}")

### Training Loop

Here we have a very simple training loop. Note we are ignoring best practices like batching and random weight initialization for simplicity.

**Exercise:** Complete the `for` loop below to train a linear regression.
1.  Use `compute_gradients` to compute `dw0` and `dw1`.
2.  Update `w0` and `w1` using the computed gradients and the `LEARNING_RATE`. Remember JAX arrays are immutable, so you'll create new arrays for the updated parameters.
3.  For every 100th step, compute and print the `loss` using the `loss_mse` function.

In [None]:
import jax.numpy as jnp  # ensure jnp is used for arrays

STEPS = 1000
LEARNING_RATE = 0.02
MSG = "STEP {step} - loss: {loss}, w0: {w0}, w1: {w1}\n"

# Initialize parameters
params = [0.0, 0.0]  # w0, w1

for step in range(1, STEPS + 1):
    grad_w0, grad_w1 = compute_gradients(params, X_train, Y_train)

    # Remember JAX arrays are immutable.
    new_w0 = params[0] - LEARNING_RATE * grad_w0
    new_w1 = params[1] - LEARNING_RATE * grad_w1
    params = [new_w0, new_w1]

    if step % 100 == 0:
        current_loss = loss_mse(params, X_train, Y_train)
        print(
            MSG.format(step=step, loss=current_loss, w0=params[0], w1=params[1])
        )

print(f"Final parameters: w0={params[0]}, w1={params[1]}")

Now let's compare the test loss for this linear regression to the test loss from the baseline model.

In [None]:
final_w0, final_w1 = params
test_loss = loss_mse(params, X_test, Y_test)
print(f"Test MSE Loss (linear regression): {test_loss}")
print(f"Baseline MSE Loss (predicting mean): {baseline_loss}")

This is indeed much better!

## Bonus

Try modelling a non-linear function such as: $y=xe^{-x^2}$

In [None]:
X = jnp.array(np.linspace(0, 2, 1000), dtype=jnp.float32)
Y = X * jnp.exp(-(X**2))

In [None]:
%matplotlib inline
plt.plot(X, Y)
plt.title("Non-linear function: y = x * exp(-x^2)")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

To model this with a linear model, we need to engineer features. Let's create a function `make_features`.

In [None]:
def make_features(X):
    f1 = jnp.ones_like(X)  # Bias feature
    f2 = X
    f3 = X**2
    f4 = X**3
    f5 = jnp.sqrt(X)
    f6 = jnp.exp(X)
    # Stack them column-wise
    return jnp.stack([f1, f2, f3, f4, f5, f6], axis=1)

We can reuse our `loss_mse` function, but we need a prediction function that works with matrix multiplication for features and weights.

In [None]:
def predict(W, X):
    return jnp.dot(X, W)

In [None]:
def loss_mse(W, X, Y_true):
    Y_hat = predict(W, X)
    errors = (Y_hat - Y_true) ** 2
    return jnp.mean(errors)

In [None]:
def compute_gradients(params_w, X_features, Y_true):
    return jax.grad(loss_mse, argnums=0)(params_w, X_features, Y_true)

Now, let's train our linear model on these engineered features.

In [None]:
STEPS = 2000
LEARNING_RATE = 0.02

Xf = make_features(X)
n_features = Xf.shape[1]

W = jnp.zeros(n_features)

for step in range(1, STEPS + 1):
    grads = compute_gradients(W, Xf, Y)
    W = W - LEARNING_RATE * grads

    if step % 100 == 0:
        current_loss = loss_mse(W, Xf, Y)
        print(f"Step: {step}, Loss: {current_loss}")

plt.plot(X, Y, label="Actual")
plt.plot(X, predict(W, Xf), label="Predicted")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()

Copyright 2025 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License