# Train XGBoost (Ray on Spark - Distributed)

Distributed XGBoost training using Ray on Spark with MLflow tracking.

**Requirements:**
- Databricks ML Runtime 17.3 LTS (includes Ray)
- Multi-node cluster (2+ workers recommended)

**Parameters:**
- `data_size`: Dataset size preset (tiny/small/medium/large/xlarge)
- `node_type`: Node type for run naming (e.g., D8sv5)
- `run_mode`: `full` or `smoke` (smoke uses tiny data)
- `num_workers`: Number of Ray workers (0 = auto based on cluster)
- `cpus_per_worker`: CPUs allocated per Ray worker (0 = auto)
- `warehouse_id`: Databricks SQL Warehouse ID for Ray distributed data loading

**MLflow:**
- System metrics enabled
- Run name: `ray_{data_size}_{node_type}` or `ray_smoke_{node_type}`

## Setup Widgets

In [None]:
# Global error tracking - captures errors from any cell
_notebook_errors = []

def log_error(error_msg, exc=None):
    """Log an error for later retrieval in exit cell."""
    import traceback
    entry = {"error": str(error_msg)}
    if exc:
        entry["traceback"] = traceback.format_exc()
    _notebook_errors.append(entry)
    print(f"ERROR LOGGED: {error_msg}")

dbutils.widgets.dropdown("data_size", "tiny", ["tiny", "small", "medium", "large", "xlarge"], "Data Size")
dbutils.widgets.text("node_type", "D8sv5", "Node Type")
dbutils.widgets.dropdown("run_mode", "full", ["full", "smoke"], "Run Mode")
dbutils.widgets.text("num_workers", "0", "Num Workers (0=auto)")
dbutils.widgets.text("cpus_per_worker", "0", "CPUs per Worker (0=auto)")
dbutils.widgets.text("warehouse_id", "148ccb90800933a1", "Databricks SQL Warehouse ID")
dbutils.widgets.text("catalog", "brian_gen_ai", "Catalog")
dbutils.widgets.text("schema", "xgb_scaling", "Schema")
dbutils.widgets.text("table_name", "", "Table Name (override)")

In [None]:
# Get widget values
data_size = dbutils.widgets.get("data_size")
node_type = dbutils.widgets.get("node_type")
run_mode = dbutils.widgets.get("run_mode")
num_workers_input = int(dbutils.widgets.get("num_workers"))
cpus_per_worker_input = int(dbutils.widgets.get("cpus_per_worker"))
warehouse_id = dbutils.widgets.get("warehouse_id").strip()
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
table_name_override = dbutils.widgets.get("table_name").strip()

# Dataset size preset mapping
SIZE_PRESETS = {
    "tiny": {"suffix": "10k", "rows": 10_000, "features": 20},
    "small": {"suffix": "1m", "rows": 1_000_000, "features": 100},
    "medium": {"suffix": "10m", "rows": 10_000_000, "features": 250},
    "large": {"suffix": "100m", "rows": 100_000_000, "features": 500},
    "xlarge": {"suffix": "500m", "rows": 500_000_000, "features": 500},
}

# Determine input table
if table_name_override:
    input_table = f"{catalog}.{schema}.{table_name_override}"
    data_size_label = table_name_override.replace("imbalanced_", "")
else:
    preset = SIZE_PRESETS[data_size]
    table_suffix = preset["suffix"]
    input_table = f"{catalog}.{schema}.imbalanced_{table_suffix}"
    data_size_label = data_size

# Run naming (prefix with ray_ and worker config)
if run_mode == "smoke":
    run_name = f"ray_smoke_{node_type}"
else:
    worker_suffix = f"_{num_workers_input}w" if num_workers_input > 0 else ""
    run_name = f"ray_{data_size_label}{worker_suffix}_{node_type}"

print(f"Data size: {data_size}")
print(f"Node type: {node_type}")
print(f"Run mode: {run_mode}")
print(f"Num workers input: {num_workers_input} (0=auto)")
print(f"CPUs per worker input: {cpus_per_worker_input} (0=auto)")
print(f"Warehouse ID: {warehouse_id}")
print(f"Input table: {input_table}")
print(f"Run name: {run_name}")

## MLflow Setup

In [None]:
import os

# Enable system metrics logging BEFORE importing mlflow
os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true"

import mlflow

# Also call the enable function after import
mlflow.enable_system_metrics_logging()

# Get current user for experiment path
user_email = spark.sql("SELECT current_user()").collect()[0][0]
experiment_path = f"/Users/{user_email}/xgb_scaling_benchmark"

# Set experiment
mlflow.set_experiment(experiment_path)
print(f"MLflow experiment: {experiment_path}")
print(f"System metrics logging enabled: {os.environ.get('MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING')}")

## Initialize Ray on Spark

In [None]:
import time
import os

# Check Ray version and availability
try:
    import ray
    print(f"Ray version: {ray.__version__}")
except ImportError as e:
    raise RuntimeError(f"Ray not available: {e}")

from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster

# Setup Databricks environment for Ray workers (required for MLflow + Ray Data access)
databricks_host_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
databricks_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

# Ray Data Databricks reader expects hostname without scheme.
databricks_host = databricks_host_url.replace("https://", "").replace("http://", "").rstrip("/")

# Default env setup for MLflow and general Databricks SDK usage.
os.environ["DATABRICKS_HOST"] = databricks_host_url
os.environ["DATABRICKS_TOKEN"] = databricks_token
print(f"Databricks Host URL: {databricks_host_url}")
print(f"Databricks Hostname (Ray Data): {databricks_host}")
print("Databricks Token: [CONFIGURED]")

# Get cluster info
sc = spark.sparkContext
num_executors = sc._jsc.sc().getExecutorMemoryStatus().size() - 1  # Exclude driver
print(f"\nSpark executors (workers): {num_executors}")

if num_executors < 1:
    print("WARNING: No executors detected. This may indicate a single-node cluster.")
    print("Ray on Spark requires a multi-node cluster (num_workers >= 1)")

# Determine per-node vCPU and leave one CPU for Spark/system buffer
import re
node_type_lower = node_type.lower()
node_vcpus_match = re.search(r"[de](\d+)", node_type_lower)
node_vcpus = int(node_vcpus_match.group(1)) if node_vcpus_match else 8
allocatable_cpus_per_node = max(1, node_vcpus - 1)

# Determine number of Ray workers (best practice: one worker per executor)
if num_workers_input > 0:
    num_workers = num_workers_input
else:
    num_workers = max(1, num_executors)

# Keep worker count aligned with available executors
if num_workers > num_executors:
    print(f"Capping num_workers from {num_workers} to {num_executors} (executor count)")
    num_workers = num_executors

# CPUs per worker for Ray training
if cpus_per_worker_input > 0:
    cpus_per_worker = cpus_per_worker_input
else:
    cpus_per_worker = allocatable_cpus_per_node

# Enforce per-node CPU safety cap
if cpus_per_worker > allocatable_cpus_per_node:
    print(
        f"Capping cpus_per_worker from {cpus_per_worker} to {allocatable_cpus_per_node} "
        f"for node type {node_type}"
    )
    cpus_per_worker = allocatable_cpus_per_node

# Use the same per-node CPU for Ray cluster worker allocation
num_cpus_worker_node = allocatable_cpus_per_node

print(f"\nResource sizing:")
print(f"  node_type: {node_type}")
print(f"  node_vcpus: {node_vcpus}")
print(f"  allocatable_cpus_per_node: {allocatable_cpus_per_node}")
print(f"  spark_executors: {num_executors}")
print(f"\nRay training configuration:")
print(f"  num_workers: {num_workers}")
print(f"  cpus_per_worker: {cpus_per_worker}")
print(f"  total_requested_cpus: {num_workers * cpus_per_worker}")

In [None]:
# Initialize Ray cluster on Spark
print("Starting Ray cluster on Spark...")
ray_start = time.time()

try:
    # Ray on Spark setup with per-node CPU safety buffer
    ray_cluster = setup_ray_cluster(
        min_worker_nodes=num_executors,         # Fixed size cluster for benchmarks
        max_worker_nodes=num_executors,         # No autoscaling for benchmarks
        num_cpus_worker_node=num_cpus_worker_node,
        num_gpus_worker_node=0,                 # CPU-only training
        collect_log_to_path="/tmp/ray_logs"
    )
    ray_init_time = time.time() - ray_start
    print(f"Ray cluster initialized in {ray_init_time:.1f}s")
    print(f"Ray head node: {ray.get_runtime_context().node_id}")
    
    print(f"\nRay cluster resources:")
    cluster_resources = ray.cluster_resources()
    print(cluster_resources)

    # Preflight CPU validation to avoid trainer stalls from pending actors
    AUTO_CAP_RESOURCES = True
    # Ray Train needs coordinator/driver CPU in addition to worker CPUs.
    ray_train_overhead_cpus = 1
    available_cpus = int(cluster_resources.get("CPU", 0))
    required_worker_cpus = int(num_workers * cpus_per_worker)
    required_total_cpus = required_worker_cpus + ray_train_overhead_cpus

    print("\nRay preflight CPU check:")
    print(f"  available_cpus: {available_cpus}")
    print(f"  requested_worker_cpus: {required_worker_cpus} ({num_workers} workers x {cpus_per_worker} CPU)")
    print(f"  ray_train_overhead_cpus: {ray_train_overhead_cpus}")
    print(f"  requested_total_cpus: {required_total_cpus}")

    if required_total_cpus > available_cpus:
        if not AUTO_CAP_RESOURCES:
            raise RuntimeError(
                f"Insufficient Ray CPUs: requested total {required_total_cpus}, available {available_cpus}. "
                "Reduce num_workers/cpus_per_worker or increase cluster size."
            )

        # Cap CPUs per worker first, then worker count if still needed.
        usable_cpus_for_workers = max(1, available_cpus - ray_train_overhead_cpus)
        capped_cpus_per_worker = max(1, usable_cpus_for_workers // max(1, num_workers))
        if capped_cpus_per_worker < cpus_per_worker:
            print(
                f"  auto-cap: cpus_per_worker {cpus_per_worker} -> {capped_cpus_per_worker}"
            )
            cpus_per_worker = capped_cpus_per_worker

        required_worker_cpus = int(num_workers * cpus_per_worker)
        required_total_cpus = required_worker_cpus + ray_train_overhead_cpus
        if required_total_cpus > available_cpus:
            capped_workers = max(1, usable_cpus_for_workers // max(1, cpus_per_worker))
            if capped_workers < num_workers:
                print(f"  auto-cap: num_workers {num_workers} -> {capped_workers}")
                num_workers = capped_workers

        required_worker_cpus = int(num_workers * cpus_per_worker)
        required_total_cpus = required_worker_cpus + ray_train_overhead_cpus
        if required_total_cpus > available_cpus:
            raise RuntimeError(
                f"Unable to fit Ray resources after auto-cap: requested total {required_total_cpus}, available {available_cpus}."
            )

    print(f"  final_num_workers: {num_workers}")
    print(f"  final_cpus_per_worker: {cpus_per_worker}")
    print(f"  final_requested_worker_cpus: {num_workers * cpus_per_worker}")
    print(f"  final_requested_total_cpus: {(num_workers * cpus_per_worker) + ray_train_overhead_cpus}")
    
except Exception as e:
    print(f"ERROR: Failed to initialize Ray cluster: {e}")
    print(f"\nDebug info:")
    print(f"  Spark version: {spark.version}")
    print(f"  Executors: {num_executors}")
    import traceback
    traceback.print_exc()
    raise

## Load Data

In [None]:
import ray.data

if not warehouse_id:
    raise ValueError("warehouse_id is required for distributed Ray Data loading")

print(f"Loading data from: {input_table}")
print(f"Using SQL Warehouse: {warehouse_id}")
load_start = time.time()

query = f"SELECT * FROM {input_table}"

# Ray Data's Databricks reader expects DATABRICKS_HOST without scheme.
_original_db_host = os.environ.get("DATABRICKS_HOST")
os.environ["DATABRICKS_HOST"] = databricks_host
try:
    full_ray_ds = ray.data.read_databricks_tables(
        warehouse_id=warehouse_id,
        query=query,
    )
finally:
    # Restore default host URL for downstream APIs.
    if _original_db_host is not None:
        os.environ["DATABRICKS_HOST"] = _original_db_host
    else:
        os.environ["DATABRICKS_HOST"] = databricks_host_url

n_rows = full_ray_ds.count()
all_columns = list(full_ray_ds.schema().names)
if "label" not in all_columns:
    raise ValueError(f"Expected 'label' column in dataset schema, got: {all_columns}")

feature_columns = [c for c in all_columns if c != "label"]
load_time = time.time() - load_start

print(f"Loaded {n_rows:,} rows x {len(all_columns)} columns in {load_time:.1f}s")
print(f"Feature count: {len(feature_columns)}")

# Create a lightweight MLflow input dataset sample (avoid driver OOM)
mlflow_sample_rows = min(10_000, n_rows)
mlflow_sample_df = full_ray_ds.limit(mlflow_sample_rows).to_pandas()
mlflow_dataset = mlflow.data.from_pandas(
    mlflow_sample_df,
    source=input_table,
    name=data_size_label,
    targets="label",
)
print(f"MLflow dataset sample created: {mlflow_dataset.name} ({mlflow_sample_rows:,} rows)")

## Prepare Features and Labels

In [None]:
# Class distribution from Ray Dataset
positive_count = int(full_ray_ds.sum("label"))
negative_count = int(n_rows - positive_count)
minority_ratio = positive_count / n_rows if n_rows else 0.0

print("Class distribution:")
print(f"  Class 0 (majority): {negative_count:,} ({(negative_count / n_rows) * 100:.2f}%)")
print(f"  Class 1 (minority): {positive_count:,} ({(positive_count / n_rows) * 100:.2f}%)")

# Calculate scale_pos_weight for imbalance
scale_pos_weight = negative_count / max(positive_count, 1)
print(f"\nscale_pos_weight: {scale_pos_weight:.2f}")

In [None]:
# Train/test split in Ray Data (keeps ingestion distributed)
split_start = time.time()
train_ray_ds, test_ray_ds = full_ray_ds.train_test_split(test_size=0.2, seed=42)
split_time = time.time() - split_start

train_count = train_ray_ds.count()
test_count = test_ray_ds.count()
train_pos = int(train_ray_ds.sum("label"))
test_pos = int(test_ray_ds.sum("label"))

print(f"Train set: {train_count:,} rows")
print(f"Test set: {test_count:,} rows")
print(f"Train minority: {train_pos:,} ({(train_pos / train_count) * 100:.2f}%)")
print(f"Test minority: {test_pos:,} ({(test_pos / test_count) * 100:.2f}%)")
print(f"Split time: {split_time:.1f}s")

# Bounded evaluation sample for local sklearn metrics
# This avoids collecting the full test split to driver memory.
eval_sample_rows = min(200_000, test_count)
eval_test_df = test_ray_ds.limit(eval_sample_rows).to_pandas()
X_test_eval = eval_test_df[feature_columns]
y_test_eval = eval_test_df["label"]
print(f"Evaluation sample rows: {len(eval_test_df):,}")

## XGBoost Training (Ray Distributed)

In [None]:
# Use Ray Train's XGBoostTrainer (modern approach, replaces deprecated xgboost_ray)
from ray.train.xgboost import XGBoostTrainer
from ray.train import ScalingConfig, RunConfig
import ray.data

# XGBoost hyperparameters - same as single-node for comparison
# Note: Uses xgboost.train() style params, not sklearn API
xgb_params = {
    "objective": "binary:logistic",
    "tree_method": "hist",
    "nthread": cpus_per_worker,  # Must be explicit — Ray Train does NOT auto-set this
    "max_depth": 6,
    "learning_rate": 0.1,
    "scale_pos_weight": scale_pos_weight,
    "seed": 42,
    "verbosity": 1,
}

# Number of boosting rounds
num_boost_round = 100

# Ray scaling config
scaling_config = ScalingConfig(
    num_workers=num_workers,
    use_gpu=False,
    resources_per_worker={"CPU": cpus_per_worker},
)

# Storage path for Ray Train checkpoints
# Using local SSD storage available on each Databricks node
import os

# use volume path for shared storage folder
ray_storage_path = f"/Volumes/{catalog}/{schema}/ray_results/"

# Ensure the storage directory exists
os.makedirs(ray_storage_path, exist_ok=True)
print(f"Storage directory created: {ray_storage_path}")

# Run config with local storage path
run_config = RunConfig(
    storage_path=ray_storage_path,
    name="xgb_ray_train",
)

print(f"Ray RunConfig:")
print(f"  storage_path: {ray_storage_path}")

print("XGBoost parameters:")
for k, v in xgb_params.items():
    print(f"  {k}: {v}")
print(f"  num_boost_round: {num_boost_round}")

print(f"\nRay ScalingConfig:")
print(f"  num_workers: {scaling_config.num_workers}")
print(f"  resources_per_worker: {scaling_config.resources_per_worker}")

print(f"\nRay RunConfig:")
print(f"  storage_path: {ray_storage_path}")

In [None]:
print(f"Train Ray Dataset: {train_count:,} rows")
print(f"Test Ray Dataset: {test_count:,} rows")

# Start MLflow run and train
with mlflow.start_run(run_name=run_name, log_system_metrics=True) as run:
    run_id = run.info.run_id
    print(f"MLflow run ID: {run_id}")
    print(f"MLflow run name: {run_name}")

    # Log input dataset
    mlflow.log_input(mlflow_dataset, context="training")
    print(f"Logged input dataset: {input_table}")

    # Log parameters
    mlflow.log_param("training_mode", "ray_distributed")
    mlflow.log_param("data_size", data_size)
    mlflow.log_param("node_type", node_type)
    mlflow.log_param("run_mode", run_mode)
    mlflow.log_param("warehouse_id", warehouse_id)
    mlflow.log_param("input_table", input_table)
    mlflow.log_param("n_rows", n_rows)
    mlflow.log_param("n_features", len(feature_columns))
    mlflow.log_param("minority_ratio", round(minority_ratio, 4))
    mlflow.log_param("train_size", train_count)
    mlflow.log_param("test_size", test_count)
    mlflow.log_param("eval_sample_rows", len(eval_test_df))

    # Log Ray params
    mlflow.log_param("num_workers", num_workers)
    mlflow.log_param("cpus_per_worker", cpus_per_worker)
    mlflow.log_param("spark_executors", num_executors)
    mlflow.log_param("num_boost_round", num_boost_round)

    # Log XGBoost params
    for k, v in xgb_params.items():
        mlflow.log_param(f"xgb_{k}", v)

    # Log timing
    mlflow.log_metric("ray_init_time_sec", ray_init_time)
    mlflow.log_metric("data_load_time_sec", load_time)
    mlflow.log_metric("split_time_sec", split_time)

    # Train model using Ray Train XGBoostTrainer
    print("\nTraining XGBoost with Ray Train...")
    train_start = time.time()

    trainer = XGBoostTrainer(
        scaling_config=scaling_config,
        run_config=run_config,
        label_column="label",
        num_boost_round=num_boost_round,
        params=xgb_params,
        datasets={"train": train_ray_ds, "valid": test_ray_ds},
    )
    result = trainer.fit()

    train_time = time.time() - train_start
    print(f"Training completed in {train_time:.1f}s")
    print(f"Best result: {result.metrics}")

    mlflow.log_metric("train_time_sec", train_time)

    # Get the trained model for bounded local evaluation
    print("\nGenerating predictions on evaluation sample...")
    pred_start = time.time()

    # Load trained booster from checkpoint
    import xgboost as xgb
    checkpoint = result.checkpoint
    with checkpoint.as_directory() as checkpoint_dir:
        booster = xgb.Booster()
        booster.load_model(f"{checkpoint_dir}/model.ubj")

    # Make predictions on bounded evaluation sample
    dtest = xgb.DMatrix(X_test_eval)
    y_pred_proba = booster.predict(dtest)
    y_pred = (y_pred_proba > 0.5).astype(int)

    pred_time = time.time() - pred_start
    mlflow.log_metric("predict_time_sec", pred_time)

    # Evaluation
    print("\nEvaluating...")
    from sklearn.metrics import (
        average_precision_score,
        roc_auc_score,
        f1_score,
        precision_score,
        recall_score,
        classification_report,
        confusion_matrix,
    )

    # Metrics on bounded evaluation sample
    auc_pr = average_precision_score(y_test_eval, y_pred_proba)
    auc_roc = roc_auc_score(y_test_eval, y_pred_proba)
    f1 = f1_score(y_test_eval, y_pred)
    precision = precision_score(y_test_eval, y_pred, zero_division=0)
    recall = recall_score(y_test_eval, y_pred, zero_division=0)

    print(f"\nResults:")
    print(f"  AUC-PR (primary): {auc_pr:.4f}")
    print(f"  AUC-ROC: {auc_roc:.4f}")
    print(f"  F1: {f1:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")

    # Log metrics
    mlflow.log_metric("auc_pr", auc_pr)
    mlflow.log_metric("auc_roc", auc_roc)
    mlflow.log_metric("f1", f1)
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)

    # Confusion matrix
    cm = confusion_matrix(y_test_eval, y_pred)
    print(f"\nConfusion Matrix:")
    print(f"  TN: {cm[0,0]:,}  FP: {cm[0,1]:,}")
    print(f"  FN: {cm[1,0]:,}  TP: {cm[1,1]:,}")

    mlflow.log_metric("true_negatives", cm[0, 0])
    mlflow.log_metric("false_positives", cm[0, 1])
    mlflow.log_metric("false_negatives", cm[1, 0])
    mlflow.log_metric("true_positives", cm[1, 1])

    # Classification report
    print(f"\nClassification Report:")
    print(classification_report(y_test_eval, y_pred, zero_division=0))

    # Total time
    total_time = ray_init_time + load_time + split_time + train_time + pred_time
    mlflow.log_metric("total_time_sec", total_time)

    print(f"\n" + "="*50)
    print(f"Run complete: {run_name}")
    print(f"Total time: {total_time:.1f}s (Ray init: {ray_init_time:.1f}s, Load: {load_time:.1f}s, Train: {train_time:.1f}s)")
    print(f"MLflow run ID: {run_id}")
    print(f"="*50)

## Shutdown Ray Cluster

In [None]:
print("Shutting down Ray cluster...")
shutdown_ray_cluster()
print("Ray cluster shutdown complete.")

## Exit

In [None]:
import json

# Check if we have all expected variables (indicates successful run)
try:
    result = {
        "status": "ok" if not _notebook_errors else "error",
        "run_name": run_name,
        "run_id": run_id,
        "training_mode": "ray_distributed",
        "data_size": data_size,
        "node_type": node_type,
        "warehouse_id": warehouse_id,
        "n_rows": n_rows,
        "num_workers": num_workers,
        "cpus_per_worker": cpus_per_worker,
        "spark_executors": num_executors,
        "auc_pr": round(auc_pr, 4),
        "train_time_sec": round(train_time, 1),
        "total_time_sec": round(total_time, 1),
    }
    if _notebook_errors:
        result["errors"] = _notebook_errors
except NameError as e:
    # Some variables weren't defined - notebook failed early
    result = {
        "status": "error",
        "error": f"Notebook failed before completion: {e}",
        "errors": _notebook_errors if '_notebook_errors' in dir() else [],
    }

result_json = json.dumps(result)
print(f"\nNotebook result: {result_json}")

dbutils.notebook.exit(result_json)