Skip to content

Add support for Warp Backend for Gradient-Based Optimization#163

Open
Medyan-Naser wants to merge 6 commits into
Autodesk:mainfrom
Medyan-Naser:feature/warp-gradient-161
Open

Add support for Warp Backend for Gradient-Based Optimization#163
Medyan-Naser wants to merge 6 commits into
Autodesk:mainfrom
Medyan-Naser:feature/warp-gradient-161

Conversation

@Medyan-Naser
Copy link
Copy Markdown
Contributor

@Medyan-Naser Medyan-Naser commented May 29, 2026

Contributing Guidelines

Description

This PR enables automatic differentiation for the Warp backend, allowing gradient-based optimization through multi-step LBM simulations. All operators have been modified to be compatible with Warp's tape-based reverse-mode AD while preserving accuracy.

closes #161

Final iteration for differentiable example using WARP and JAX:

JAX:
iteration_00149

WARP:
iteration_00149

Convergence:

JAX:
convergence

WARP:
convergence

Type of change

Core Operator Modifications

Four operators were modified to remove non-differentiable control flow that breaks Warp's autodiff:

1. xlb/operator/macroscopic/zero_moment.py - Density Computation

  • Removed: Neumaier compensated summation (contains if-else branches)
  • Replaced: Simple summation
  • Validation: Numerical error ~10⁻¹², LBM precision ~10⁻⁵ (safe margin)

2. xlb/operator/macroscopic/first_moment.py - Velocity Computation

  • Removed: Neumaier compensated summation
  • Replaced: Direct summation with velocity components
  • Validation: 100% gradient match with JAX

3. xlb/operator/stream/stream.py - Streaming Operator

  • Removed: Conditional boundary handling (if-else)
  • Replaced: Modulo arithmetic for periodic BCs
  • Validation: Mathematically equivalent

4. xlb/operator/stepper/nse_stepper.py - NSE Stepper

  • Removed: Early returns in solid boundary cells
  • Replaced: Conditional assignment
  • Validation: Same physics, different code structure

Usage Requirements

Critical Pattern for Multi-Step Optimization

Warp's tape-based AD requires pre-allocating all intermediate states:

import warp as wp
from xlb.operator.stepper import IncompressibleNavierStokesStepper

# Setup
wp.init()
stepper = IncompressibleNavierStokesStepper(grid=grid, collision_type="BGK")
f_init = wp.zeros((9, nx, ny), dtype=wp.float32, requires_grad=True)
# ... initialize f_init ...

num_steps = 20
for opt_iter in range(100):
    # CORRECT: Pre-allocate all intermediate states
    f_states = [wp.zeros((9, nx, ny), dtype=wp.float32, requires_grad=True) 
                for _ in range(num_steps + 1)]
    wp.copy(f_states[0], f_init)
    
    loss_wp = wp.zeros(1, dtype=wp.float32, requires_grad=True)
    
    # Forward with tape
    with wp.Tape() as tape:
        for step in range(num_steps):
            f_states[step + 1].zero_()  # Clear buffer, don't recreate
            _, f_states[step + 1] = stepper(
                f_states[step], f_states[step + 1],
                bc_mask, missing_mask, omega, step
            )
        
        # Compute loss
        compute_loss(f_states[num_steps], loss_wp)
    
    # Backward
    loss_wp.grad.fill_(1.0)
    tape.backward(loss=loss_wp)
    
    # Update parameters
    f_grad = tape.gradients[f_init]
    f_init_np = f_init.numpy() - learning_rate * f_grad.numpy()
    f_init = wp.array(f_init_np, dtype=wp.float32, requires_grad=True)
    
    tape.zero()

Common Mistake (Breaks Autodiff)

# DON'T DO THIS - recreates arrays, severs gradient flow
for step in range(num_steps):
    f_1 = wp.zeros_like(f_0)  # New array each iteration
    f_0, f_1 = stepper(f_0, f_1, ...)

Why it fails: Each wp.zeros_like() creates a new array, breaking the computational graph. Gradients cannot flow back through severed connections.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

How Has This Been Tested?

  • All pytest tests pass

The single test failure also exist on the main branch.

=============================================================================================================== FAILURES ===============================================================================================================
___________________________________________________________________________________________ test_neon_editable_install_after_prior_warp_lang ___________________________________________________________________________________________

tmp_path = PosixPath('/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0')

    @pytest.mark.skipif(sys.version_info < (3, 11), reason="XLB requires Python >= 3.11")
    @pytest.mark.skipif(not _neon_wheels_supported(), reason="Neon wheels: Linux x86_64 / aarch64 only")
    def test_neon_editable_install_after_prior_warp_lang(tmp_path: Path) -> None:
        """Pre-install ``warp-lang``, then ``pip install -e .[neon,test]``; imports must work."""
        venv_dir = tmp_path / "venv"
        subprocess.run([sys.executable, "-m", "venv", str(venv_dir)], check=True)
        venv_py = _venv_python(venv_dir)
        assert venv_py.is_file(), f"missing venv python: {venv_py}"
    
        env = {**os.environ, "XLB_NEON_SKIP_UNINSTALL_WARP": ""}
    
        proc = _run([str(venv_py), "-m", "pip", "install", "--upgrade", "pip", "wheel"], cwd=REPO_ROOT, env=env)
        assert proc.returncode == 0, proc.stdout + proc.stderr
    
        proc = _run(
            [str(venv_py), "-m", "pip", "install", "warp-lang==1.10.0"],
            cwd=REPO_ROOT,
            env=env,
        )
        assert proc.returncode == 0, proc.stdout + proc.stderr
    
        proc = _run(
            [str(venv_py), "-m", "pip", "install", "-e", ".[neon,test]"],
            cwd=REPO_ROOT,
            env=env,
        )
        assert proc.returncode == 0, proc.stdout + proc.stderr
    
        proc = _run(
            [str(venv_py), "-c", "import neon; import warp; print('ok', neon.__file__, warp.__file__)"],
            cwd=REPO_ROOT,
            env=env,
        )
>       assert proc.returncode == 0, proc.stdout + proc.stderr
E       AssertionError: [19:30:06] [neon-py] [ERROR] Failed to load library: /tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/liblibNeonPy.so
E         Traceback (most recent call last):
E           File "<string>", line 1, in <module>
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/__init__.py", line 20, in <module>
E             from .multires.__init__ import *
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/__init__.py", line 3, in <module>
E             from .mGrid import mGrid
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mGrid.py", line 16, in <module>
E             from .mField import mField
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mField.py", line 16, in <module>
E             import neon.multires.mPartition
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mPartition.py", line 196, in <module>
E             mPartition_int8 = factory_mPartition(wp.int8)      # 8-bit signed integer partitions
E                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mPartition.py", line 144, in factory_mPartition
E             neon_gate: neon.Gate = neon.Gate()
E                                    ^^^^^^^^^^^
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/gate.py", line 111, in __init__
E             raise e
E           File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/gate.py", line 107, in __init__
E             self.lib = ctypes.CDLL(lib_path)
E                        ^^^^^^^^^^^^^^^^^^^^^
E           File "/usr/lib/python3.12/ctypes/__init__.py", line 379, in __init__
E             self._handle = _dlopen(self._name, mode)
E                            ^^^^^^^^^^^^^^^^^^^^^^^^^
E         OSError: libcudart.so.12: cannot open shared object file: No such file or directory
E         
E       assert 1 == 0
E        +  where 1 = CompletedProcess(args=['/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/bin/python', '-c', "import neon; import warp; print('ok', neon.__file__, warp.__file__)"], returncode=1, stdout='[19:30:06] [neon-py] [ERROR] Failed to load library: /tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/liblibNeonPy.so\n', stderr='Traceback (most recent call last):\n  File "<string>", line 1, in <module>\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/__init__.py", line 20, in <module>\n    from .multires.__init__ import *\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/__init__.py", line 3, in <module>\n    from .mGrid import mGrid\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mGrid.py", line 16, in <module>\n    from .mField import mField\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mField.py", line 16, in <module>\n    import neon.multires.mPartition\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mPartition.py", line 196, in <module>\n    mPartition_int8 = factory_mPartition(wp.int8)      # 8-bit signed integer partitions\n                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/multires/mPartition.py", line 144, in factory_mPartition\n    neon_gate: neon.Gate = neon.Gate()\n                           ^^^^^^^^^^^\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/gate.py", line 111, in __init__\n    raise e\n  File "/tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/gate.py", line 107, in __init__\n    self.lib = ctypes.CDLL(lib_path)\n               ^^^^^^^^^^^^^^^^^^^^^\n  File "/usr/lib/python3.12/ctypes/__init__.py", line 379, in __init__\n    self._handle = _dlopen(self._name, mode)\n                   ^^^^^^^^^^^^^^^^^^^^^^^^^\nOSError: libcudart.so.12: cannot open shared object file: No such file or directory\n').returncode

tests/install/test_neon_install_warp_cleanup.py:79: AssertionError
======================================================================================================= short test summary info ========================================================================================================
FAILED tests/install/test_neon_install_warp_cleanup.py::test_neon_editable_install_after_prior_warp_lang - AssertionError: [19:30:06] [neon-py] [ERROR] Failed to load library: /tmp/pytest-of-medy/pytest-2/test_neon_editable_install_aft0/venv/lib/python3.12/site-packages/neon/liblibNeonPy.so
=============================================================================================== 1 failed, 93 passed in 177.66s (0:02:57) ===============================================================================================

Validation Results

Single-Step Gradient Accuracy

Test: test_gradient_comparison.py
Element-wise match: 9,216 / 9,216 (100%)
Magnitude:   JAX: 192.00, Warp: 192.00
Direction:   Cosine similarity: 1.0000
Max difference: 0.00e+00

Forward Simulation Consistency

Test: test_forward_consistency.py
Max difference (Warp vs JAX): 0.00e+00 across 20 timesteps
Physics preserved: ✓

Multi-Step Optimization (Actual Example)

Test: examples/cfd/differentiable_lbm.py --shape circle --grid-size 64 --sim-steps 10 --iterations 20

Backend | Improvement | Final Loss |
--------|-------------|------------|--------
JAX     | 97.46%      | 0.012724   |
Warp    | 97.56%      | 0.012216   |

Conclusion: Warp optimization achieves identical performance to JAX when using correct pattern.

Linting and Code Formatting

Make sure the code follows the project's linting and formatting standards. This project uses Ruff for linting.

To run Ruff, execute the following command from the root of the repository:

ruff check .
  • Ruff passes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Warp backend does not support gradient-based optimization (stepper returns zero gradients)

1 participant