GPU utilization from Ray for both RL + SFT#1712
Conversation
Add `RayGpuMonitor`, a background daemon thread that scrapes Ray node Prometheus endpoints each training step and logs per-GPU utilization, GPU memory, and CPU RAM (plus cluster-wide averages) to wandb under `ray/` keys. Wire it into both the RL and SFT trainers behind `enable_ray_gpu_monitor`, and add `train/tokens_per_second_per_gpu` to SFT. Includes unit tests covering collection, averaging, and thread-safety without a live Ray cluster or GPUs. Co-Authored-By: Aditya Chaloo <achaloo@lila.ai> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a background GPU/RAM monitor (RayGpuMonitor) that scrapes Ray node Prometheus endpoints to collect and log per-node, per-GPU utilization and memory metrics to wandb. This monitoring is integrated into both SFTTrainer and Trainer (PPO), with corresponding configuration options and comprehensive unit tests. The reviewer feedback highlights several important improvements: guarding against a potential ZeroDivisionError when _num_training_gpus is zero, utilizing a threading.Event instead of time.sleep() to ensure responsive background thread shutdown, and wrapping per-node parsing in a try-except block so that a single node failure does not abort the entire metrics collection process.
| self._buffer: List[Dict[str, float]] = [] | ||
| self._lock = threading.Lock() | ||
| self._thread: Optional[threading.Thread] = None | ||
| self._running = False | ||
| self._urls: List[str] = [] |
There was a problem hiding this comment.
Using time.sleep() in a background thread can block shutdown for up to collection_interval * 2 seconds. Using a threading.Event allows the thread to wake up and exit immediately when stop() is called.
| self._buffer: List[Dict[str, float]] = [] | |
| self._lock = threading.Lock() | |
| self._thread: Optional[threading.Thread] = None | |
| self._running = False | |
| self._urls: List[str] = [] | |
| self._buffer: List[Dict[str, float]] = [] | |
| self._lock = threading.Lock() | |
| self._thread: Optional[threading.Thread] = None | |
| self._running = False | |
| self._stop_event = threading.Event() | |
| self._urls: List[str] = [] |
| self._running = True | ||
| self._thread = threading.Thread(target=self._loop, daemon=True, name="ray-gpu-monitor") | ||
| self._thread.start() |
There was a problem hiding this comment.
| self._running = False | ||
| if self._thread is not None: | ||
| self._thread.join(timeout=self.collection_interval * 2) |
There was a problem hiding this comment.
Set the stop event to wake up the background thread immediately, avoiding any shutdown delay caused by time.sleep().
| self._running = False | |
| if self._thread is not None: | |
| self._thread.join(timeout=self.collection_interval * 2) | |
| self._running = False | |
| self._stop_event.set() | |
| if self._thread is not None: | |
| self._thread.join() |
| while self._running: | ||
| try: | ||
| snapshot = self._collect(client) | ||
| if snapshot: | ||
| with self._lock: | ||
| self._buffer.append(snapshot) | ||
| except Exception as exc: | ||
| logger.debug(f"RayGpuMonitor: collection error: {exc}") | ||
| time.sleep(self.collection_interval) |
There was a problem hiding this comment.
Use self._stop_event.wait(self.collection_interval) instead of time.sleep() to allow the thread to wake up and exit immediately upon shutdown.
| while self._running: | |
| try: | |
| snapshot = self._collect(client) | |
| if snapshot: | |
| with self._lock: | |
| self._buffer.append(snapshot) | |
| except Exception as exc: | |
| logger.debug(f"RayGpuMonitor: collection error: {exc}") | |
| time.sleep(self.collection_interval) | |
| while self._running: | |
| try: | |
| snapshot = self._collect(client) | |
| if snapshot: | |
| with self._lock: | |
| self._buffer.append(snapshot) | |
| except Exception as exc: | |
| logger.debug(f"RayGpuMonitor: collection error: {exc}") | |
| if self._stop_event.wait(self.collection_interval): | |
| break |
| for node_idx, url in enumerate(self._urls): | ||
| try: | ||
| resp = client.get(url) | ||
| resp.raise_for_status() | ||
| except (httpx.RequestError, httpx.HTTPStatusError) as exc: | ||
| logger.debug(f"RayGpuMonitor: failed to scrape {url}: {exc}") | ||
| continue | ||
|
|
||
| parsed = parse_metrics_text(resp.text) | ||
| for (name, labels), value in parsed.items(): | ||
| labels_dict = dict(labels) | ||
| if name == _GPU_UTIL: | ||
| gpu_idx = labels_dict.get("GpuIndex", "0") | ||
| result[f"node.{node_idx}.gpu.{gpu_idx}.util"] = value | ||
| elif name == _GPU_MEM: | ||
| gpu_idx = labels_dict.get("GpuIndex", "0") | ||
| # Ray reports gram_used in MB; convert to GB. | ||
| result[f"node.{node_idx}.gpu.{gpu_idx}.mem_used_gb"] = value / 1024.0 | ||
| elif name == _RAM_USED: | ||
| # Ray reports mem_used in bytes; convert to GB (host CPU RAM utilized). | ||
| result[f"node.{node_idx}.cpu_ram_used_gb"] = value / (1024**3) | ||
| return result |
There was a problem hiding this comment.
If parse_metrics_text or any processing inside the loop raises an exception, it will abort the entire collection step and skip any remaining nodes. Wrapping the per-node collection and parsing in a try...except block ensures that a failure on one node does not prevent scraping other nodes.
for node_idx, url in enumerate(self._urls):
try:
resp = client.get(url)
resp.raise_for_status()
parsed = parse_metrics_text(resp.text)
for (name, labels), value in parsed.items():
labels_dict = dict(labels)
if name == _GPU_UTIL:
gpu_idx = labels_dict.get("GpuIndex", "0")
result[f"node.{node_idx}.gpu.{gpu_idx}.util"] = value
elif name == _GPU_MEM:
gpu_idx = labels_dict.get("GpuIndex", "0")
# Ray reports gram_used in MB; convert to GB.
result[f"node.{node_idx}.gpu.{gpu_idx}.mem_used_gb"] = value / 1024.0
elif name == _RAM_USED:
# Ray reports mem_used in bytes; convert to GB (host CPU RAM utilized).
result[f"node.{node_idx}.cpu_ram_used_gb"] = value / (1024**3)
except Exception as exc:
logger.debug(f"RayGpuMonitor: failed to scrape or parse {url}: {exc}")
return result
Summary
RayGpuMonitor(skyrl/train/utils/ray_gpu_monitor.py): a background daemon thread that scrapes Ray node Prometheus endpoints each training step and logs per-GPU utilization, GPU memory, and CPU RAM — plus cluster-wide averages — to wandb underray/keys.enable_ray_gpu_monitorconfig flag (defaultTrue).train/tokens_per_second_per_gputo SFT (parallel to the RL metric).Test plan
uv run --isolated --extra dev pytest tests/train/utils/test_ray_gpu_monitor.py(collection, averaging, thread-safety — no live Ray cluster or GPUs required)