# Train XGBoost (Ray on Spark - Distributed)

Distributed XGBoost training using Ray on Spark with MLflow tracking,
**OMP_NUM_THREADS fix**, per-worker diagnostics, and system metrics collection.

**Key fix:** Databricks silently sets `OMP_NUM_THREADS=1` on executors, causing
XGBoost to use only 1 CPU core. This notebook applies a 3-layer fix:
1. `spark.executorEnv.OMP_NUM_THREADS` in Spark config (Layer 1 - databricks.yml)
2. `ray.init(runtime_env={"env_vars": {"OMP_NUM_THREADS": ...}})` (Layer 2)
3. Worker-level `os.environ` + ctypes before `import xgboost` (Layer 3)

**Features:**
- OMP diagnostics via `OmpDiagnosticsCollector` actor (zero-CPU)
- Per-worker system metrics via `WorkerMetricsMonitor` actors
- `DataParallelTrainer` with custom train function for OMP control
- Environment validation gate (`src.validate_env`)
- Shared config presets from `src.config`

## Setup Widgets

In [None]:
# Global error tracking
_notebook_errors = []
def log_error(error_msg, exc=None):
    import traceback
    entry = {"error": str(error_msg)}
    if exc: entry["traceback"] = traceback.format_exc()
    _notebook_errors.append(entry)
    print(f"ERROR LOGGED: {error_msg}")

dbutils.widgets.dropdown("data_size", "tiny", ["tiny", "small", "medium", "medium_large", "large", "xlarge"], "Data Size")
dbutils.widgets.text("node_type", "D8sv5", "Node Type")
dbutils.widgets.dropdown("run_mode", "full", ["full", "smoke"], "Run Mode")
dbutils.widgets.text("num_workers", "0", "Num Workers (0=auto)")
dbutils.widgets.text("cpus_per_worker", "0", "CPUs per Worker (0=auto)")
dbutils.widgets.text("warehouse_id", "148ccb90800933a1", "Databricks SQL Warehouse ID")
dbutils.widgets.text("catalog", "brian_gen_ai", "Catalog")
dbutils.widgets.text("schema", "xgb_scaling", "Schema")
dbutils.widgets.text("table_name", "", "Table Name (override)")

In [None]:
# Get widget values
data_size = dbutils.widgets.get("data_size")
node_type = dbutils.widgets.get("node_type")
run_mode = dbutils.widgets.get("run_mode")
num_workers_input = int(dbutils.widgets.get("num_workers"))
cpus_per_worker_input = int(dbutils.widgets.get("cpus_per_worker"))
warehouse_id = dbutils.widgets.get("warehouse_id").strip()
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
table_name_override = dbutils.widgets.get("table_name").strip()

import sys, os
notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
repo_root = "/".join(notebook_path.split("/")[:-2])
sys.path.insert(0, f"/Workspace{repo_root}")
from src.config import PRESETS as CONFIG_PRESETS, get_preset

SIZE_PRESETS = {n: {"suffix": p.table_suffix, "rows": p.rows, "features": p.total_features} for n, p in CONFIG_PRESETS.items()}

if table_name_override:
    input_table = f"{catalog}.{schema}.{table_name_override}"
    data_size_label = table_name_override.replace("imbalanced_", "")
else:
    preset = get_preset(data_size)
    input_table = f"{catalog}.{schema}.imbalanced_{preset.table_suffix}"
    data_size_label = data_size

run_name = f"ray_smoke_{node_type}" if run_mode == "smoke" else f"ray_{data_size_label}{'_'+str(num_workers_input)+'w' if num_workers_input > 0 else ''}_{node_type}"
print(f"Config: {data_size} | {node_type} | {run_mode} | table={input_table} | run={run_name}")

## Environment Validation Gate

In [None]:
from src.validate_env import validate_environment
validate_environment(track="ray-scaling", expected_workers=num_workers_input if num_workers_input > 0 else None, raise_on_failure=False)
if not _env_report.passed:
    print(f"WARNING: {len(_env_report.errors)} validation error(s) — continuing anyway for debugging")


## MLflow Setup

In [None]:
import os

# Enable system metrics logging BEFORE importing mlflow
os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true"

import mlflow

# Also call the enable function after import
mlflow.enable_system_metrics_logging()

# Get current user for experiment path
user_email = spark.sql("SELECT current_user()").collect()[0][0]
experiment_path = f"/Users/{user_email}/xgb_scaling_benchmark"

# Set experiment
mlflow.set_experiment(experiment_path)
print(f"MLflow experiment: {experiment_path}")
print(f"System metrics logging enabled: {os.environ.get('MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING')}")

## Initialize Ray on Spark

In [None]:
import time
import os

# Check Ray version and availability
try:
    import ray
    print(f"Ray version: {ray.__version__}")
except ImportError as e:
    raise RuntimeError(f"Ray not available: {e}")

from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster

# Setup Databricks environment for Ray workers (required for MLflow + Ray Data access)
databricks_host_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
databricks_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

# Ray Data Databricks reader expects hostname without scheme.
databricks_host = databricks_host_url.replace("https://", "").replace("http://", "").rstrip("/")

# Default env setup for MLflow and general Databricks SDK usage.
os.environ["DATABRICKS_HOST"] = databricks_host_url
os.environ["DATABRICKS_TOKEN"] = databricks_token
print(f"Databricks Host URL: {databricks_host_url}")
print(f"Databricks Hostname (Ray Data): {databricks_host}")
print("Databricks Token: [CONFIGURED]")

# Get cluster info
sc = spark.sparkContext
num_executors = sc._jsc.sc().getExecutorMemoryStatus().size() - 1  # Exclude driver
print(f"\nSpark executors (workers): {num_executors}")

if num_executors < 1:
    print("WARNING: No executors detected. This may indicate a single-node cluster.")
    print("Ray on Spark requires a multi-node cluster (num_workers >= 1)")

# Determine per-node vCPU and leave one CPU for Spark/system buffer
import re
node_type_lower = node_type.lower()
node_vcpus_match = re.search(r"[de](\d+)", node_type_lower)
node_vcpus = int(node_vcpus_match.group(1)) if node_vcpus_match else 8
allocatable_cpus_per_node = max(1, node_vcpus - 1)

# Determine number of Ray workers (best practice: one worker per executor)
if num_workers_input > 0:
    num_workers = num_workers_input
else:
    num_workers = max(1, num_executors)

# Keep worker count aligned with available executors
if num_workers > num_executors:
    print(f"Capping num_workers from {num_workers} to {num_executors} (executor count)")
    num_workers = num_executors

# CPUs per worker for Ray training
if cpus_per_worker_input > 0:
    cpus_per_worker = cpus_per_worker_input
else:
    cpus_per_worker = allocatable_cpus_per_node

# Enforce per-node CPU safety cap
if cpus_per_worker > allocatable_cpus_per_node:
    print(
        f"Capping cpus_per_worker from {cpus_per_worker} to {allocatable_cpus_per_node} "
        f"for node type {node_type}"
    )
    cpus_per_worker = allocatable_cpus_per_node

# Use the same per-node CPU for Ray cluster worker allocation
num_cpus_worker_node = allocatable_cpus_per_node

print(f"\nResource sizing:")
print(f"  node_type: {node_type}")
print(f"  node_vcpus: {node_vcpus}")
print(f"  allocatable_cpus_per_node: {allocatable_cpus_per_node}")
print(f"  spark_executors: {num_executors}")
print(f"\nRay training configuration:")
print(f"  num_workers: {num_workers}")
print(f"  cpus_per_worker: {cpus_per_worker}")
print(f"  total_requested_cpus: {num_workers * cpus_per_worker}")

In [None]:
# Initialize Ray cluster + OMP fix
print("Starting Ray cluster...")
ray_start = time.time()
os.environ["OMP_NUM_THREADS"] = str(allocatable_cpus_per_node)

try:
    setup_ray_cluster(min_worker_nodes=num_executors, max_worker_nodes=num_executors,
        num_cpus_worker_node=num_cpus_worker_node, num_gpus_worker_node=0, collect_log_to_path="/tmp/ray_logs")
    ray_init_time = time.time() - ray_start

    # CRITICAL OMP FIX (Layer 2): Reconnect with runtime_env
    omp_threads_str = str(cpus_per_worker)
    if ray.is_initialized(): ray.shutdown()
    ray.init(runtime_env={"env_vars": {"OMP_NUM_THREADS": omp_threads_str,
        "DATABRICKS_HOST": databricks_host_url, "DATABRICKS_TOKEN": databricks_token}})
    print(f"Ray reconnected. OMP_NUM_THREADS={omp_threads_str}, init={ray_init_time:.1f}s")

    cluster_resources = ray.cluster_resources()
    print(f"Resources: {cluster_resources}")
    available_cpus = int(cluster_resources.get("CPU", 0))
    required = num_workers * cpus_per_worker + 1
    if required > available_cpus:
        usable = max(1, available_cpus - 1)
        cpus_per_worker = max(1, usable // max(1, num_workers))
        num_workers = max(1, usable // max(1, cpus_per_worker))
    print(f"Final: {num_workers}W x {cpus_per_worker}CPU = {num_workers*cpus_per_worker}+1 overhead")
except Exception as e:
    import traceback; traceback.print_exc(); raise

## Worker-Side System Metrics & OMP Diagnostics

In [None]:
import ray

@ray.remote(num_cpus=0)
class WorkerMetricsMonitor:
    def __init__(self, run_id, node_id, db_host, db_token, sampling_interval=10.0):
        import os
        os.environ.update({"DATABRICKS_HOST": db_host, "DATABRICKS_TOKEN": db_token, "MLFLOW_TRACKING_URI": "databricks"})
        self._run_id, self._node_id, self._si, self._mon = run_id, node_id, sampling_interval, None
        self._rn = ray.get_runtime_context().get_node_id()[:8]
    def start(self):
        from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor
        self._mon = SystemMetricsMonitor(run_id=self._run_id, node_id=self._node_id, sampling_interval=self._si, samples_before_logging=1)
        self._mon.start(); return f"{self._node_id} on {self._rn}"
    def stop(self):
        if self._mon: self._mon.finish(); self._mon = None; return f"{self._node_id} stopped"
        return f"{self._node_id} n/a"

@ray.remote(num_cpus=0)
class OmpDiagnosticsCollector:
    def __init__(self): self._r = {}
    def report(self, rank, diag): self._r[rank] = diag
    def get_all(self): return dict(self._r)

def start_worker_monitors(run_id, db_host, db_token, num_nodes, si=10.0):
    head = ray.get_runtime_context().get_node_id()
    nodes = [n for n in ray.nodes() if n.get("Alive") and n["NodeID"] != head][:num_nodes]
    actors, futs = [], []
    for i, n in enumerate(nodes):
        a = WorkerMetricsMonitor.options(scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(node_id=n["NodeID"], soft=False), name=f"metrics_w{i}").remote(run_id, f"worker_{i}", db_host, db_token, si)
        futs.append(a.start.remote()); actors.append(a)
    for r in ray.get(futs): print(f"  {r}")
    return actors

def stop_worker_monitors(actors):
    if not actors: return
    try:
        for r in ray.get([a.stop.remote() for a in actors], timeout=30): print(f"  {r}")
    except Exception as e: print(f"  WARN: {e}")
    for a in actors:
        try: ray.kill(a)
        except: pass

print("WorkerMetricsMonitor + OmpDiagnosticsCollector defined.")

## Load Data

In [None]:
import ray.data

if not warehouse_id:
    raise ValueError("warehouse_id is required for distributed Ray Data loading")

print(f"Loading data from: {input_table}")
print(f"Using SQL Warehouse: {warehouse_id}")
load_start = time.time()

query = f"SELECT * FROM {input_table}"

# Ray Data's Databricks reader expects DATABRICKS_HOST without scheme.
_original_db_host = os.environ.get("DATABRICKS_HOST")
os.environ["DATABRICKS_HOST"] = databricks_host
try:
    full_ray_ds = ray.data.read_databricks_tables(
        warehouse_id=warehouse_id,
        query=query,
    )
finally:
    # Restore default host URL for downstream APIs.
    if _original_db_host is not None:
        os.environ["DATABRICKS_HOST"] = _original_db_host
    else:
        os.environ["DATABRICKS_HOST"] = databricks_host_url

n_rows = full_ray_ds.count()
all_columns = list(full_ray_ds.schema().names)
if "label" not in all_columns:
    raise ValueError(f"Expected 'label' column in dataset schema, got: {all_columns}")

feature_columns = [c for c in all_columns if c != "label"]
load_time = time.time() - load_start

print(f"Loaded {n_rows:,} rows x {len(all_columns)} columns in {load_time:.1f}s")
print(f"Feature count: {len(feature_columns)}")

# Create a lightweight MLflow input dataset sample (avoid driver OOM)
mlflow_sample_rows = min(10_000, n_rows)
mlflow_sample_df = full_ray_ds.limit(mlflow_sample_rows).to_pandas()
mlflow_dataset = mlflow.data.from_pandas(
    mlflow_sample_df,
    source=input_table,
    name=data_size_label,
    targets="label",
)
print(f"MLflow dataset sample created: {mlflow_dataset.name} ({mlflow_sample_rows:,} rows)")

## Prepare Features and Labels

In [None]:
# Class distribution from Ray Dataset
positive_count = int(full_ray_ds.sum("label"))
negative_count = int(n_rows - positive_count)
minority_ratio = positive_count / n_rows if n_rows else 0.0

print("Class distribution:")
print(f"  Class 0 (majority): {negative_count:,} ({(negative_count / n_rows) * 100:.2f}%)")
print(f"  Class 1 (minority): {positive_count:,} ({(positive_count / n_rows) * 100:.2f}%)")

# Calculate scale_pos_weight for imbalance
scale_pos_weight = negative_count / max(positive_count, 1)
print(f"\nscale_pos_weight: {scale_pos_weight:.2f}")

In [None]:
# Train/test split in Ray Data (keeps ingestion distributed)
split_start = time.time()
train_ray_ds, test_ray_ds = full_ray_ds.train_test_split(test_size=0.2, seed=42)
split_time = time.time() - split_start

train_count = train_ray_ds.count()
test_count = test_ray_ds.count()
train_pos = int(train_ray_ds.sum("label"))
test_pos = int(test_ray_ds.sum("label"))

print(f"Train set: {train_count:,} rows")
print(f"Test set: {test_count:,} rows")
print(f"Train minority: {train_pos:,} ({(train_pos / train_count) * 100:.2f}%)")
print(f"Test minority: {test_pos:,} ({(test_pos / test_count) * 100:.2f}%)")
print(f"Split time: {split_time:.1f}s")

# Bounded evaluation sample for local sklearn metrics
# This avoids collecting the full test split to driver memory.
eval_sample_rows = min(200_000, test_count)
eval_test_df = test_ray_ds.limit(eval_sample_rows).to_pandas()
X_test_eval = eval_test_df[feature_columns]
y_test_eval = eval_test_df["label"]
print(f"Evaluation sample rows: {len(eval_test_df):,}")

## XGBoost Training (Ray Distributed)

In [None]:
from ray.train.xgboost import RayTrainReportCallback, XGBoostConfig
from ray.train import ScalingConfig, RunConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
import ray.data, ray.train

xgb_params = {"objective": "binary:logistic", "tree_method": "hist", "nthread": cpus_per_worker,
    "max_depth": 6, "learning_rate": 0.1, "scale_pos_weight": scale_pos_weight, "seed": 42, "verbosity": 1}
num_boost_round = 100
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=False, resources_per_worker={"CPU": cpus_per_worker})
ray_storage_path = f"/Volumes/{catalog}/{schema}/ray_results/"
os.makedirs(ray_storage_path, exist_ok=True)
run_config = RunConfig(storage_path=ray_storage_path, name="xgb_ray_train")

def xgb_train_fn(config):
    import os, ctypes
    nthread, diag_ref = config.get("nthread", 1), config.get("_omp_diag_ref")
    rank = ray.train.get_context().get_world_rank()
    diag = {"omp_before": os.environ.get("OMP_NUM_THREADS", "NOT_SET")}
    os.environ["OMP_NUM_THREADS"] = str(nthread)
    diag["omp_set_to"] = str(nthread)
    for ln in ["libgomp.so.1", "libomp.so", "libomp.so.5"]:
        try:
            lib = ctypes.CDLL(ln); lib.omp_get_max_threads.restype = ctypes.c_int
            lib.omp_set_num_threads.argtypes = [ctypes.c_int]
            b = lib.omp_get_max_threads(); lib.omp_set_num_threads(nthread)
            diag[f"ctypes_{ln}"] = f"{b}->{lib.omp_get_max_threads()}"
        except OSError: pass
    import xgboost
    for ln in ["libgomp.so.1"]:
        try:
            lib = ctypes.CDLL(ln); lib.omp_get_max_threads.restype = ctypes.c_int
            diag[f"post_{ln}"] = str(lib.omp_get_max_threads())
        except: pass
    if diag_ref:
        try: ray.get(diag_ref.report.remote(rank, diag), timeout=10)
        except: pass
    label_col, n_rounds = config["label_column"], config["num_boost_round"]
    xp = {k: v for k, v in config.items() if k not in ("label_column","num_boost_round","dataset_keys","_omp_diag_ref")}
    tdf = ray.train.get_dataset_shard("train").materialize().to_pandas()
    dtrain = xgboost.DMatrix(tdf.drop(label_col, axis=1), label=tdf[label_col])
    evals = [(dtrain, "train")]
    vds = ray.train.get_dataset_shard("valid")
    if vds:
        vdf = vds.materialize().to_pandas()
        evals.append((xgboost.DMatrix(vdf.drop(label_col, axis=1), label=vdf[label_col]), "valid"))
    ckpt = ray.train.get_checkpoint()
    sm, iters = None, n_rounds
    if ckpt: sm = RayTrainReportCallback.get_model(ckpt); iters = n_rounds - sm.num_boosted_rounds()
    xgboost.train(xp, dtrain, evals=evals, num_boost_round=iters, xgb_model=sm, callbacks=[RayTrainReportCallback()])

train_loop_config = {**xgb_params, "label_column": "label", "num_boost_round": num_boost_round}
print(f"XGB: {xgb_params} | rounds={num_boost_round}")
print(f"OMP: spark_conf(L1) + runtime_env(L2) + env+ctypes before import(L3)")
print(f"DataParallelTrainer: {num_workers}W x {cpus_per_worker}CPU")

In [None]:
print(f"Train: {train_count:,} | Test: {test_count:,}")
_mons, _omp = [], OmpDiagnosticsCollector.remote()
train_loop_config["_omp_diag_ref"] = _omp

with mlflow.start_run(run_name=run_name, log_system_metrics=True) as run:
    run_id = run.info.run_id
    print(f"MLflow: {run_id} ({run_name})")
    try:
        _mons = start_worker_monitors(run_id, databricks_host_url, databricks_token, num_executors)
        mlflow.log_param("worker_metrics_monitors", len(_mons))
    except Exception as e:
        print(f"WARN: monitors failed: {e}"); mlflow.log_param("worker_metrics_monitors", 0)

    mlflow.log_input(mlflow_dataset, context="training")
    for k, v in {"training_mode": "ray_distributed", "data_size": data_size, "node_type": node_type,
        "run_mode": run_mode, "input_table": input_table, "n_rows": n_rows,
        "n_features": len(feature_columns), "num_workers": num_workers,
        "cpus_per_worker": cpus_per_worker, "num_boost_round": num_boost_round,
        "omp_fix": "spark_conf+runtime_env+env_before_import+ctypes+diag",
        "omp_target": cpus_per_worker}.items():
        mlflow.log_param(k, v)
    for k, v in xgb_params.items(): mlflow.log_param(f"xgb_{k}", v)
    mlflow.log_metric("ray_init_time_sec", ray_init_time)
    mlflow.log_metric("data_load_time_sec", load_time)
    mlflow.log_metric("split_time_sec", split_time)

    try:
        print("\nTraining with DataParallelTrainer + OMP fix...")
        t0 = time.time()
        trainer = DataParallelTrainer(train_loop_per_worker=xgb_train_fn, train_loop_config=train_loop_config,
            scaling_config=scaling_config, run_config=run_config,
            datasets={"train": train_ray_ds, "valid": test_ray_ds}, backend_config=XGBoostConfig())
        result = trainer.fit()
        train_time = time.time() - t0
        print(f"Done in {train_time:.1f}s"); mlflow.log_metric("train_time_sec", train_time)
    finally:
        stop_worker_monitors(_mons); _mons = []

    try:
        od = ray.get(_omp.get_all.remote(), timeout=15)
        print(f"\nOMP diag ({len(od)} workers):")
        for r in sorted(od):
            for k, v in sorted(od[r].items()): print(f"  w{r}/{k}: {v}"); mlflow.log_param(f"omp_w{r}_{k}", str(v)[:500])
    except Exception as e: print(f"WARN: OMP diag failed: {e}")

    import xgboost as xgb
    t0 = time.time()
    booster = RayTrainReportCallback.get_model(result.checkpoint)
    yp = booster.predict(xgb.DMatrix(X_test_eval))
    y_pred = (yp > 0.5).astype(int)
    pred_time = time.time() - t0; mlflow.log_metric("predict_time_sec", pred_time)

    from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
    auc_pr, auc_roc = average_precision_score(y_test_eval, yp), roc_auc_score(y_test_eval, yp)
    f1 = f1_score(y_test_eval, y_pred)
    prec, rec = precision_score(y_test_eval, y_pred, zero_division=0), recall_score(y_test_eval, y_pred, zero_division=0)
    for n, v in [("auc_pr",auc_pr),("auc_roc",auc_roc),("f1",f1),("precision",prec),("recall",rec)]:
        mlflow.log_metric(n, v); print(f"  {n}: {v:.4f}")
    cm = confusion_matrix(y_test_eval, y_pred)
    for n, v in [("true_negatives",cm[0,0]),("false_positives",cm[0,1]),("false_negatives",cm[1,0]),("true_positives",cm[1,1])]:
        mlflow.log_metric(n, v)
    print(classification_report(y_test_eval, y_pred, zero_division=0))
    total_time = ray_init_time + load_time + split_time + train_time + pred_time
    mlflow.log_metric("total_time_sec", total_time)
    print(f"\nDone: {run_name} | {total_time:.1f}s | OMP={cpus_per_worker} | {run_id}")

## Shutdown Ray Cluster

In [None]:
print("Shutting down Ray cluster...")
shutdown_ray_cluster()
print("Ray cluster shutdown complete.")

## Exit

In [None]:
import json

# Check if we have all expected variables (indicates successful run)
try:
    result = {
        "status": "ok" if not _notebook_errors else "error",
        "run_name": run_name,
        "run_id": run_id,
        "training_mode": "ray_distributed",
        "data_size": data_size,
        "node_type": node_type,
        "warehouse_id": warehouse_id,
        "n_rows": n_rows,
        "num_workers": num_workers,
        "cpus_per_worker": cpus_per_worker,
        "spark_executors": num_executors,
        "auc_pr": round(auc_pr, 4),
        "train_time_sec": round(train_time, 1),
        "total_time_sec": round(total_time, 1),
    }
    if _notebook_errors:
        result["errors"] = _notebook_errors
except NameError as e:
    # Some variables weren't defined - notebook failed early
    result = {
        "status": "error",
        "error": f"Notebook failed before completion: {e}",
        "errors": _notebook_errors if '_notebook_errors' in dir() else [],
    }

result_json = json.dumps(result)
print(f"\nNotebook result: {result_json}")

dbutils.notebook.exit(result_json)