### The `torch.func` module
This notebook demonstrates how to use the `torch.func` module to compute derivatives of functions.

The `torch.func` module provides a way to compute derivatives of functions using the `torch.func.grad` and `torch.func.jacrev` functions.

The `torch.func.grad` function computes the gradient of a function with respect to a single input tensor.

The `torch.func.jacrev` function computes the Jacobian of a function with respect to a single input tensor
using reverse-mode automatic differentiation.

The `torch.func.jacfwd` function computes the Jacobian of a function with respect to a single input tensor
using forward-mode automatic differentiation.





In [None]:
import torch

torch.set_default_dtype(torch.float64)

### Scalar function example


The next code block calculates derivatives for a batch of data using `torch.func`.

* **`f(x)`**: The function to differentiate: $x^2 + \sin(x^2)$.
* **`vmap(grad(f))(x)`**: This is the key operation.
    * `grad(f)` computes the gradient for a single scalar.
    * `vmap` vectorizes this operation, applying it efficiently across the entire batch `x` (1000 points) in parallel, replacing the need for a Python loop.


In [None]:
def f(x: torch.Tensor) -> torch.Tensor:
    return x**2 + torch.sin(x**2)

def df_dx(x: torch.Tensor) -> torch.Tensor:
    return 2*x + 2*x*torch.cos(x**2)

N = 1000
x = torch.linspace(0, 5, N, requires_grad=True)

# y.backward(gradient=torch.ones_like(y))
dy_dx = torch.func.vmap(torch.func.grad(f))(x)

print("Are derivatives correct?", torch.allclose(df_dx(x), dy_dx))

### Code Description: Jacobian Computation

This script calculates the full Jacobian matrix for a function mapping 2 inputs to 3 
outputs ($f: \mathbb{R}^2 \to \mathbb{R}^3$).

* **`get_vjp(v)`**: Calculates the gradients for a specific direction `v`.
It uses `torch.autograd.grad` to process the entire batch of $N$ points simultaneously.
* **`basis_vectors`**: An identity matrix representing the three output components ($y_1, y_2, y_3$).
* **`vmap(get_vjp)(basis_vectors)`**:
    * Instead of writing a Python loop to calculate the gradient for $y_1$, then $y_2$, then $y_3$ sequentially.
    * `vmap` parallelizes this operation, computing the backward pass for all 3 output dimensions at once.
* **Result**: `Js` is the Jacobian tensor of shape `(N, 3, 2)`, containing the partial derivatives for every sample.

In [None]:
import torch

def f(x: torch.Tensor) -> torch.Tensor:
    x1 = x[:, 0]
    x2 = x[:, 1]

    y1 = x1 * x2
    y2 = torch.sin(x1 + x2**2)
    y3 = x1**2 - 3 * x2

    return torch.stack((y1, y2, y3), dim=1)


def df_dx(x: torch.Tensor) -> torch.Tensor:
    x1 = x[:, 0]
    x2 = x[:, 1]

    dy1_dx1 = x2
    dy1_dx2 = x1

    dy2_dx1 = torch.cos(x1 + x2**2)
    dy2_dx2 = 2 * x2 * torch.cos(x1 + x2**2)

    dy3_dx1 = 2 * x1
    dy3_dx2 = -3 * torch.ones_like(x2)

    J = torch.stack(
        (
            torch.stack((dy1_dx1, dy1_dx2), dim=1),
            torch.stack((dy2_dx1, dy2_dx2), dim=1),
            torch.stack((dy3_dx1, dy3_dx2), dim=1),
        ),
        dim=1
    )

    return J


N = 1000

x = (2 * torch.rand(N, 2) - 1).requires_grad_()

y = f(x)

# ---- Jacobian using vmap + autograd.grad ----

# 1. Define the function for a SINGLE basis vector v (shape: 3,)
#    This replaces the body of your for-loop.
def get_vjp(v):
    # Expand the vector v (3,) to the whole batch (N, 3)
    # e.g., turn [1, 0, 0] into [[1, 0, 0], [1, 0, 0], ...]
    g = v.unsqueeze(0).expand_as(y)
    
    # Compute gradients for this projection
    return torch.autograd.grad(
        outputs=y,
        inputs=x,
        grad_outputs=g,
        retain_graph=True
    )[0]

# 2. Basis vectors for the 3 output dimensions
basis_vectors = torch.eye(3, device=x.device)

# 3. Apply vmap
#    Input: (3, 3) -> effectively iterates over rows
#    Output: Stack of 3 results. Each result is (N, 2).
#    Total Output Shape: (3, N, 2)
Js_vmap = torch.vmap(get_vjp)(basis_vectors)

# 4. Permute to match expected shape (N, 3, 2)
Js = Js_vmap.permute(1, 0, 2)


dfdx = df_dx(x)

print("Are Jacobians correct?", torch.allclose(dfdx, Js))

### Code Description: Batch Jacobian

This script computes the Jacobian matrix for a function mapping 2 inputs to 3 outputs ($f: \mathbb{R}^2 \to \mathbb{R}^3$).

* **`f(x)`**: Defines the logic for a **single** sample vector.
* **`jacrev(f)`**: Uses reverse-mode automatic differentiation to compute the Jacobian matrix for one input.
* **`vmap(jacrev(f))(x)`**: This composes the transformations:
    * It takes the single-sample Jacobian function.
    * It vectorizes it over the batch `x` (1000 samples), computing all Jacobians in parallel without a loop.
    * **Result**: A tensor of shape `(1000, 3, 2)` containing the partial derivatives for every sample.

In [None]:
def f(x: torch.Tensor) -> torch.Tensor:
    x1 = x[0]
    x2 = x[1]

    y1 = x1 * x2
    y2 = torch.sin(x1 + x2**2)
    y3 = x1**2 - 3 * x2

    return torch.stack((y1, y2, y3))


def df_dx(x: torch.Tensor) -> torch.Tensor:
    x1 = x[:, 0]
    x2 = x[:, 1]

    dy1_dx1 = x2
    dy1_dx2 = x1

    dy2_dx1 = torch.cos(x1 + x2**2)
    dy2_dx2 = 2 * x2 * torch.cos(x1 + x2**2)

    dy3_dx1 = 2 * x1
    dy3_dx2 = -3 * torch.ones_like(x2)

    J = torch.stack(
        (
            torch.stack((dy1_dx1, dy1_dx2), dim=1),
            torch.stack((dy2_dx1, dy2_dx2), dim=1),
            torch.stack((dy3_dx1, dy3_dx2), dim=1),
        ),
        dim=1
    )

    return J


N = 1000

x = (2 * torch.rand(N, 2) - 1).requires_grad_()

y = f(x)


#    Compute Jacobian using vmap + jacrev
#    jacrev calculates the Jacobian for one sample.
#    vmap applies it to the whole batch efficiently.
Js = torch.func.vmap(torch.func.jacrev(f))(x)


dfdx = df_dx(x)

print("Shapes:", dfdx.shape, Js.shape)
print("Are Jacobians correct?", torch.allclose(dfdx, Js, atol=1e-5))