# Warp PyTorch Tutorial: Custom Operators

In this example we will determine the minimum value of the Rosenbrock function, a non-convex function that is often used to test optimization algorithms. The function is defined as:

$$
f(x,y) = (a-x)^2 + b(y-x^2)^2
$$

where a = 1 and b = 100. The minimum value of the function is 0 at (1, 1).

We will make use of PyTorch custom operators, which will allow us to safely incorporate Warp kernel launches in a PyTorch graph.

In [None]:
!pip install warp-lang torch matplotlib

In [None]:
import warp as wp
import numpy as np

# Requires PyTorch 2.4+
import torch

wp.config.quiet = True

# Explicitly initializing Warp is not necessary but
# we do it here to ensure everything is good to go.
wp.init()

# Everything else is solely to visualize the results.
import IPython
import matplotlib
import matplotlib.animation
import matplotlib.pyplot

matplotlib.pyplot.rc("animation", html="jshtml")

# PyTorch Custom Operators

PyTorch custom operators allow you to wrap Python functions (in this case, Warp kernel launches) so that they behave like PyTorch native operators. See the [PyTorch docs](https://pytorch.org/tutorials/advanced/python_custom_ops.html#adding-training-support-for-crop) for more information. This is useful when you have a computational graph that PyTorch manages but want to use Warp kernels in one or more nodes. In the following example, we use Warp to evaluate the Rosenbrock function in the forward pass and use PyTorch's Adam optimizer to determine the function's minimum.

Below, we define the Warp kernel `eval_rosenbrock` in the usual way, and wrap its forward implementation with the custom PyTorch operator `warp_rosenbrock` as well its adjoint with `warp_rosenbrock_backward`.

In [None]:
# Define the Rosenbrock function
@wp.func
def rosenbrock(x: float, y: float):
    return (1.0 - x) ** 2.0 + 100.0 * (y - x**2.0) ** 2.0


@wp.kernel
def eval_rosenbrock(
    xy: wp.array(dtype=wp.vec2),
    # outputs
    z: wp.array(dtype=float),
):
    i = wp.tid()
    v = xy[i]
    z[i] = rosenbrock(v[0], v[1])


@torch.library.custom_op("wp::warp_rosenbrock", mutates_args=())
def warp_rosenbrock(xy: torch.Tensor, num_particles: int) -> torch.Tensor:
    wp_xy = wp.from_torch(xy, dtype=wp.vec2)
    wp_z = wp.zeros(num_particles, dtype=wp.float32)

    wp.launch(kernel=eval_rosenbrock, dim=num_particles, inputs=[wp_xy], outputs=[wp_z])

    return wp.to_torch(wp_z)


@warp_rosenbrock.register_fake
def _(xy, num_particles):
    return torch.empty(num_particles, dtype=torch.float32)


@torch.library.custom_op("wp::warp_rosenbrock_backward", mutates_args=())
def warp_rosenbrock_backward(
    xy: torch.Tensor, num_particles: int, z: torch.Tensor, adj_z: torch.Tensor
) -> torch.Tensor:
    wp_xy = wp.from_torch(xy, dtype=wp.vec2)
    wp_z = wp.from_torch(z, requires_grad=False)
    wp_adj_z = wp.from_torch(adj_z, requires_grad=False)

    wp.launch(
        kernel=eval_rosenbrock,
        dim=num_particles,
        inputs=[wp_xy],
        outputs=[wp_z],
        adj_inputs=[wp_xy.grad],
        adj_outputs=[wp_adj_z],
        adjoint=True,
    )

    return wp.to_torch(wp_xy.grad)


@warp_rosenbrock_backward.register_fake
def _(xy, num_particles, z, adj_z):
    return torch.empty_like(xy)


def backward(ctx, adj_z):
    ctx.xy.grad = warp_rosenbrock_backward(ctx.xy, ctx.num_particles, ctx.z, adj_z)
    return ctx.xy.grad, None


def setup_context(ctx, inputs, output):
    ctx.xy, ctx.num_particles = inputs
    ctx.z = output


warp_rosenbrock.register_autograd(backward, setup_context=setup_context)

# Setup

In [None]:
"""Initialization"""

# Number of initial function queries (these should move towards the function's minimum)
num_particles = 1500

# Initial particle positions
rng = np.random.default_rng(42)
xy = torch.tensor(
    rng.normal(size=(num_particles, 2)), dtype=torch.float32, requires_grad=True, device=wp.device_to_torch(wp.get_device())
)

# PyTorch Adam optimizer
opt = torch.optim.Adam([xy], lr=5e-2)


"""Plotting"""

# Domain
min_x, max_x = -2.0, 2.0
min_y, max_y = -2.0, 2.0

# Create a grid of points
x = np.linspace(min_x, max_x, 100)
y = np.linspace(min_y, max_y, 100)
X, Y = np.meshgrid(x, y)
XY = np.column_stack((X.flatten(), Y.flatten()))
N = len(XY)

XY = wp.array(XY, dtype=wp.vec2)
Z = wp.empty(N, dtype=wp.float32)

# Evaluate the function over the domain
wp.launch(eval_rosenbrock, dim=N, inputs=[XY], outputs=[Z])
Z = Z.numpy().reshape(X.shape)

# Plot the function as a heatmap
fig = matplotlib.pyplot.figure(figsize=(6, 6))
ax = matplotlib.pyplot.gca()
matplotlib.pyplot.imshow(
    Z,
    extent=[min_x, max_x, min_y, max_y],
    origin="lower",
    interpolation="bicubic",
    cmap="coolwarm",
)
matplotlib.pyplot.contour(
    X,
    Y,
    Z,
    extent=[min_x, max_x, min_y, max_y],
    levels=150,
    colors="k",
    alpha=0.5,
    linewidths=0.5,
)

# Plot optimum
matplotlib.pyplot.plot(1, 1, "*", color="r", markersize=10)

matplotlib.pyplot.title("Rosenbrock function")
matplotlib.pyplot.xlabel("x")
matplotlib.pyplot.ylabel("y")

(mean_marker,) = ax.plot([], [], "o", color="w", markersize=5)

# Create a scatter plot (initially empty)
scatter_plot = ax.scatter([], [], c="k", s=2)

# Optimization with PyTorch

Here we set up the optimization procedure. `step()` executes a single optimization pass. Notice that in the `forward()` method, we simply need to call our custom operator `warp_rosenbrock()`.

In [None]:
def forward():
    z = warp_rosenbrock(xy, num_particles)
    return z

def step():
    opt.zero_grad()
    z = forward()
    z.backward(torch.ones_like(z))
    opt.step()

# Visualization

In [None]:
# Function to update the scatter plot
def render():
    # Compute mean
    xy_np = xy.numpy(force=True)
    mean_pos = np.mean(xy_np, axis=0)
    
    # Update the scatter plot
    scatter_plot.set_offsets(np.c_[xy_np[:, 0], xy_np[:, 1]])
    mean_marker.set_data([mean_pos[0]], [mean_pos[1]])

# Optimize then render
def step_and_render(frame):
    for _ in range(200):
        step()

    render()

# Create the animation and visualize in Matplotlib
plot_anim = matplotlib.animation.FuncAnimation(
    fig,
    step_and_render,
    frames=30,
    interval=100)

# Display the result
IPython.display.display(plot_anim)
matplotlib.pyplot.close()