In [1]:
from __future__ import annotations

import os
import time
import json
import socket
import platform
import traceback
from typing import Any, Dict, List, Optional, Tuple


def test_ray_cluster(
    *,
    address: str = "auto",
    namespace: str = "ray_cluster_test",
    include_internal_kv: bool = True,
    include_actors: bool = True,
    include_tasks: bool = True,
    include_placement_groups: bool = True,
    run_per_node_exec_probe: bool = True,
    run_object_store_probe: bool = True,
    run_cpu_burn_probe: bool = False,
    cpu_burn_seconds: float = 0.75,
    per_node_timeout_s: float = 15.0,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Comprehensive Ray cluster diagnostic.

    What it tries to answer:
      - Are we connected to a Ray cluster? Which address?
      - How many nodes does Ray detect? Which are alive?
      - Node properties: NodeID, IP, resources, labels (if any), node manager info, etc.
      - Cluster resources vs available resources (Ray scheduler view)
      - Autoscaler status (if present on head / internal KV)
      - Placement groups, actors, tasks summary (if enabled + available in your Ray version)
      - Per-node execution probe: can we schedule a task on each node?
      - Object store probe: can we put/get a few MB and report time?
      - Optional CPU burn probe: light compute to verify distributed execution

    Returns a JSON-serializable dict with results (best effort).
    """

    result: Dict[str, Any] = {
        "meta": {
            "timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
            "host": socket.gethostname(),
            "platform": platform.platform(),
            "python": platform.python_version(),
            "pid": os.getpid(),
            "env": {
                "RAY_ADDRESS": os.environ.get("RAY_ADDRESS"),
                "RAY_NAMESPACE": os.environ.get("RAY_NAMESPACE"),
            },
        },
        "connection": {},
        "cluster": {},
        "nodes": {},
        "probes": {},
        "warnings": [],
        "errors": [],
    }

    # --- Lazy imports: keep function usable even if some Ray submodules unavailable ---
    try:
        import ray
        result["meta"]["ray_version"] = getattr(ray, "__version__", "unknown")
    except Exception as e:
        result["errors"].append(f"Failed to import ray: {e}")
        return result

    # Connect
    try:
        ctx = ray.init(address=address, namespace=namespace, ignore_reinit_error=True, log_to_driver=False)
        # ctx can be ray.runtime_context.RuntimeContext-ish depending on version; keep minimal
        result["connection"] = {
            "address_requested": address,
            "dashboard_url": getattr(ctx, "dashboard_url", None),
            "redis_address": getattr(ctx, "address_info", {}).get("redis_address") if hasattr(ctx, "address_info") else None,
            "raylet_socket_name": getattr(ctx, "address_info", {}).get("raylet_socket_name") if hasattr(ctx, "address_info") else None,
        }
    except Exception as e:
        result["errors"].append(f"ray.init failed: {e}")
        result["errors"].append(traceback.format_exc())
        return result

    # Helpers
    def _safe(callable_, *args, **kwargs):
        try:
            return callable_(*args, **kwargs), None
        except Exception as e:
            return None, f"{e}\n{traceback.format_exc()}"

    # --- Cluster resource view ---
    import ray

    cr, err = _safe(ray.cluster_resources)
    ar, err2 = _safe(ray.available_resources)
    result["cluster"]["cluster_resources"] = cr or {}
    result["cluster"]["available_resources"] = ar or {}
    if err:
        result["warnings"].append(f"ray.cluster_resources() failed: {err}")
    if err2:
        result["warnings"].append(f"ray.available_resources() failed: {err2}")

    # --- Nodes: ray.nodes() returns dicts with rich info ---
    nodes, err = _safe(ray.nodes)
    node_list: List[Dict[str, Any]] = nodes or []
    if err:
        result["warnings"].append(f"ray.nodes() failed: {err}")
        node_list = []

    alive = [n for n in node_list if n.get("Alive", False)]
    dead = [n for n in node_list if not n.get("Alive", False)]

    # Normalize node fields into a cleaner structure
    def _node_summary(n: Dict[str, Any]) -> Dict[str, Any]:
        # Ray versions differ in exact keys; keep what exists.
        out: Dict[str, Any] = {
            "node_id": n.get("NodeID") or n.get("NodeId") or n.get("node_id"),
            "alive": bool(n.get("Alive", False)),
            "node_name": n.get("NodeName") or n.get("NodeManagerAddress") or n.get("Hostname"),
            "ip": n.get("NodeManagerAddress") or n.get("IPAddress") or n.get("ip"),
            "resources_total": n.get("Resources", {}),
            "labels": n.get("Labels", None),  # newer Ray may expose Labels
            "raylet": {
                "node_manager_hostname": n.get("NodeManagerHostname"),
                "node_manager_port": n.get("NodeManagerPort"),
                "object_manager_port": n.get("ObjectManagerPort"),
                "raylet_socket_name": n.get("raylet_socket_name"),
            },
            "additional": {},
        }

        # Keep extra keys but avoid huge dumps
        keep_extra = [
            "StartTime", "StartTimeMs", "StartTimeNS",
            "State", "DeathCause", "IsHead", "NodeManagerAddress",
        ]
        for k in keep_extra:
            if k in n:
                out["additional"][k] = n.get(k)

        # Some Ray builds include "Resources" and also "ResourcesTotal" or similar
        for k in ["ResourcesTotal", "ResourcesTotalMap", "ResourcesAvailable", "ResourcesAvailableMap"]:
            if k in n:
                out["additional"][k] = n.get(k)

        return out

    node_summaries = [_node_summary(n) for n in node_list]

    result["nodes"] = {
        "total_detected": len(node_list),
        "alive": len(alive),
        "dead": len(dead),
        "summaries": node_summaries,
    }

    if len(node_list) == 0:
        result["warnings"].append("No nodes returned by ray.nodes(); cluster may be unreachable or in a bad state.")

    # --- Autoscaler status (best effort, varies by Ray + deployment) ---
    autoscaler_info: Dict[str, Any] = {}
    if include_internal_kv:
        # Newer Ray: ray._private.internal_kv; key names vary.
        try:
            from ray._private import internal_kv  # type: ignore
            # internal KV needs init; should already be via ray.init
            internal_kv._internal_kv_initialized()  # may not exist in all versions
        except Exception:
            internal_kv = None  # type: ignore

        if internal_kv is not None:
            # Try common autoscaler keys
            keys_to_try = [
                b"autoscaler/status",                 # sometimes used
                b"autoscaler/summary",                # sometimes used
                b"ray_autoscaler/status",             # sometimes used
                b"RAY_AUTOSCALER_STATUS",             # sometimes used
            ]
            found_any = False
            for k in keys_to_try:
                v, e = _safe(internal_kv._internal_kv_get, k)  # type: ignore
                if v:
                    found_any = True
                    try:
                        autoscaler_info[k.decode("utf-8", "ignore")] = v.decode("utf-8", "ignore")
                    except Exception:
                        autoscaler_info[k.decode("utf-8", "ignore")] = str(v)
            if not found_any:
                autoscaler_info["note"] = "No known autoscaler keys found in internal KV (this can be normal)."
        else:
            autoscaler_info["note"] = "ray._private.internal_kv unavailable; skipping internal KV autoscaler lookup."
    result["cluster"]["autoscaler"] = autoscaler_info

    # --- Placement groups / Actors / Tasks (best effort: Ray APIs differ by version) ---
    # Placement groups
    if include_placement_groups:
        pg_out: Dict[str, Any] = {"available": False, "items": []}
        try:
            from ray.util.placement_group import list_placement_groups  # type: ignore
            pgs, e = _safe(list_placement_groups)
            if pgs is not None:
                pg_out["available"] = True
                # pgs can be list[PlacementGroup] or dict-like
                items = []
                for pg in pgs:
                    try:
                        # Try to access fields safely
                        items.append({
                            "id": str(getattr(pg, "id", None) or getattr(pg, "placement_group_id", None)),
                            "name": getattr(pg, "name", None),
                            "state": str(getattr(pg, "state", None)),
                            "bundles": getattr(pg, "bundles", None),
                            "strategy": getattr(pg, "strategy", None),
                        })
                    except Exception:
                        items.append({"raw": str(pg)})
                pg_out["items"] = items
            if e:
                pg_out["error"] = e
        except Exception as e:
            pg_out["error"] = str(e)
        result["cluster"]["placement_groups"] = pg_out

    # Actors / tasks: prefer state API if present
    state_out: Dict[str, Any] = {"available": False}
    if include_actors or include_tasks:
        try:
            from ray.util.state import list_actors, list_tasks  # type: ignore
            state_out["available"] = True
            if include_actors:
                actors, e = _safe(list_actors, limit=1000)
                state_out["actors"] = {
                    "count": len(actors) if actors else 0,
                    "sample": (actors[:25] if actors else []),
                    "error": e,
                }
            if include_tasks:
                tasks, e = _safe(list_tasks, limit=1000)
                state_out["tasks"] = {
                    "count": len(tasks) if tasks else 0,
                    "sample": (tasks[:25] if tasks else []),
                    "error": e,
                }
        except Exception as e:
            state_out["error"] = str(e)
    result["cluster"]["state_api"] = state_out

    # --- Probes ---
    probes: Dict[str, Any] = {}

    # A remote function to report execution context and node-level details.
    @ray.remote
    def _node_probe() -> Dict[str, Any]:
        import os, socket, time, platform
        import ray
        ctx = ray.get_runtime_context()
        # CPU count here is logical cores visible to the worker proc.
        try:
            import psutil  # optional, might not exist in your image
            mem = psutil.virtual_memory()._asdict()
            cpu_count = psutil.cpu_count(logical=True)
        except Exception:
            mem = None
            cpu_count = os.cpu_count()

        return {
            "ts": time.time(),
            "hostname": socket.gethostname(),
            "ip": ray.util.get_node_ip_address(),
            "pid": os.getpid(),
            "python": platform.python_version(),
            "platform": platform.platform(),
            "ray_node_id": str(ctx.get_node_id()),
            "worker_id": str(getattr(ctx, "get_worker_id", lambda: None)() or ""),
            "job_id": str(ctx.get_job_id()),
            "namespace": str(ctx.namespace),
            "cpu_count": cpu_count,
            "mem": mem,
            "env_sample": {
                "RAY_ADDRESS": os.environ.get("RAY_ADDRESS"),
                "OMP_NUM_THREADS": os.environ.get("OMP_NUM_THREADS"),
            },
        }

    # Determine unique node IDs & schedule one probe per node
    # We do this by using custom resources that Ray sets: "node:<node_id>" typically exists.
    # If absent, we fall back to scheduling N probes and deduplicating by node_id.
    node_resource_keys = []
    if isinstance(cr, dict):
        node_resource_keys = [k for k in cr.keys() if isinstance(k, str) and k.startswith("node:")]

    per_node_exec: Dict[str, Any] = {"attempted": False, "results": [], "errors": []}
    if run_per_node_exec_probe:
        per_node_exec["attempted"] = True
        try:
            futures = []
            if node_resource_keys:
                for nk in node_resource_keys:
                    # Force placement on that node by requiring its unique node resource.
                    futures.append(_node_probe.options(resources={nk: 0.001}).remote())
            else:
                # Fallback: schedule as many probes as alive nodes (or 1) and dedupe
                n = max(1, len(alive))
                futures = [_node_probe.remote() for _ in range(n)]

            # Gather with timeout
            start = time.time()
            out = []
            remaining = list(futures)
            while remaining and (time.time() - start) < per_node_timeout_s:
                done, pending = ray.wait(remaining, num_returns=1, timeout=0.5)
                if done:
                    out.extend(ray.get(done))
                remaining = pending

            if remaining:
                per_node_exec["errors"].append(f"Timeout waiting for {len(remaining)} node probe tasks.")
                # Attempt to fetch any that did finish later without blocking too long
                try:
                    done_now, _ = ray.wait(remaining, num_returns=len(remaining), timeout=0.1)
                    if done_now:
                        out.extend(ray.get(done_now))
                except Exception:
                    pass

            # Deduplicate by ray_node_id
            by_node: Dict[str, Dict[str, Any]] = {}
            for r in out:
                by_node[r.get("ray_node_id", f"unknown-{len(by_node)}")] = r

            per_node_exec["results"] = list(by_node.values())
            per_node_exec["summary"] = {
                "unique_nodes_returned": len(by_node),
                "requested": len(futures),
            }
        except Exception as e:
            per_node_exec["errors"].append(str(e))
            per_node_exec["errors"].append(traceback.format_exc())

    probes["per_node_exec_probe"] = per_node_exec

    # Object store probe (put/get a few MB)
    if run_object_store_probe:
        obj_probe: Dict[str, Any] = {"attempted": True}
        try:
            import numpy as np  # usually available
            size_mb = 8
            arr = np.random.randint(0, 255, size=(size_mb * 1024 * 1024,), dtype=np.uint8)

            t0 = time.time()
            ref = ray.put(arr)
            t1 = time.time()
            got = ray.get(ref)
            t2 = time.time()

            obj_probe.update({
                "size_mb": size_mb,
                "put_seconds": t1 - t0,
                "get_seconds": t2 - t1,
                "roundtrip_seconds": t2 - t0,
                "checksum": int(got[:1024].sum()),  # tiny sanity check
            })
        except Exception as e:
            obj_probe["error"] = str(e)
            obj_probe["traceback"] = traceback.format_exc()
        probes["object_store_probe"] = obj_probe

    # Optional CPU burn (distributed) to validate scheduling + concurrency
    if run_cpu_burn_probe:
        burn: Dict[str, Any] = {"attempted": True, "seconds_each": cpu_burn_seconds}
        try:
            @ray.remote
            def _burn(seconds: float) -> Dict[str, Any]:
                import time, math, socket
                import ray
                end = time.time() + seconds
                x = 0.0
                i = 0
                while time.time() < end:
                    x += math.sin(i) * math.cos(i)
                    i += 1
                return {
                    "hostname": socket.gethostname(),
                    "ip": ray.util.get_node_ip_address(),
                    "iters": i,
                    "acc": x,
                }

            # Try one burn per alive node, else 1
            n = max(1, len(alive))
            t0 = time.time()
            outs = ray.get([_burn.remote(cpu_burn_seconds) for _ in range(n)])
            t1 = time.time()

            burn["tasks"] = n
            burn["wall_seconds"] = t1 - t0
            burn["results_sample"] = outs[: min(25, len(outs))]
        except Exception as e:
            burn["error"] = str(e)
            burn["traceback"] = traceback.format_exc()
        probes["cpu_burn_probe"] = burn

    result["probes"] = probes

    # High-level sanity warnings
    if result["nodes"]["alive"] == 0:
        result["warnings"].append("Ray reports 0 alive nodes.")
    if isinstance(cr, dict) and "CPU" in cr and cr.get("CPU", 0) == 0:
        result["warnings"].append("Cluster resources show CPU=0; scheduler may be unhealthy or nodes not registered.")
    if run_per_node_exec_probe and per_node_exec.get("summary"):
        uniq = per_node_exec["summary"].get("unique_nodes_returned", 0)
        if uniq < max(1, len(alive)):
            result["warnings"].append(
                f"Per-node exec probe returned {uniq} unique nodes, but ray.nodes() shows {len(alive)} alive nodes."
            )

    if verbose:
        # Pretty print a compact summary to stdout
        print("=== Ray Cluster Test Summary ===")
        print("Ray:", result["meta"].get("ray_version"))
        print("Address:", result["connection"].get("address_requested"))
        print("Dashboard:", result["connection"].get("dashboard_url"))
        print("Nodes detected:", result["nodes"]["total_detected"],
              "| alive:", result["nodes"]["alive"],
              "| dead:", result["nodes"]["dead"])
        print("Cluster resources:", result["cluster"].get("cluster_resources", {}))
        print("Available resources:", result["cluster"].get("available_resources", {}))
        if run_per_node_exec_probe:
            s = result["probes"]["per_node_exec_probe"].get("summary", {})
            print("Per-node probe:", s)
        if run_object_store_probe:
            op = result["probes"].get("object_store_probe", {})
            if "error" in op:
                print("Object store probe: ERROR")
            else:
                print("Object store probe (MB):", op.get("size_mb"),
                      "| put:", round(op.get("put_seconds", 0), 4),
                      "| get:", round(op.get("get_seconds", 0), 4),
                      "| rt:", round(op.get("roundtrip_seconds", 0), 4))
        if result["warnings"]:
            print("Warnings:", len(result["warnings"]))
            for w in result["warnings"][:10]:
                print("-", w)
        if result["errors"]:
            print("Errors:", len(result["errors"]))
            for e in result["errors"][:3]:
                print("-", e)

    return result

In [3]:
out = test_ray_cluster(
    address="auto",
    run_cpu_burn_probe=False,   # flip to True if you want a distributed compute sanity test
    cpu_burn_seconds=1.0,
    verbose=True,
)

# If you want a JSON blob to save:
import json
print(json.dumps(out, indent=2, default=str))


2026-01-13 15:55:09,105	INFO worker.py:1821 -- Connecting to existing Ray cluster at address: 172.31.22.212:6379...
2026-01-13 15:55:09,106	INFO worker.py:1839 -- Calling ray.init() again after it has already been called.


=== Ray Cluster Test Summary ===
Ray: 2.53.0
Address: auto
Dashboard: 127.0.0.1:8265
Nodes detected: 3 | alive: 3 | dead: 0
Cluster resources: {'node:172.31.24.15': 1.0, 'CPU': 10.0, 'memory': 16230511002.0, 'object_store_memory': 6955933286.0, 'node:__internal_head__': 1.0, 'node:172.31.22.212': 1.0, 'node:172.31.31.22': 1.0}
Available resources: {'node:172.31.24.15': 1.0, 'CPU': 10.0, 'object_store_memory': 6955933286.0, 'memory': 16230511002.0, 'node:__internal_head__': 1.0, 'node:172.31.22.212': 1.0, 'node:172.31.31.22': 1.0}
Per-node probe: {'unique_nodes_returned': 3, 'requested': 4}
Object store probe (MB): 8 | put: 0.002 | get: 0.0006 | rt: 0.0026
{
  "meta": {
    "timestamp_utc": "2026-01-13T15:55:09Z",
    "host": "ip-172-31-22-212.eu-west-2.compute.internal",
    "platform": "Linux-6.1.159-181.297.amzn2023.x86_64-x86_64-with-glibc2.41",
    "python": "3.11.14",
    "pid": 762,
    "env": {
      "RAY_ADDRESS": null,
      "RAY_NAMESPACE": null
    },
    "ray_version": "2.5

10.0