Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
verify_right_padding,
)
from nemo_rl.models.huggingface.common import ModelFlag
from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled


class VllmSpecificArgs(TypedDict):
Expand Down Expand Up @@ -313,7 +314,7 @@ def _patch_vllm_init_workers_ray():
# For non-parallel mode, explicitly set executor to None to avoid Ray issues
vllm_kwargs["distributed_executor_backend"] = None

os.environ["VLLM_USE_V1"] = os.environ.get("NRL_VLLM_USE_V1", "1")
os.environ["VLLM_USE_V1"] = "1" if is_vllm_v1_engine_enabled() else "0"
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"

load_format = self.cfg["vllm_cfg"]["load_format"]
Expand Down
28 changes: 17 additions & 11 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
get_gpu_info,
get_runtime_env_for_policy_worker,
import_class_from_path,
is_vllm_v1_engine_enabled,
sliding_window_overwrite,
)
from nemo_rl.utils.native_checkpoint import (
Expand Down Expand Up @@ -418,6 +419,17 @@ def create_context_parallel_ctx(

# Refer to nemo impl. Below is original comment.
# based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178

def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
# Apply temperature scaling to logits if configured and not using V1 engine.
if "generation" in self.cfg and self.cfg["generation"] is not None:
# The V1 engine returns raw logits before temperature scaling.
# The V0 engine returns scaled logits.
# Therefore, we only divide if we are not using the V1 engine.
if not is_vllm_v1_engine_enabled():
logits.div_(self.cfg["generation"]["temperature"])
return logits

@staticmethod
@contextlib.contextmanager
def train_context(cp_context: Optional[Generator[None, None, None]] = None):
Expand Down Expand Up @@ -654,17 +666,8 @@ def train(
else:
logits = outputs.logits

# Divide logits by temperature
if (
"generation" in self.cfg
and self.cfg["generation"] is not None
):
# The V1 engine returns raw logits before temperature scaling.
# The V0 engine (when VLLM_USE_V1 is not '1') returns scaled logits.
# Therefore, we only divide if we are NOT using the V1 engine.
use_v1_engine = os.environ.get("VLLM_USE_V1") == "1"
if not use_v1_engine:
logits.div_(self.cfg["generation"]["temperature"])
# Apply temperature scaling
logits = self._apply_temperature_scaling(logits)

if self.cp_size > 1:
seq_index_dtensor = (
Expand Down Expand Up @@ -944,6 +947,9 @@ def get_logprobs(

logits = outputs.logits

# Apply temperature scaling
logits = self._apply_temperature_scaling(logits)

if self.cp_size > 1:
seq_index_tensor = (
DTensor.from_local(
Expand Down
9 changes: 9 additions & 0 deletions nemo_rl/models/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@
from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches


def is_vllm_v1_engine_enabled() -> bool:
"""Check if vLLM V1 engine is enabled.

Returns:
bool: True if V1 engine is enabled, False otherwise (defaults to True if not set)
"""
return os.environ.get("NRL_VLLM_USE_V1", "1") == "1"


def import_class_from_path(name: str) -> Any:
"""Import a class from a string path (e.g. 'torch.optim.AdamW').

Expand Down
Loading