# Train XGBoost (Ray on Spark - Plasma Object Store Tuning)

Distributed XGBoost training using Ray on Spark with **configurable Plasma/Object Store parameters**
and **per-worker system metrics collection**.

This notebook extends `train_xgb_ray.ipynb` with additional widget parameters for tuning Ray's
object store (shared memory), spilling configuration, and heap memory allocation.

**Experiment Goal:** Find optimal object store configuration for 10M+ row datasets.

**New Parameters (vs base notebook):**
- `obj_store_mem_gb`: Object store memory per worker node in GB (0 = Ray default ~30% of RAM)
- `head_obj_store_mem_gb`: Object store memory for head node in GB (0 = Ray default)
- `heap_mem_gb`: Heap memory per worker in GB (0 = Ray default)
- `spill_dir`: Object spilling directory path (default: /local_disk0/ray_spill)
- `ray_temp_dir`: Ray temp root directory (default: /local_disk0/tmp)
- `allow_slow_storage`: Set RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1 to bypass /dev/shm cap

**Worker System Metrics:**
- Deploys `SystemMetricsMonitor` Ray actors on each worker node
- Logs CPU, memory, disk, network metrics per worker to the same MLflow run
- Metric names: `system/worker_0/cpu_utilization_percentage`, `system/worker_1/...`, etc.
- Credentials passed from driver via `dbutils` context to workers via Ray object store

**Requirements:**
- Databricks ML Runtime 17.3 LTS (includes Ray 2.37.0)
- Multi-node cluster (2+ workers recommended)

**MLflow:**
- System metrics enabled (driver + all worker nodes)
- All Plasma config params logged for experiment comparison

## Setup Widgets

In [None]:
# Global error tracking - captures errors from any cell
_notebook_errors = []

def log_error(error_msg, exc=None):
    """Log an error for later retrieval in exit cell."""
    import traceback
    entry = {"error": str(error_msg)}
    if exc:
        entry["traceback"] = traceback.format_exc()
    _notebook_errors.append(entry)
    print(f"ERROR LOGGED: {error_msg}")

# === Standard parameters (same as train_xgb_ray.ipynb) ===
dbutils.widgets.dropdown("data_size", "tiny", ["tiny", "small", "medium", "large", "xlarge"], "Data Size")
dbutils.widgets.text("node_type", "D8sv5", "Node Type")
dbutils.widgets.dropdown("run_mode", "full", ["full", "smoke"], "Run Mode")
dbutils.widgets.text("num_workers", "0", "Num Workers (0=auto)")
dbutils.widgets.text("cpus_per_worker", "0", "CPUs per Worker (0=auto)")
dbutils.widgets.text("warehouse_id", "148ccb90800933a1", "Databricks SQL Warehouse ID")
dbutils.widgets.text("catalog", "brian_gen_ai", "Catalog")
dbutils.widgets.text("schema", "xgb_scaling", "Schema")
dbutils.widgets.text("table_name", "", "Table Name (override)")

# === NEW: Plasma / Object Store tuning parameters ===
dbutils.widgets.text("obj_store_mem_gb", "0", "Object Store Memory per Worker (GB, 0=default)")
dbutils.widgets.text("head_obj_store_mem_gb", "0", "Object Store Memory Head (GB, 0=default)")
dbutils.widgets.text("heap_mem_gb", "0", "Heap Memory per Worker (GB, 0=default)")
dbutils.widgets.text("spill_dir", "/local_disk0/ray_spill", "Object Spill Directory")
dbutils.widgets.text("ray_temp_dir", "/local_disk0/tmp", "Ray Temp Root Directory")
dbutils.widgets.dropdown("allow_slow_storage", "0", ["0", "1"], "Allow Slow Storage (bypass /dev/shm)")

In [None]:
# Get widget values
data_size = dbutils.widgets.get("data_size")
node_type = dbutils.widgets.get("node_type")
run_mode = dbutils.widgets.get("run_mode")
num_workers_input = int(dbutils.widgets.get("num_workers"))
cpus_per_worker_input = int(dbutils.widgets.get("cpus_per_worker"))
warehouse_id = dbutils.widgets.get("warehouse_id").strip()
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
table_name_override = dbutils.widgets.get("table_name").strip()

# NEW: Plasma config values
obj_store_mem_gb = float(dbutils.widgets.get("obj_store_mem_gb"))
head_obj_store_mem_gb = float(dbutils.widgets.get("head_obj_store_mem_gb"))
heap_mem_gb = float(dbutils.widgets.get("heap_mem_gb"))
spill_dir = dbutils.widgets.get("spill_dir").strip()
ray_temp_dir = dbutils.widgets.get("ray_temp_dir").strip()
allow_slow_storage = dbutils.widgets.get("allow_slow_storage").strip()

# Convert GB to bytes (0 means use Ray default)
obj_store_mem_bytes = int(obj_store_mem_gb * 1024 * 1024 * 1024) if obj_store_mem_gb > 0 else None
head_obj_store_mem_bytes = int(head_obj_store_mem_gb * 1024 * 1024 * 1024) if head_obj_store_mem_gb > 0 else None
heap_mem_bytes = int(heap_mem_gb * 1024 * 1024 * 1024) if heap_mem_gb > 0 else None

# Dataset size preset mapping
SIZE_PRESETS = {
    "tiny": {"suffix": "10k", "rows": 10_000, "features": 20},
    "small": {"suffix": "1m", "rows": 1_000_000, "features": 100},
    "medium": {"suffix": "10m", "rows": 10_000_000, "features": 250},
    "large": {"suffix": "100m", "rows": 100_000_000, "features": 500},
    "xlarge": {"suffix": "500m", "rows": 500_000_000, "features": 500},
}

# Determine input table
if table_name_override:
    input_table = f"{catalog}.{schema}.{table_name_override}"
    data_size_label = table_name_override.replace("imbalanced_", "")
else:
    preset = SIZE_PRESETS[data_size]
    table_suffix = preset["suffix"]
    input_table = f"{catalog}.{schema}.imbalanced_{table_suffix}"
    data_size_label = data_size

# Build a compact plasma config tag for the run name
plasma_tag = f"os{obj_store_mem_gb:.0f}g" if obj_store_mem_gb > 0 else "osD"
if heap_mem_gb > 0:
    plasma_tag += f"_h{heap_mem_gb:.0f}g"
if allow_slow_storage == "1":
    plasma_tag += "_slow"

# Run naming (prefix with ray_ and worker config + plasma config)
if run_mode == "smoke":
    run_name = f"plasma_smoke_{node_type}"
else:
    worker_suffix = f"_{num_workers_input}w" if num_workers_input > 0 else ""
    run_name = f"plasma_{data_size_label}{worker_suffix}_{node_type}_{plasma_tag}"

print(f"Data size: {data_size}")
print(f"Node type: {node_type}")
print(f"Run mode: {run_mode}")
print(f"Num workers input: {num_workers_input} (0=auto)")
print(f"CPUs per worker input: {cpus_per_worker_input} (0=auto)")
print(f"Warehouse ID: {warehouse_id}")
print(f"Input table: {input_table}")
print(f"Run name: {run_name}")
print(f"\n--- Plasma Object Store Config ---")
print(f"  obj_store_mem_gb: {obj_store_mem_gb} ({obj_store_mem_bytes} bytes)")
print(f"  head_obj_store_mem_gb: {head_obj_store_mem_gb} ({head_obj_store_mem_bytes} bytes)")
print(f"  heap_mem_gb: {heap_mem_gb} ({heap_mem_bytes} bytes)")
print(f"  spill_dir: {spill_dir}")
print(f"  ray_temp_dir: {ray_temp_dir}")
print(f"  allow_slow_storage: {allow_slow_storage}")

## MLflow Setup

In [None]:
import os

# Enable system metrics logging BEFORE importing mlflow
os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true"

# Set slow storage env var if requested (must be set BEFORE Ray starts)
if allow_slow_storage == "1":
    os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
    print("RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1 (bypassing /dev/shm cap)")

import mlflow

# Also call the enable function after import
mlflow.enable_system_metrics_logging()

# Get current user for experiment path
user_email = spark.sql("SELECT current_user()").collect()[0][0]
experiment_path = f"/Users/{user_email}/xgb_scaling_benchmark"

# Set experiment
mlflow.set_experiment(experiment_path)
print(f"MLflow experiment: {experiment_path}")
print(f"System metrics logging enabled: {os.environ.get('MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING')}")

## Initialize Ray on Spark (with Plasma Tuning)

In [None]:
import time
import os
import json as _json

# Check Ray version and availability
try:
    import ray
    print(f"Ray version: {ray.__version__}")
except ImportError as e:
    raise RuntimeError(f"Ray not available: {e}")

from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster

# Setup Databricks environment for Ray workers (required for MLflow + Ray Data access)
databricks_host_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
databricks_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

# Ray Data Databricks reader expects hostname without scheme.
databricks_host = databricks_host_url.replace("https://", "").replace("http://", "").rstrip("/")

# Default env setup for MLflow and general Databricks SDK usage.
os.environ["DATABRICKS_HOST"] = databricks_host_url
os.environ["DATABRICKS_TOKEN"] = databricks_token

print(f"Databricks Host URL: {databricks_host_url}")
print(f"Databricks Hostname (Ray Data): {databricks_host}")
print("Databricks Token: [CONFIGURED]")

# Get cluster info
sc = spark.sparkContext
num_executors = sc._jsc.sc().getExecutorMemoryStatus().size() - 1  # Exclude driver
print(f"\nSpark executors (workers): {num_executors}")

if num_executors < 1:
    print("WARNING: No executors detected. This may indicate a single-node cluster.")
    print("Ray on Spark requires a multi-node cluster (num_workers >= 1)")

# Determine per-node vCPU and leave one CPU for Spark/system buffer
import re
node_type_lower = node_type.lower()
node_vcpus_match = re.search(r"[de](\d+)", node_type_lower)
node_vcpus = int(node_vcpus_match.group(1)) if node_vcpus_match else 8
allocatable_cpus_per_node = max(1, node_vcpus - 1)

# Override OMP_NUM_THREADS — Databricks/Spark sets this to 1 on executors,
# which silently caps XGBoost nthread to 1 regardless of what you set.
# XGBoost's C++ layer does: min(nthread, omp_get_max_threads()) = min(14, 1) = 1
# Setting it here on the driver propagates to Ray workers spawned by setup_ray_cluster.
os.environ["OMP_NUM_THREADS"] = str(allocatable_cpus_per_node)
print(f"\nOMP_NUM_THREADS: {os.environ['OMP_NUM_THREADS']} (overriding Spark default of 1)")

# Determine number of Ray workers (best practice: one worker per executor)
if num_workers_input > 0:
    num_workers = num_workers_input
else:
    num_workers = max(1, num_executors)

# Keep worker count aligned with available executors
if num_workers > num_executors:
    print(f"Capping num_workers from {num_workers} to {num_executors} (executor count)")
    num_workers = num_executors

# CPUs per worker for Ray training
if cpus_per_worker_input > 0:
    cpus_per_worker = cpus_per_worker_input
else:
    cpus_per_worker = allocatable_cpus_per_node

# Enforce per-node CPU safety cap
if cpus_per_worker > allocatable_cpus_per_node:
    print(
        f"Capping cpus_per_worker from {cpus_per_worker} to {allocatable_cpus_per_node} "
        f"for node type {node_type}"
    )
    cpus_per_worker = allocatable_cpus_per_node

# Use the same per-node CPU for Ray cluster worker allocation
num_cpus_worker_node = allocatable_cpus_per_node

print(f"\nResource sizing:")
print(f"  node_type: {node_type}")
print(f"  node_vcpus: {node_vcpus}")
print(f"  allocatable_cpus_per_node: {allocatable_cpus_per_node}")
print(f"  spark_executors: {num_executors}")
print(f"\nRay training configuration:")
print(f"  num_workers: {num_workers}")
print(f"  cpus_per_worker: {cpus_per_worker}")
print(f"  total_requested_cpus: {num_workers * cpus_per_worker}")

In [None]:
# Initialize Ray cluster on Spark WITH Plasma object store tuning
# AND runtime_env to set OMP_NUM_THREADS at process level (before any imports)
print("Starting Ray cluster on Spark (with Plasma tuning + OMP fix)...")
ray_start = time.time()

# --- Build head_node_options with spilling config ---
head_node_options = {}
if spill_dir:
    head_node_options["system_config"] = {
        "object_spilling_config": _json.dumps({
            "type": "filesystem",
            "params": {
                "directory_path": spill_dir
            }
        })
    }
    print(f"Custom spill directory: {spill_dir}")

# --- Collect /dev/shm size for diagnostics ---
try:
    import subprocess
    shm_info = subprocess.run(["df", "-h", "/dev/shm"], capture_output=True, text=True)
    print(f"\n/dev/shm on driver:\n{shm_info.stdout}")
except Exception:
    pass

# --- Build setup_ray_cluster kwargs ---
ray_cluster_kwargs = {
    "min_worker_nodes": num_executors,           # Fixed size cluster for benchmarks
    "max_worker_nodes": num_executors,           # No autoscaling for benchmarks
    "num_cpus_worker_node": num_cpus_worker_node,
    "num_gpus_worker_node": 0,                   # CPU-only training
    "collect_log_to_path": "/tmp/ray_logs",
}

# Plasma tuning parameters (only set if non-default)
if obj_store_mem_bytes is not None:
    ray_cluster_kwargs["object_store_memory_worker_node"] = obj_store_mem_bytes
    print(f"Object store memory per worker: {obj_store_mem_gb:.1f} GB ({obj_store_mem_bytes:,} bytes)")

if head_obj_store_mem_bytes is not None:
    ray_cluster_kwargs["object_store_memory_head_node"] = head_obj_store_mem_bytes
    print(f"Object store memory head node: {head_obj_store_mem_gb:.1f} GB ({head_obj_store_mem_bytes:,} bytes)")

if heap_mem_bytes is not None:
    ray_cluster_kwargs["memory_worker_node"] = heap_mem_bytes
    print(f"Heap memory per worker: {heap_mem_gb:.1f} GB ({heap_mem_bytes:,} bytes)")

if ray_temp_dir:
    ray_cluster_kwargs["ray_temp_root_dir"] = ray_temp_dir
    print(f"Ray temp root dir: {ray_temp_dir}")

if head_node_options:
    ray_cluster_kwargs["head_node_options"] = head_node_options

print(f"\nsetup_ray_cluster kwargs:")
for k, v in ray_cluster_kwargs.items():
    if k == "head_node_options":
        print(f"  {k}: {_json.dumps(v, indent=4)}")
    else:
        print(f"  {k}: {v}")

try:
    ray_cluster = setup_ray_cluster(**ray_cluster_kwargs)
    ray_init_time = time.time() - ray_start
    print(f"\nRay cluster initialized in {ray_init_time:.1f}s")

    # =======================================================================
    # CRITICAL OMP FIX: Reconnect to Ray with runtime_env that sets
    # OMP_NUM_THREADS at the OS level BEFORE the worker Python process starts.
    #
    # setup_ray_cluster() internally calls ray.init() for health checks then
    # ray.shutdown(), leaving RAY_ADDRESS set but no active connection.
    # We reconnect with runtime_env so ALL Ray workers (tasks + actors)
    # inherit OMP_NUM_THREADS=<cpus_per_worker> from process startup.
    #
    # This is the MOST RELIABLE fix because:
    # - The env var is set via os.execvp() before the Python interpreter starts
    # - OpenMP runtime reads it during first library load (import xgboost)
    # - No race condition, no wrong-library problem, no caching issue
    # =======================================================================
    omp_threads_str = str(cpus_per_worker)
    print(f"\nReconnecting to Ray with runtime_env OMP_NUM_THREADS={omp_threads_str}...")
    
    # Disconnect if setup_ray_cluster left a connection open
    if ray.is_initialized():
        ray.shutdown()
    
    ray.init(
        runtime_env={
            "env_vars": {
                "OMP_NUM_THREADS": omp_threads_str,
                # Also propagate Databricks auth for MLflow/Ray Data
                "DATABRICKS_HOST": databricks_host_url,
                "DATABRICKS_TOKEN": databricks_token,
            }
        }
    )
    print(f"Ray reconnected with runtime_env. OMP_NUM_THREADS={omp_threads_str}")
    print(f"Ray head node: {ray.get_runtime_context().get_node_id()[:8]}")
    # =======================================================================

    print(f"\nRay cluster resources:")
    cluster_resources = ray.cluster_resources()
    print(cluster_resources)

    # Collect actual object store size from Ray nodes for diagnostics
    try:
        nodes_info = ray.nodes()
        for node in nodes_info:
            node_id_short = node.get('NodeID', 'unknown')[:8]
            resources = node.get('Resources', {})
            obj_store = resources.get('object_store_memory', 0)
            mem = resources.get('memory', 0)
            cpus = resources.get('CPU', 0)
            alive = node.get('Alive', False)
            print(f"  Node {node_id_short}: CPU={cpus}, memory={mem/(1024**3):.1f}GB, object_store={obj_store/(1024**3):.2f}GB, alive={alive}")
    except Exception as e:
        print(f"  (Could not collect node details: {e})")

    # Preflight CPU validation to avoid trainer stalls from pending actors
    AUTO_CAP_RESOURCES = True
    ray_train_overhead_cpus = 1
    available_cpus = int(cluster_resources.get("CPU", 0))
    required_worker_cpus = int(num_workers * cpus_per_worker)
    required_total_cpus = required_worker_cpus + ray_train_overhead_cpus

    print("\nRay preflight CPU check:")
    print(f"  available_cpus: {available_cpus}")
    print(f"  requested_worker_cpus: {required_worker_cpus} ({num_workers} workers x {cpus_per_worker} CPU)")
    print(f"  ray_train_overhead_cpus: {ray_train_overhead_cpus}")
    print(f"  requested_total_cpus: {required_total_cpus}")

    if required_total_cpus > available_cpus:
        if not AUTO_CAP_RESOURCES:
            raise RuntimeError(
                f"Insufficient Ray CPUs: requested total {required_total_cpus}, available {available_cpus}. "
                "Reduce num_workers/cpus_per_worker or increase cluster size."
            )

        usable_cpus_for_workers = max(1, available_cpus - ray_train_overhead_cpus)
        capped_cpus_per_worker = max(1, usable_cpus_for_workers // max(1, num_workers))
        if capped_cpus_per_worker < cpus_per_worker:
            print(f"  auto-cap: cpus_per_worker {cpus_per_worker} -> {capped_cpus_per_worker}")
            cpus_per_worker = capped_cpus_per_worker

        required_worker_cpus = int(num_workers * cpus_per_worker)
        required_total_cpus = required_worker_cpus + ray_train_overhead_cpus
        if required_total_cpus > available_cpus:
            capped_workers = max(1, usable_cpus_for_workers // max(1, cpus_per_worker))
            if capped_workers < num_workers:
                print(f"  auto-cap: num_workers {num_workers} -> {capped_workers}")
                num_workers = capped_workers

        required_worker_cpus = int(num_workers * cpus_per_worker)
        required_total_cpus = required_worker_cpus + ray_train_overhead_cpus
        if required_total_cpus > available_cpus:
            raise RuntimeError(
                f"Unable to fit Ray resources after auto-cap: requested total {required_total_cpus}, available {available_cpus}."
            )

    print(f"  final_num_workers: {num_workers}")
    print(f"  final_cpus_per_worker: {cpus_per_worker}")
    print(f"  final_requested_worker_cpus: {num_workers * cpus_per_worker}")
    print(f"  final_requested_total_cpus: {(num_workers * cpus_per_worker) + ray_train_overhead_cpus}")
    
except Exception as e:
    print(f"ERROR: Failed to initialize Ray cluster: {e}")
    print(f"\nDebug info:")
    print(f"  Spark version: {spark.version}")
    print(f"  Executors: {num_executors}")
    import traceback
    traceback.print_exc()
    raise

## Worker-Side System Metrics Collection

Deploy MLflow `SystemMetricsMonitor` actors on each Ray worker node to collect per-node CPU, memory,
disk, and network metrics. Each worker logs to the **same MLflow run** as the driver but with a unique
`node_id` prefix (e.g., `system/worker_0/cpu_utilization_percentage`).

**Key design decisions:**
- Uses `SystemMetricsMonitor` directly (NOT `mlflow.start_run`) to avoid side-effects on run status
- Passes Databricks auth credentials from driver via `ray.put()` to the object store
- Workers set `DATABRICKS_HOST`, `DATABRICKS_TOKEN`, and `MLFLOW_TRACKING_URI` env vars for MLflow auth
- Each actor is pinned to a unique worker node via Ray scheduling
- Monitors are started before training and stopped after, with `try/finally` for clean shutdown

In [None]:
# Define Ray actor for per-worker MLflow system metrics collection.
# This actor runs SystemMetricsMonitor directly — it does NOT call mlflow.start_run(),
# so it won't interfere with the driver's run status. It only logs system metrics
# (CPU, memory, disk, network) to the same run_id with a unique node_id prefix.
#
# IMPORTANT: We do NOT pass tracking_uri to SystemMetricsMonitor constructor because
# that parameter only exists in unreleased MLflow (master). Instead, we set the
# MLFLOW_TRACKING_URI environment variable before creating the monitor, which MLflow
# reads internally via mlflow.get_tracking_uri().

import ray

@ray.remote(num_cpus=0)  # Zero CPU requirement — monitoring only, no resource reservation
class WorkerMetricsMonitor:
    """Runs MLflow SystemMetricsMonitor on a Ray worker node.
    
    Logs system metrics to an existing MLflow run with a unique node_id prefix.
    Uses SystemMetricsMonitor directly (not mlflow.start_run) to avoid side effects.
    
    Metric names produced: system/{node_id}/cpu_utilization_percentage, etc.
    """
    
    def __init__(self, run_id: str, node_id: str, db_host: str, db_token: str,
                 sampling_interval: float = 10.0):
        import os
        
        # Set Databricks auth env vars so MLflow can talk to the tracking server.
        # These must be set BEFORE creating SystemMetricsMonitor because the monitor
        # internally creates a BatchMetricsLogger which reads mlflow.get_tracking_uri().
        os.environ["DATABRICKS_HOST"] = db_host
        os.environ["DATABRICKS_TOKEN"] = db_token
        os.environ["MLFLOW_TRACKING_URI"] = "databricks"
        
        self._run_id = run_id
        self._node_id = node_id
        self._sampling_interval = sampling_interval
        self._monitor = None
        
        # Capture the Ray node ID for diagnostics
        self._ray_node_id = ray.get_runtime_context().get_node_id()[:8]
        
    def start(self) -> str:
        """Start collecting system metrics. Returns node info string."""
        from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor
        
        self._monitor = SystemMetricsMonitor(
            run_id=self._run_id,
            node_id=self._node_id,
            sampling_interval=self._sampling_interval,
            samples_before_logging=1,  # Log every sample for fine-grained visibility
        )
        self._monitor.start()
        return f"{self._node_id} started on ray_node={self._ray_node_id}"
    
    def stop(self) -> str:
        """Stop collecting and flush remaining metrics."""
        if self._monitor is not None:
            self._monitor.finish()
            self._monitor = None
            return f"{self._node_id} stopped on ray_node={self._ray_node_id}"
        return f"{self._node_id} was not running"
    
    def status(self) -> dict:
        """Return current monitoring status."""
        return {
            "node_id": self._node_id,
            "ray_node": self._ray_node_id,
            "running": self._monitor is not None,
        }


def start_worker_monitors(run_id: str, db_host: str, db_token: str, 
                           num_nodes: int, sampling_interval: float = 10.0):
    """Launch WorkerMetricsMonitor actors across Ray worker nodes.
    
    Creates one monitor actor per worker node. Each actor is scheduled
    on a specific node via NodeAffinitySchedulingStrategy.
    
    Args:
        run_id: MLflow run ID to log metrics to
        db_host: Databricks workspace URL (with https://)
        db_token: Databricks API token
        num_nodes: Number of Ray worker nodes to monitor
        sampling_interval: Seconds between metric samples
        
    Returns:
        List of actor handles (keep references alive!)
    """
    # Get unique alive worker node IDs (exclude head node)
    head_node_id = ray.get_runtime_context().get_node_id()
    alive_nodes = [n for n in ray.nodes() if n.get("Alive") and n["NodeID"] != head_node_id]
    
    # Cap to requested number of nodes
    target_nodes = alive_nodes[:num_nodes]
    
    print(f"Starting worker metrics monitors for {len(target_nodes)} nodes...")
    print(f"  Head node (excluded): {head_node_id[:8]}")
    
    actors = []
    start_futures = []
    
    for idx, node_info in enumerate(target_nodes):
        node_id_label = f"worker_{idx}"
        ray_node_id = node_info["NodeID"]
        
        # Schedule actor on this specific node using node affinity
        actor = WorkerMetricsMonitor.options(
            scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                node_id=ray_node_id,
                soft=False,
            ),
            name=f"mlflow_metrics_{node_id_label}",
        ).remote(
            run_id=run_id,
            node_id=node_id_label,
            db_host=db_host,
            db_token=db_token,
            sampling_interval=sampling_interval,
        )
        
        start_futures.append(actor.start.remote())
        actors.append(actor)
    
    # Wait for all monitors to start
    results = ray.get(start_futures)
    for r in results:
        print(f"  {r}")
    
    print(f"All {len(actors)} worker monitors started (sampling every {sampling_interval}s)")
    return actors


def stop_worker_monitors(actors):
    """Stop all WorkerMetricsMonitor actors and flush metrics.
    
    Args:
        actors: List of WorkerMetricsMonitor actor handles
    """
    if not actors:
        print("No worker monitors to stop.")
        return
        
    print(f"Stopping {len(actors)} worker metrics monitors...")
    stop_futures = [actor.stop.remote() for actor in actors]
    
    try:
        results = ray.get(stop_futures, timeout=30)
        for r in results:
            print(f"  {r}")
    except ray.exceptions.GetTimeoutError:
        print("  WARNING: Some monitors did not stop within 30s timeout")
    except Exception as e:
        print(f"  WARNING: Error stopping monitors: {e}")
    
    # Kill the actors to free resources
    for actor in actors:
        try:
            ray.kill(actor)
        except Exception:
            pass
    
    print("Worker monitors stopped and cleaned up.")


print("WorkerMetricsMonitor actor defined.")
print("Helper functions: start_worker_monitors(), stop_worker_monitors()")
print(f"\nCredentials available for workers:")
print(f"  DATABRICKS_HOST: {databricks_host_url[:40]}...")
print(f"  DATABRICKS_TOKEN: [CONFIGURED]")

## Load Data

In [None]:
import ray.data

if not warehouse_id:
    raise ValueError("warehouse_id is required for distributed Ray Data loading")

print(f"Loading data from: {input_table}")
print(f"Using SQL Warehouse: {warehouse_id}")
load_start = time.time()

query = f"SELECT * FROM {input_table}"

# Ray Data's Databricks reader expects DATABRICKS_HOST without scheme.
_original_db_host = os.environ.get("DATABRICKS_HOST")
os.environ["DATABRICKS_HOST"] = databricks_host
try:
    full_ray_ds = ray.data.read_databricks_tables(
        warehouse_id=warehouse_id,
        query=query,
    )
finally:
    # Restore default host URL for downstream APIs.
    if _original_db_host is not None:
        os.environ["DATABRICKS_HOST"] = _original_db_host
    else:
        os.environ["DATABRICKS_HOST"] = databricks_host_url

n_rows = full_ray_ds.count()
all_columns = list(full_ray_ds.schema().names)
if "label" not in all_columns:
    raise ValueError(f"Expected 'label' column in dataset schema, got: {all_columns}")

feature_columns = [c for c in all_columns if c != "label"]
load_time = time.time() - load_start

print(f"Loaded {n_rows:,} rows x {len(all_columns)} columns in {load_time:.1f}s")
print(f"Feature count: {len(feature_columns)}")

# Create a lightweight MLflow input dataset sample (avoid driver OOM)
mlflow_sample_rows = min(10_000, n_rows)
mlflow_sample_df = full_ray_ds.limit(mlflow_sample_rows).to_pandas()
mlflow_dataset = mlflow.data.from_pandas(
    mlflow_sample_df,
    source=input_table,
    name=data_size_label,
    targets="label",
)
print(f"MLflow dataset sample created: {mlflow_dataset.name} ({mlflow_sample_rows:,} rows)")

## Prepare Features and Labels

In [None]:
# Class distribution from Ray Dataset
positive_count = int(full_ray_ds.sum("label"))
negative_count = int(n_rows - positive_count)
minority_ratio = positive_count / n_rows if n_rows else 0.0

print("Class distribution:")
print(f"  Class 0 (majority): {negative_count:,} ({(negative_count / n_rows) * 100:.2f}%)")
print(f"  Class 1 (minority): {positive_count:,} ({(positive_count / n_rows) * 100:.2f}%)")

# Calculate scale_pos_weight for imbalance
scale_pos_weight = negative_count / max(positive_count, 1)
print(f"\nscale_pos_weight: {scale_pos_weight:.2f}")

In [None]:
# Train/test split in Ray Data (keeps ingestion distributed)
split_start = time.time()
train_ray_ds, test_ray_ds = full_ray_ds.train_test_split(test_size=0.2, seed=42)
split_time = time.time() - split_start

train_count = train_ray_ds.count()
test_count = test_ray_ds.count()
train_pos = int(train_ray_ds.sum("label"))
test_pos = int(test_ray_ds.sum("label"))

print(f"Train set: {train_count:,} rows")
print(f"Test set: {test_count:,} rows")
print(f"Train minority: {train_pos:,} ({(train_pos / train_count) * 100:.2f}%)")
print(f"Test minority: {test_pos:,} ({(test_pos / test_count) * 100:.2f}%)")
print(f"Split time: {split_time:.1f}s")

# Bounded evaluation sample for local sklearn metrics
eval_sample_rows = min(200_000, test_count)
eval_test_df = test_ray_ds.limit(eval_sample_rows).to_pandas()
X_test_eval = eval_test_df[feature_columns]
y_test_eval = eval_test_df["label"]
print(f"Evaluation sample rows: {len(eval_test_df):,}")

## XGBoost Training (Ray Distributed)

In [None]:
from ray.train.xgboost import RayTrainReportCallback, XGBoostConfig
from ray.train import ScalingConfig, RunConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
import ray.data
import ray.train

# ======================================================================
# OMP Diagnostics Collector — zero-CPU Ray actor that collects OMP state
# from workers and makes it accessible to the driver for MLflow logging.
# This solves the problem that worker print() statements go to Ray worker
# stdout which is not accessible via the Databricks REST API.
# ======================================================================
@ray.remote(num_cpus=0)
class OmpDiagnosticsCollector:
    """Collects OMP diagnostic results from training workers."""
    def __init__(self):
        self._results = {}

    def report(self, worker_rank: int, diagnostics: dict):
        """Called by each worker to report its OMP state."""
        self._results[worker_rank] = diagnostics

    def get_all(self) -> dict:
        """Returns all collected diagnostics keyed by worker rank."""
        return dict(self._results)

print("OmpDiagnosticsCollector actor defined.")

xgb_nthread = cpus_per_worker  # name it clearly for logging

xgb_params = {
    "objective": "binary:logistic",
    "tree_method": "hist",
    "nthread": xgb_nthread,
    "max_depth": 6,
    "learning_rate": 0.1,
    "scale_pos_weight": scale_pos_weight,
    "seed": 42,
    "verbosity": 1,
}

num_boost_round = 100

scaling_config = ScalingConfig(
    num_workers=num_workers,
    use_gpu=False,
    resources_per_worker={"CPU": cpus_per_worker},
)

import os
ray_storage_path = f"/Volumes/{catalog}/{schema}/ray_results/"
os.makedirs(ray_storage_path, exist_ok=True)
print(f"Storage directory created: {ray_storage_path}")

run_config = RunConfig(
    storage_path=ray_storage_path,
    name="xgb_ray_plasma_tune",
)


def xgb_train_fn(config: dict):
    """Custom XGBoost training function with OMP_NUM_THREADS fix + diagnostics.

    ROOT CAUSE: Databricks/Spark sets OMP_NUM_THREADS=1 on executors.
    Ray on Spark workers inherit this. When xgboost is imported, it loads
    libxgboost.so which initializes the OpenMP runtime. The OpenMP runtime
    reads OMP_NUM_THREADS=1 at initialization and caps all thread pools to 1.
    XGBoost's C++ layer then does: effective_threads = min(nthread, omp_get_max_threads())
    = min(14, 1) = 1.

    FIX: Set OMP_NUM_THREADS env var BEFORE importing xgboost, so the
    OpenMP runtime initializes with the correct thread count. Also call
    omp_set_num_threads() on all available OpenMP libraries as belt-and-suspenders.

    DIAGNOSTICS: Reports OMP state to OmpDiagnosticsCollector actor so the
    driver can log it as MLflow params (worker stdout isn't accessible via API).
    """
    import os
    import ctypes

    nthread = config.get("nthread", 1)
    diag_collector_ref = config.get("_omp_diag_ref")
    worker_rank = ray.train.get_context().get_world_rank()
    diagnostics = {}

    # ===================================================================
    # CRITICAL: Set OMP_NUM_THREADS BEFORE importing xgboost.
    # The OpenMP runtime reads this env var at initialization time (when
    # the first OpenMP-using shared library is loaded). Once initialized
    # with threads=1, some runtimes ignore later omp_set_num_threads().
    # ===================================================================
    omp_env_before = os.environ.get("OMP_NUM_THREADS", "NOT_SET")
    os.environ["OMP_NUM_THREADS"] = str(nthread)
    diagnostics["omp_env_at_start"] = omp_env_before
    diagnostics["omp_env_set_to"] = str(nthread)
    print(f"[Worker {worker_rank}] OMP_NUM_THREADS: was={omp_env_before}, now={os.environ['OMP_NUM_THREADS']}")
    print(f"[Worker {worker_rank}] Requested nthread: {nthread}")

    # Also try to set via ctypes on ALL available OpenMP libraries BEFORE import
    omp_libs_found = []
    for lib_name in ["libgomp.so.1", "libomp.so", "libomp.so.5", "libiomp5.so"]:
        try:
            lib = ctypes.CDLL(lib_name)
            lib.omp_get_max_threads.restype = ctypes.c_int
            lib.omp_set_num_threads.argtypes = [ctypes.c_int]
            omp_before = lib.omp_get_max_threads()
            lib.omp_set_num_threads(nthread)
            omp_after = lib.omp_get_max_threads()
            omp_libs_found.append(f"{lib_name}: {omp_before}->{omp_after}")
            diagnostics[f"ctypes_{lib_name}"] = f"{omp_before}->{omp_after}"
            print(f"[Worker {worker_rank}] {lib_name} omp_max_threads: {omp_before} -> {omp_after}")
        except OSError:
            pass

    if not omp_libs_found:
        diagnostics["ctypes_libs"] = "NONE_FOUND"
        print(f"[Worker {worker_rank}] WARNING: No OpenMP library found via ctypes")
    else:
        diagnostics["ctypes_libs"] = ";".join(omp_libs_found)

    # ===================================================================
    # NOW import xgboost — the OpenMP runtime should initialize with
    # OMP_NUM_THREADS already set to the correct value.
    # ===================================================================
    import xgboost

    # Verify post-import: check that the OpenMP runtime has the right value
    for lib_name in ["libgomp.so.1", "libomp.so", "libomp.so.5"]:
        try:
            lib = ctypes.CDLL(lib_name)
            lib.omp_get_max_threads.restype = ctypes.c_int
            actual = lib.omp_get_max_threads()
            diagnostics[f"post_import_{lib_name}"] = str(actual)
            print(f"[Worker {worker_rank}] Post-import {lib_name} omp_max_threads: {actual}")
        except OSError:
            pass

    # Diagnostic: What OpenMP lib is XGBoost linked against?
    xgb_omp_libs = []
    try:
        import subprocess
        if hasattr(xgboost, '__file__'):
            xgb_dir = os.path.dirname(xgboost.__file__)
            ldd_result = subprocess.run(
                ["ldd", os.path.join(xgb_dir, "lib", "libxgboost.so")],
                capture_output=True, text=True, timeout=5
            )
            for line in ldd_result.stdout.split('\n'):
                if 'omp' in line.lower() or 'gomp' in line.lower():
                    xgb_omp_libs.append(line.strip())
                    print(f"[Worker {worker_rank}] XGBoost linked OMP: {line.strip()}")
    except Exception as e:
        print(f"[Worker {worker_rank}] Could not inspect XGBoost libs: {e}")
    diagnostics["xgb_omp_linked_libs"] = "; ".join(xgb_omp_libs) if xgb_omp_libs else "NONE_OR_ERROR"

    # Report diagnostics back to driver via collector actor
    if diag_collector_ref is not None:
        try:
            ray.get(diag_collector_ref.report.remote(worker_rank, diagnostics), timeout=10)
            print(f"[Worker {worker_rank}] OMP diagnostics reported to collector")
        except Exception as e:
            print(f"[Worker {worker_rank}] Failed to report diagnostics: {e}")

    label_column = config["label_column"]
    num_boost_round = config["num_boost_round"]
    xgb_params = {k: v for k, v in config.items()
                  if k not in ("label_column", "num_boost_round", "dataset_keys",
                               "_omp_diag_ref")}

    # Get dataset shards for this worker
    train_ds = ray.train.get_dataset_shard("train")
    train_df = train_ds.materialize().to_pandas()
    train_X = train_df.drop(label_column, axis=1)
    train_y = train_df[label_column]
    dtrain = xgboost.DMatrix(train_X, label=train_y)

    evals = [(dtrain, "train")]

    # Validation set if available
    valid_ds = ray.train.get_dataset_shard("valid")
    if valid_ds is not None:
        valid_df = valid_ds.materialize().to_pandas()
        valid_X = valid_df.drop(label_column, axis=1)
        valid_y = valid_df[label_column]
        dvalid = xgboost.DMatrix(valid_X, label=valid_y)
        evals.append((dvalid, "valid"))

    # Resume from checkpoint if available
    checkpoint = ray.train.get_checkpoint()
    starting_model = None
    remaining_iters = num_boost_round
    if checkpoint:
        starting_model = RayTrainReportCallback.get_model(checkpoint)
        starting_iter = starting_model.num_boosted_rounds()
        remaining_iters = num_boost_round - starting_iter

    print(f"[Worker {worker_rank}] Starting xgboost.train: nthread={xgb_params.get('nthread')}, "
          f"boost_rounds={remaining_iters}, OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')}")

    # Train with RayTrainReportCallback for metrics reporting + checkpointing
    bst = xgboost.train(
        xgb_params,
        dtrain=dtrain,
        evals=evals,
        num_boost_round=remaining_iters,
        xgb_model=starting_model,
        callbacks=[RayTrainReportCallback()],
    )


# Build config dict: xgb_params + metadata for the train function
train_loop_config = {
    **xgb_params,
    "label_column": "label",
    "num_boost_round": num_boost_round,
}

print("XGBoost parameters:")
for k, v in xgb_params.items():
    print(f"  {k}: {v}")
print(f"  num_boost_round: {num_boost_round}")
print(f"\n=== OMP FIX STRATEGY ===")
print(f"1. runtime_env sets OMP_NUM_THREADS={xgb_nthread} at Ray worker process startup")
print(f"2. xgb_train_fn sets os.environ['OMP_NUM_THREADS']={xgb_nthread} BEFORE import xgboost")
print(f"3. ctypes omp_set_num_threads({xgb_nthread}) on all OMP libs BEFORE import xgboost")
print(f"4. xgb_params['nthread']={xgb_nthread} as the XGBoost-level parameter")
print(f"5. OmpDiagnosticsCollector reports worker OMP state back as MLflow params")
print(f"\nRay ScalingConfig:")
print(f"  num_workers: {scaling_config.num_workers}")
print(f"  resources_per_worker: {scaling_config.resources_per_worker}")
print(f"\nUsing DataParallelTrainer with custom train_loop_per_worker")

In [None]:
print(f"Train Ray Dataset: {train_count:,} rows")
print(f"Test Ray Dataset: {test_count:,} rows")

# Track worker monitor actors for cleanup
_worker_monitor_actors = []

# Create OMP diagnostics collector actor
_omp_diag_collector = OmpDiagnosticsCollector.remote()
train_loop_config["_omp_diag_ref"] = _omp_diag_collector
print("OMP diagnostics collector created and passed to train config")

with mlflow.start_run(run_name=run_name, log_system_metrics=True) as run:
    run_id = run.info.run_id
    print(f"MLflow run ID: {run_id}")
    print(f"MLflow run name: {run_name}")

    # ======================================================================
    # START WORKER-SIDE SYSTEM METRICS MONITORS
    # Each worker node gets a SystemMetricsMonitor logging to the same run_id
    # with a unique node_id prefix (e.g., system/worker_0/cpu_utilization_%)
    # ======================================================================
    try:
        _worker_monitor_actors = start_worker_monitors(
            run_id=run_id,
            db_host=databricks_host_url,
            db_token=databricks_token,
            num_nodes=num_executors,
            sampling_interval=10.0,
        )
        mlflow.log_param("worker_metrics_monitors", len(_worker_monitor_actors))
        print(f"Worker metrics monitors active: {len(_worker_monitor_actors)}")
    except Exception as e:
        import traceback
        full_error = traceback.format_exc()
        print(f"WARNING: Could not start worker monitors: {e}")
        print(f"Full traceback:\n{full_error}")
        print("Continuing without worker-side system metrics.")
        mlflow.log_param("worker_metrics_monitors", 0)
        mlflow.log_param("worker_metrics_error", str(e)[:500])
    # ======================================================================

    # Log input dataset
    mlflow.log_input(mlflow_dataset, context="training")
    print(f"Logged input dataset: {input_table}")

    # Log standard parameters
    mlflow.log_param("training_mode", "ray_distributed_plasma_tune")
    mlflow.log_param("data_size", data_size)
    mlflow.log_param("node_type", node_type)
    mlflow.log_param("run_mode", run_mode)
    mlflow.log_param("warehouse_id", warehouse_id)
    mlflow.log_param("input_table", input_table)
    mlflow.log_param("n_rows", n_rows)
    mlflow.log_param("n_features", len(feature_columns))
    mlflow.log_param("minority_ratio", round(minority_ratio, 4))
    mlflow.log_param("train_size", train_count)
    mlflow.log_param("test_size", test_count)
    mlflow.log_param("eval_sample_rows", len(eval_test_df))

    # Log Ray params
    mlflow.log_param("num_workers", num_workers)
    mlflow.log_param("cpus_per_worker", cpus_per_worker)
    mlflow.log_param("spark_executors", num_executors)
    mlflow.log_param("num_boost_round", num_boost_round)

    # === Log OMP fix strategy ===
    mlflow.log_param("omp_fix_strategy", "runtime_env+env_before_import+ctypes+diag_collector")
    mlflow.log_param("omp_target_threads", cpus_per_worker)

    # === Log Plasma/Object Store config params ===
    mlflow.log_param("plasma_obj_store_mem_gb", obj_store_mem_gb)
    mlflow.log_param("plasma_head_obj_store_mem_gb", head_obj_store_mem_gb)
    mlflow.log_param("plasma_heap_mem_gb", heap_mem_gb)
    mlflow.log_param("plasma_spill_dir", spill_dir)
    mlflow.log_param("plasma_ray_temp_dir", ray_temp_dir)
    mlflow.log_param("plasma_allow_slow_storage", allow_slow_storage)
    mlflow.log_param("plasma_tag", plasma_tag)

    # Log actual object store sizes from Ray nodes
    try:
        nodes_info = ray.nodes()
        total_obj_store = sum(n.get('Resources', {}).get('object_store_memory', 0) for n in nodes_info if n.get('Alive'))
        total_memory = sum(n.get('Resources', {}).get('memory', 0) for n in nodes_info if n.get('Alive'))
        mlflow.log_metric("actual_total_obj_store_gb", round(total_obj_store / (1024**3), 2))
        mlflow.log_metric("actual_total_memory_gb", round(total_memory / (1024**3), 2))
        mlflow.log_metric("actual_node_count", len([n for n in nodes_info if n.get('Alive')]))
    except Exception:
        pass

    # Log XGBoost params
    for k, v in xgb_params.items():
        mlflow.log_param(f"xgb_{k}", v)

    # Log timing
    mlflow.log_metric("ray_init_time_sec", ray_init_time)
    mlflow.log_metric("data_load_time_sec", load_time)
    mlflow.log_metric("split_time_sec", split_time)

    # ======================================================================
    # TRAINING (with worker monitor cleanup in finally block)
    # ======================================================================
    try:
        # Train model
        print("\nTraining XGBoost with Ray Train (DataParallelTrainer + OMP fix)...")
        print(f"OMP fix: runtime_env OMP_NUM_THREADS={cpus_per_worker} + env before import + ctypes")
        train_start = time.time()

        trainer = DataParallelTrainer(
            train_loop_per_worker=xgb_train_fn,
            train_loop_config=train_loop_config,
            scaling_config=scaling_config,
            run_config=run_config,
            datasets={"train": train_ray_ds, "valid": test_ray_ds},
            backend_config=XGBoostConfig(),
        )
        result = trainer.fit()

        train_time = time.time() - train_start
        print(f"Training completed in {train_time:.1f}s")
        print(f"Best result: {result.metrics}")

        mlflow.log_metric("train_time_sec", train_time)

    finally:
        # ======================================================================
        # STOP WORKER-SIDE SYSTEM METRICS MONITORS
        # Always stop monitors, even if training fails, to flush buffered metrics
        # ======================================================================
        stop_worker_monitors(_worker_monitor_actors)
        _worker_monitor_actors = []

    # ======================================================================
    # COLLECT AND LOG OMP DIAGNOSTICS FROM WORKERS
    # ======================================================================
    try:
        omp_diag_results = ray.get(_omp_diag_collector.get_all.remote(), timeout=15)
        print(f"\n=== OMP DIAGNOSTICS FROM {len(omp_diag_results)} WORKERS ===")
        for rank in sorted(omp_diag_results.keys()):
            diag = omp_diag_results[rank]
            print(f"\nWorker {rank}:")
            for k, v in sorted(diag.items()):
                print(f"  {k}: {v}")
                # Log each diagnostic as an MLflow param
                param_key = f"omp_w{rank}_{k}"
                # MLflow param values are limited to 500 chars
                mlflow.log_param(param_key, str(v)[:500])
        mlflow.log_param("omp_diag_workers_reporting", len(omp_diag_results))
    except Exception as e:
        print(f"WARNING: Could not collect OMP diagnostics: {e}")
        mlflow.log_param("omp_diag_error", str(e)[:500])
    # ======================================================================

    # Predictions on bounded evaluation sample
    print("\nGenerating predictions on evaluation sample...")
    pred_start = time.time()

    import xgboost as xgb
    checkpoint = result.checkpoint
    booster = RayTrainReportCallback.get_model(checkpoint)

    dtest = xgb.DMatrix(X_test_eval)
    y_pred_proba = booster.predict(dtest)
    y_pred = (y_pred_proba > 0.5).astype(int)

    pred_time = time.time() - pred_start
    mlflow.log_metric("predict_time_sec", pred_time)

    # Evaluation
    print("\nEvaluating...")
    from sklearn.metrics import (
        average_precision_score,
        roc_auc_score,
        f1_score,
        precision_score,
        recall_score,
        classification_report,
        confusion_matrix,
    )

    auc_pr = average_precision_score(y_test_eval, y_pred_proba)
    auc_roc = roc_auc_score(y_test_eval, y_pred_proba)
    f1 = f1_score(y_test_eval, y_pred)
    precision = precision_score(y_test_eval, y_pred, zero_division=0)
    recall = recall_score(y_test_eval, y_pred, zero_division=0)

    print(f"\nResults:")
    print(f"  AUC-PR (primary): {auc_pr:.4f}")
    print(f"  AUC-ROC: {auc_roc:.4f}")
    print(f"  F1: {f1:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")

    mlflow.log_metric("auc_pr", auc_pr)
    mlflow.log_metric("auc_roc", auc_roc)
    mlflow.log_metric("f1", f1)
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)

    cm = confusion_matrix(y_test_eval, y_pred)
    print(f"\nConfusion Matrix:")
    print(f"  TN: {cm[0,0]:,}  FP: {cm[0,1]:,}")
    print(f"  FN: {cm[1,0]:,}  TP: {cm[1,1]:,}")

    mlflow.log_metric("true_negatives", cm[0, 0])
    mlflow.log_metric("false_positives", cm[0, 1])
    mlflow.log_metric("false_negatives", cm[1, 0])
    mlflow.log_metric("true_positives", cm[1, 1])

    print(f"\nClassification Report:")
    print(classification_report(y_test_eval, y_pred, zero_division=0))

    # Total time
    total_time = ray_init_time + load_time + split_time + train_time + pred_time
    mlflow.log_metric("total_time_sec", total_time)

    print(f"\n" + "="*60)
    print(f"Run complete: {run_name}")
    print(f"Plasma config: {plasma_tag}")
    print(f"OMP fix: runtime_env + env_before_import + ctypes (target={cpus_per_worker} threads)")
    print(f"Total time: {total_time:.1f}s (Ray init: {ray_init_time:.1f}s, Load: {load_time:.1f}s, Train: {train_time:.1f}s)")
    print(f"MLflow run ID: {run_id}")
    print(f"="*60)

## Shutdown Ray Cluster

In [None]:
print("Shutting down Ray cluster...")
shutdown_ray_cluster()
print("Ray cluster shutdown complete.")

## Exit

In [None]:
import json

try:
    result = {
        "status": "ok" if not _notebook_errors else "error",
        "run_name": run_name,
        "run_id": run_id,
        "training_mode": "ray_distributed_plasma_tune",
        "data_size": data_size,
        "node_type": node_type,
        "warehouse_id": warehouse_id,
        "n_rows": n_rows,
        "num_workers": num_workers,
        "cpus_per_worker": cpus_per_worker,
        "spark_executors": num_executors,
        "plasma_config": {
            "obj_store_mem_gb": obj_store_mem_gb,
            "head_obj_store_mem_gb": head_obj_store_mem_gb,
            "heap_mem_gb": heap_mem_gb,
            "spill_dir": spill_dir,
            "ray_temp_dir": ray_temp_dir,
            "allow_slow_storage": allow_slow_storage,
        },
        "auc_pr": round(auc_pr, 4),
        "train_time_sec": round(train_time, 1),
        "total_time_sec": round(total_time, 1),
    }
    if _notebook_errors:
        result["errors"] = _notebook_errors
except NameError as e:
    result = {
        "status": "error",
        "error": f"Notebook failed before completion: {e}",
        "errors": _notebook_errors if '_notebook_errors' in dir() else [],
    }

result_json = json.dumps(result)
print(f"\nNotebook result: {result_json}")

dbutils.notebook.exit(result_json)