Skip to content

Conversation

@yhtang
Copy link
Contributor

@yhtang yhtang commented Nov 11, 2025

No description provided.

@yhtang yhtang marked this pull request as ready for review November 12, 2025 05:52
@yhtang yhtang requested a review from Copilot November 12, 2025 05:52
Copilot finished reviewing on behalf of yhtang November 12, 2025 05:54
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a JAX-vLLM rollout offloading bridge that enables efficient coupling between JAX training and vLLM inference for reinforcement learning post-training workloads. The bridge offloads rollout generation to vLLM while keeping training in JAX, using NCCL for direct GPU-to-GPU weight transfers.

  • Implements a lightweight RPC gateway for control plane coordination between trainer and rollout engine
  • Provides NCCL-based data plane for fast GPU-to-GPU weight streaming with tensor resharding
  • Supports multiple transfer modes (fused, unfused, grouped) and flexible parallelism configurations (FSDP/TP)

Reviewed Changes

Copilot reviewed 59 out of 59 changed files in this pull request and generated 35 comments.

Show a summary per file
File Description
setup.py Package setup with protobuf build hooks and dependencies
pyproject.toml Build system configuration with ruff linting rules
pep517_backend.py Custom PEP 517 build backend for protobuf compilation
jax_inference_offloading/vllm/ vLLM wrapper and worker extension for weight updates
jax_inference_offloading/transport/ NCCL transport implementations (star topology, tensor/model transports)
jax_inference_offloading/controller/ Gateway server and client implementations for trainer/rollout coordination
jax_inference_offloading/models/ Model parameter mappings for Llama3 and Gemma families
jax_inference_offloading/jax/ Offloading bridge API for JAX integration
jax_inference_offloading/tunix/ Tunix-specific rollout and model loading utilities
examples/ Example scripts for single-node and multi-node deployments
Comments suppressed due to low confidence (1)

jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py:139

  • This assignment to 'shutdown' is unnecessary as it is redefined before this value is used.
  def shutdown(self):

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@mjsML mjsML removed request for mjsML and nouiz November 12, 2025 09:50
@mjsML
Copy link
Member

mjsML commented Nov 12, 2025

@jreiffers PTAL

@mjsML mjsML requested review from jreiffers and olupton November 12, 2025 09:55
yhtang and others added 3 commits November 12, 2025 16:57
@yhtang
Copy link
Contributor Author

yhtang commented Nov 13, 2025

The new CI workflow for the jax-inference-offloading subfolder is currently failing because the base container image does not yet include the changes (git-clone.sh) from this PR. Given that cyclic dependency, I suggest we treat this workflow as a pilot and proceed with merging as-is, as long as the main CI passes, then iterate to fix the workflow and address any issues once the updated base container is available.

@Steboss
Copy link
Contributor

Steboss commented Nov 14, 2025

The new CI workflow for the jax-inference-offloading subfolder is currently failing because the base container image does not yet include the changes (git-clone.sh) from this PR. Given that cyclic dependency, I suggest we treat this workflow as a pilot and proceed with merging as-is, as long as the main CI passes, then iterate to fix the workflow and address any issues once the updated base container is available.

@yhtang JIO is not part of the main ci? and I can't understand why the CI doesn't pick up your changes to git-clone.sh. the base container should start working and having it. I can see indeed in the error your new sparse option https://github.com/NVIDIA/JAX-Toolbox/actions/runs/19322673777/job/55267036826?pr=1775#step:4:1184
maybe I am wrong

@Steboss
Copy link
Contributor

Steboss commented Nov 14, 2025

@yhtang other question
I can see there's a loooot of code. Does this come from a specific library or github repo? How likely is this code will change (many times) in the next few months? Are we introducing any technical debt? It may be hard to keep the code and its changes under control?

Copy link
Member

@aybchan aybchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yhtang Could you clarify what you would like the CI here to achieve? Is it to build and publish an image with JAX inference offloading based on nightly JAX (i.e. similar to what we do with maxtext)? Is there also an intention to run the examples you've added in the CI with this image, or some other way to test it?

aybchan
aybchan previously approved these changes Nov 14, 2025
@yhtang
Copy link
Contributor Author

yhtang commented Nov 14, 2025

The new CI workflow for the jax-inference-offloading subfolder is currently failing because the base container image does not yet include the changes (git-clone.sh) from this PR. Given that cyclic dependency, I suggest we treat this workflow as a pilot and proceed with merging as-is, as long as the main CI passes, then iterate to fix the workflow and address any issues once the updated base container is available.

@yhtang JIO is not part of the main ci? and I can't understand why the CI doesn't pick up your changes to git-clone.sh. the base container should start working and having it. I can see indeed in the error your new sparse option https://github.com/NVIDIA/JAX-Toolbox/actions/runs/19322673777/job/55267036826?pr=1775#step:4:1184 maybe I am wrong

I think the issue is that this workflow is still using the vanilla CUDA DL base image rather than the JAX-Toolbox base image built in our workflow. As a result, it doesn’t see the updated git-clone.sh even though it’s included in this PR. Nevertheless, I've reverted git-clone.sh and will add it back when we merge the offloading CI workflow with the main workflow.

My plan is to introduce two Dockerfiles: one for pure OSS installation and another based on the JAX-Toolbox base image and JAX builds. I’d handle the second Dockerfile and the corresponding CI wiring in a follow-up PR to keep the scope of this change focused.

@yhtang
Copy link
Contributor Author

yhtang commented Nov 15, 2025

The standalone JIO CI workflow is now passing, and the remaining main CI failures (e.g. https://github.com/NVIDIA/JAX-Toolbox/actions/runs/19385835327/job/55473082584?pr=1775#step:4:2129) match those on the main branch (e.g. https://github.com/NVIDIA/JAX-Toolbox/actions/runs/19386305470/job/55474369243#step:4:2407).

Per our earlier agreement that this PR is scoped to changes within the jax-inference-offloading folder, I’ll go ahead and merge this PR and keep subsequent updates similarly contained until we work together to integrate the JIO CI into the main CI.

@yhtang yhtang merged commit 03c29c6 into main Nov 15, 2025
70 of 79 checks passed
@yhtang yhtang deleted the yhtang/jax-inference-offloading branch November 15, 2025 08:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants