Skip to content

GPU utilization from Ray for both RL + SFT#1712

Merged
erictang000 merged 2 commits into
mainfrom
achaloo/ray
May 28, 2026
Merged

GPU utilization from Ray for both RL + SFT#1712
erictang000 merged 2 commits into
mainfrom
achaloo/ray

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

Summary

  • Add RayGpuMonitor (skyrl/train/utils/ray_gpu_monitor.py): a background daemon thread that scrapes Ray node Prometheus endpoints each training step and logs per-GPU utilization, GPU memory, and CPU RAM — plus cluster-wide averages — to wandb under ray/ keys.
  • Wire it into both the RL trainer and SFT trainer behind a new enable_ray_gpu_monitor config flag (default True).
  • Add train/tokens_per_second_per_gpu to SFT (parallel to the RL metric).

Test plan

  • uv run --isolated --extra dev pytest tests/train/utils/test_ray_gpu_monitor.py (collection, averaging, thread-safety — no live Ray cluster or GPUs required)

Add `RayGpuMonitor`, a background daemon thread that scrapes Ray node Prometheus
endpoints each training step and logs per-GPU utilization, GPU memory, and CPU
RAM (plus cluster-wide averages) to wandb under `ray/` keys. Wire it into both
the RL and SFT trainers behind `enable_ray_gpu_monitor`, and add
`train/tokens_per_second_per_gpu` to SFT. Includes unit tests covering
collection, averaging, and thread-safety without a live Ray cluster or GPUs.

Co-Authored-By: Aditya Chaloo <achaloo@lila.ai>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a background GPU/RAM monitor (RayGpuMonitor) that scrapes Ray node Prometheus endpoints to collect and log per-node, per-GPU utilization and memory metrics to wandb. This monitoring is integrated into both SFTTrainer and Trainer (PPO), with corresponding configuration options and comprehensive unit tests. The reviewer feedback highlights several important improvements: guarding against a potential ZeroDivisionError when _num_training_gpus is zero, utilizing a threading.Event instead of time.sleep() to ensure responsive background thread shutdown, and wrapping per-node parsing in a try-except block so that a single node failure does not abort the entire metrics collection process.

Comment thread skyrl/train/sft_trainer.py
Comment thread skyrl/train/sft_trainer.py
Comment on lines +57 to +61
self._buffer: List[Dict[str, float]] = []
self._lock = threading.Lock()
self._thread: Optional[threading.Thread] = None
self._running = False
self._urls: List[str] = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using time.sleep() in a background thread can block shutdown for up to collection_interval * 2 seconds. Using a threading.Event allows the thread to wake up and exit immediately when stop() is called.

Suggested change
self._buffer: List[Dict[str, float]] = []
self._lock = threading.Lock()
self._thread: Optional[threading.Thread] = None
self._running = False
self._urls: List[str] = []
self._buffer: List[Dict[str, float]] = []
self._lock = threading.Lock()
self._thread: Optional[threading.Thread] = None
self._running = False
self._stop_event = threading.Event()
self._urls: List[str] = []

Comment on lines +81 to +83
self._running = True
self._thread = threading.Thread(target=self._loop, daemon=True, name="ray-gpu-monitor")
self._thread.start()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Clear the stop event when starting the background thread to ensure it can be restarted cleanly if needed.

        self._running = True
        self._stop_event.clear()
        self._thread = threading.Thread(target=self._loop, daemon=True, name="ray-gpu-monitor")
        self._thread.start()

Comment on lines +90 to +92
self._running = False
if self._thread is not None:
self._thread.join(timeout=self.collection_interval * 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Set the stop event to wake up the background thread immediately, avoiding any shutdown delay caused by time.sleep().

Suggested change
self._running = False
if self._thread is not None:
self._thread.join(timeout=self.collection_interval * 2)
self._running = False
self._stop_event.set()
if self._thread is not None:
self._thread.join()

Comment on lines +147 to +155
while self._running:
try:
snapshot = self._collect(client)
if snapshot:
with self._lock:
self._buffer.append(snapshot)
except Exception as exc:
logger.debug(f"RayGpuMonitor: collection error: {exc}")
time.sleep(self.collection_interval)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use self._stop_event.wait(self.collection_interval) instead of time.sleep() to allow the thread to wake up and exit immediately upon shutdown.

Suggested change
while self._running:
try:
snapshot = self._collect(client)
if snapshot:
with self._lock:
self._buffer.append(snapshot)
except Exception as exc:
logger.debug(f"RayGpuMonitor: collection error: {exc}")
time.sleep(self.collection_interval)
while self._running:
try:
snapshot = self._collect(client)
if snapshot:
with self._lock:
self._buffer.append(snapshot)
except Exception as exc:
logger.debug(f"RayGpuMonitor: collection error: {exc}")
if self._stop_event.wait(self.collection_interval):
break

Comment on lines +160 to +181
for node_idx, url in enumerate(self._urls):
try:
resp = client.get(url)
resp.raise_for_status()
except (httpx.RequestError, httpx.HTTPStatusError) as exc:
logger.debug(f"RayGpuMonitor: failed to scrape {url}: {exc}")
continue

parsed = parse_metrics_text(resp.text)
for (name, labels), value in parsed.items():
labels_dict = dict(labels)
if name == _GPU_UTIL:
gpu_idx = labels_dict.get("GpuIndex", "0")
result[f"node.{node_idx}.gpu.{gpu_idx}.util"] = value
elif name == _GPU_MEM:
gpu_idx = labels_dict.get("GpuIndex", "0")
# Ray reports gram_used in MB; convert to GB.
result[f"node.{node_idx}.gpu.{gpu_idx}.mem_used_gb"] = value / 1024.0
elif name == _RAM_USED:
# Ray reports mem_used in bytes; convert to GB (host CPU RAM utilized).
result[f"node.{node_idx}.cpu_ram_used_gb"] = value / (1024**3)
return result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If parse_metrics_text or any processing inside the loop raises an exception, it will abort the entire collection step and skip any remaining nodes. Wrapping the per-node collection and parsing in a try...except block ensures that a failure on one node does not prevent scraping other nodes.

        for node_idx, url in enumerate(self._urls):
            try:
                resp = client.get(url)
                resp.raise_for_status()
                parsed = parse_metrics_text(resp.text)
                for (name, labels), value in parsed.items():
                    labels_dict = dict(labels)
                    if name == _GPU_UTIL:
                        gpu_idx = labels_dict.get("GpuIndex", "0")
                        result[f"node.{node_idx}.gpu.{gpu_idx}.util"] = value
                    elif name == _GPU_MEM:
                        gpu_idx = labels_dict.get("GpuIndex", "0")
                        # Ray reports gram_used in MB; convert to GB.
                        result[f"node.{node_idx}.gpu.{gpu_idx}.mem_used_gb"] = value / 1024.0
                    elif name == _RAM_USED:
                        # Ray reports mem_used in bytes; convert to GB (host CPU RAM utilized).
                        result[f"node.{node_idx}.cpu_ram_used_gb"] = value / (1024**3)
            except Exception as exc:
                logger.debug(f"RayGpuMonitor: failed to scrape or parse {url}: {exc}")
        return result

@erictang000 erictang000 merged commit 172bc17 into main May 28, 2026
3 of 4 checks passed
@erictang000 erictang000 deleted the achaloo/ray branch May 28, 2026 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant