### `torch.autograd.functional.vjp`

This notebook demonstrates how to use the `torch.autograd.functional.vjp` function 
to compute the vector-Jacobian product (VJP) of a function.

In [1]:
import torch

torch.set_default_dtype(torch.float64)

### Scalar function example

This script computes derivatives using `torch.autograd.functional.vjp`, which calculates the product of a vector $v$ with the Jacobian matrix of the function.

* **`v = torch.ones_like(x)`**: Defines the vector $v$. In the chain rule, this acts as the "incoming gradient." Setting it to all ones preserves the exact gradient of $f(x)$.
* **`vjp(f, x, v=v)`**: Computes both the output and the gradients in a single efficient step:
    1.  **`y`**: The function output ($f(x)$).
    2.  **`dy_dx`**: The Vector-Jacobian Product ($v^T \cdot J$).
* **Result**: Since $v$ is 1, this effectively returns the standard element-wise derivatives for the batch without explicitly constructing a large Jacobian matrix.

In [7]:
def f(x: torch.Tensor) -> torch.Tensor:
    return x**2 + torch.sin(x**2)

def df_dx(x: torch.Tensor) -> torch.Tensor:
    return 2*x + 2*x*torch.cos(x**2)

N = 1000
x = torch.linspace(0, 5, N, requires_grad=True)

# v must have the same shape as f(x)
v = torch.ones_like(x)

# For vector-valued f, pass v directly
y, dy_dx = torch.autograd.functional.vjp(f, x, v=v)

print("Are derivatives correct?", torch.allclose(df_dx(x), dy_dx))


Are derivatives correct? True


### Vector-valued function example

This script calculates gradients for a function that outputs a list of results (one for each sample in the batch).

* **`f(x)`**: Takes a batch of inputs `(N, 2)` and returns a batch of scalar outputs `(N,)`.
* **`v = torch.ones(N)`**: The "incoming gradient" vector. It must match the output shape of `f(x)`. Setting it to ones preserves the gradients exactly.
* **`vjp(f, x, v=v)`**: Computes the Vector-Jacobian Product.
    * Since `f` outputs a vector (the batch results), we provide `v` to weight these outputs during backpropagation.
    * It calculates the derivative of the outputs with respect to the inputs `x`.
    * **Result**: `grad_x_vjp` (shape `(N, 2)`) contains the partial derivatives for every sample in the batch.

In [3]:
def f(x: torch.Tensor) -> torch.Tensor:
    return x[:, 0] * x[:, 1] + torch.sin(x[:, 0] * x[:, 1]**2)

def df_dx(x: torch.Tensor) -> torch.Tensor:
    return torch.stack(
        (
            x[:, 1] + x[:, 1]**2 * torch.cos(x[:, 0] * x[:, 1]**2),
            x[:, 0] + 2 * x[:, 0] * x[:, 1] * torch.cos(x[:, 0] * x[:, 1]**2),
        ),
        dim=1,
    )

N = 1000

x = 2 * torch.rand(N, 2) - 1
x.requires_grad = True

# v must have same shape as f(x), i.e. (N,)
v = torch.ones(N)

# vjp for vector output: pass v directly
y, grad_x_vjp = torch.autograd.functional.vjp(f, x, v=v)

dfdx = df_dx(x)

print("Are gradients correct?", torch.allclose(dfdx, grad_x_vjp))


Are gradients correct? True


### Vector-valued function with multiple inputs

This script demonstrates how to compute gradients for a function that takes **multiple independent arguments** ($x$ and $y$).

* **`f(x, y)`**: A function accepting two separate input tensors.
* **`vjp(f, (x, y), v=v)`**: Computes the Vector-Jacobian Product for multiple inputs simultaneously.
    * **`inputs=(x, y)`**: By passing a tuple of inputs, `vjp` tracks gradients for both tensors.
    * **Result**: It returns a tuple `(vjp_x, vjp_y)`, where `vjp_x` is the partial derivative with respect to $x$, and `vjp_y` is the partial derivative with respect to $y$.

In [4]:
def f(x, y):
    return x*y + torch.sin(x*y**2)

def df_dx(x, y):
    return torch.stack(
        (
            y + y**2 * torch.cos(x*y**2),
            x + 2*x*y * torch.cos(x*y**2),
        ),
        dim=1,
    )

N = 1000

x = (2*torch.rand(N) - 1).requires_grad_()
y = (2*torch.rand(N) - 1).requires_grad_()

# v must match f(x,y) shape (N,)
v = torch.ones(N)

# vjp with multiple inputs: inputs=(x,y)
out, (vjp_x, vjp_y) = torch.autograd.functional.vjp(
    f,
    (x, y),
    v=v
)

J = df_dx(x, y)

print("dx correct?", torch.allclose(J[:, 0], vjp_x))
print("dy correct?", torch.allclose(J[:, 1], vjp_y))


dx correct? True
dy correct? True


### Vector-valued function with multiple inputs

This script reconstructs the full Jacobian matrix by iterating through the output dimensions.

* **The Problem**: `vjp` computes a single vector (a weighted sum of gradients), not the full Jacobian matrix.
* **The Loop**: To get the full Jacobian, we must isolate each output component ($y_1, y_2, y_3$) individually.
* **`v[:, k] = 1.0`**: Inside the loop, we create a "one-hot" vector. This acts as a selector. When passed to `vjp`, it forces PyTorch to calculate the gradient of *only* the $k$-th output variable.
* **Result**: By running `vjp` 3 times and placing the results into `Js[:, k, :]`, we build the complete Jacobian matrix row by row.

In [None]:
def f(x: torch.Tensor) -> torch.Tensor:
    x1 = x[:, 0]
    x2 = x[:, 1]

    y1 = x1 * x2
    y2 = torch.sin(x1 + x2**2)
    y3 = x1**2 - 3 * x2

    return torch.stack((y1, y2, y3), dim=1)   # (N, 3)


def df_dx(x: torch.Tensor) -> torch.Tensor:
    x1 = x[:, 0]
    x2 = x[:, 1]

    dy1_dx1 = x2
    dy1_dx2 = x1

    dy2_dx1 = torch.cos(x1 + x2**2)
    dy2_dx2 = 2 * x2 * torch.cos(x1 + x2**2)

    dy3_dx1 = 2 * x1
    dy3_dx2 = -3 * torch.ones_like(x2)

    J = torch.stack(
        (
            torch.stack((dy1_dx1, dy1_dx2), dim=1),
            torch.stack((dy2_dx1, dy2_dx2), dim=1),
            torch.stack((dy3_dx1, dy3_dx2), dim=1),
        ),
        dim=1
    )   # (N, 3, 2)

    return J


N = 1000

x = (2 * torch.rand(N, 2) - 1).requires_grad_()

y = f(x)

# ----------------------- vjp version -----------------------
Js = torch.zeros(N, 3, 2)

for k in range(3):     # loop over output dims
    v = torch.zeros_like(y)   # shape (N, 3)
    v[:, k] = 1.0             # pick k-th output component

    _, grad_x = torch.autograd.functional.vjp(
        f,
        x,
        v=v
    )

    Js[:, k, :] = grad_x      # grad_x is (N, 2)


dfdx = df_dx(x)

print("Are Jacobians correct?", torch.allclose(dfdx, Js))


Are Jacobians correct? True


### Code Description: Multi-Input Jacobian via VJP Loop

This script reconstructs the full Jacobian matrix for a function taking **two independent inputs** ($x, y$) and producing **three outputs**.

* **The Loop**: To get the full Jacobian matrix, we must iterate through the 3 output dimensions ($y_1, y_2, y_3$).
* **`v[:, k] = 1.0`**: The "selector" vector. By setting only the $k$-th column to 1, we isolate the gradient calculation for that specific output equation.
* **`vjp(f, (x, y), v=v)`**:
    * Accepts a **tuple** of inputs `(x, y)`.
    * Returns a **tuple** of gradients `(gx, gy)`.
* **Reconstruction**: We manually fill the Jacobian tensor `Js` row by row:
    * `Js[:, k, 0] = gx`: Stores partial derivatives w.r.t. $x$.
    * `Js[:, k, 1] = gy`: Stores partial derivatives w.r.t. $y$.

In [6]:
def f(x, y):
    y1 = x * y
    y2 = torch.sin(x + y**2)
    y3 = x**2 - 3*y
    return torch.stack((y1, y2, y3), dim=1)   # (N, 3)


def df_dxdy(x, y):
    dy1_dx = y
    dy1_dy = x

    dy2_dx = torch.cos(x + y**2)
    dy2_dy = 2*y * torch.cos(x + y**2)

    dy3_dx = 2*x
    dy3_dy = -3 * torch.ones_like(y)

    J = torch.stack(
        (
            torch.stack((dy1_dx, dy1_dy), dim=1),
            torch.stack((dy2_dx, dy2_dy), dim=1),
            torch.stack((dy3_dx, dy3_dy), dim=1),
        ),
        dim=1
    )   # (N, 3, 2)
    return J


N = 1000
x = (2*torch.rand(N) - 1).requires_grad_()
y = (2*torch.rand(N) - 1).requires_grad_()

out = f(x, y)      # shape (N, 3)

Js = torch.zeros(N, 3, 2)

for k in range(3):
    # v has same shape as out: (N, 3)
    v = torch.zeros_like(out)
    v[:, k] = 1.0   # pick k-th output component

    # vjp with multiple inputs: returns (out, (vjp_wrt_x, vjp_wrt_y))
    _, (gx, gy) = torch.autograd.functional.vjp(
        f,
        (x, y),
        v=v,
    )

    Js[:, k, 0] = gx
    Js[:, k, 1] = gy

J_true = df_dxdy(x, y)

print("Correct?", torch.allclose(Js, J_true))


Correct? True
