|
| 1 | +''' |
| 2 | +Collection of examples for using xgboost.spark estimator interface |
| 3 | +================================================================== |
| 4 | +
|
| 5 | +@author: Weichen Xu |
| 6 | +''' |
| 7 | +from pyspark.sql import SparkSession |
| 8 | +from pyspark.sql.functions import rand |
| 9 | +from pyspark.ml.linalg import Vectors |
| 10 | +import sklearn.datasets |
| 11 | +from sklearn.model_selection import train_test_split |
| 12 | +from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor |
| 13 | +from pyspark.ml.evaluation import RegressionEvaluator, MulticlassClassificationEvaluator |
| 14 | + |
| 15 | + |
| 16 | +spark = SparkSession.builder.master("local[*]").getOrCreate() |
| 17 | + |
| 18 | + |
| 19 | +def create_spark_df(X, y): |
| 20 | + return spark.createDataFrame( |
| 21 | + spark.sparkContext.parallelize([ |
| 22 | + (Vectors.dense(features), float(label)) |
| 23 | + for features, label in zip(X, y) |
| 24 | + ]), |
| 25 | + ["features", "label"] |
| 26 | + ) |
| 27 | + |
| 28 | + |
| 29 | +# load diabetes dataset (regression dataset) |
| 30 | +diabetes_X, diabetes_y = sklearn.datasets.load_diabetes(return_X_y=True) |
| 31 | +diabetes_X_train, diabetes_X_test, diabetes_y_train, diabetes_y_test = \ |
| 32 | + train_test_split(diabetes_X, diabetes_y, test_size=0.3, shuffle=True) |
| 33 | + |
| 34 | +diabetes_train_spark_df = create_spark_df(diabetes_X_train, diabetes_y_train) |
| 35 | +diabetes_test_spark_df = create_spark_df(diabetes_X_test, diabetes_y_test) |
| 36 | + |
| 37 | +# train xgboost regressor model |
| 38 | +xgb_regressor = SparkXGBRegressor(max_depth=5) |
| 39 | +xgb_regressor_model = xgb_regressor.fit(diabetes_train_spark_df) |
| 40 | + |
| 41 | +transformed_diabetes_test_spark_df = xgb_regressor_model.transform(diabetes_test_spark_df) |
| 42 | +regressor_evaluator = RegressionEvaluator(metricName="rmse") |
| 43 | +print(f"regressor rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df)}") |
| 44 | + |
| 45 | +diabetes_train_spark_df2 = diabetes_train_spark_df.withColumn( |
| 46 | + "validationIndicatorCol", rand(1) > 0.7 |
| 47 | +) |
| 48 | + |
| 49 | +# train xgboost regressor model with validation dataset |
| 50 | +xgb_regressor2 = SparkXGBRegressor(max_depth=5, validation_indicator_col="validationIndicatorCol") |
| 51 | +xgb_regressor_model2 = xgb_regressor.fit(diabetes_train_spark_df2) |
| 52 | +transformed_diabetes_test_spark_df2 = xgb_regressor_model2.transform(diabetes_test_spark_df) |
| 53 | +print(f"regressor2 rmse={regressor_evaluator.evaluate(transformed_diabetes_test_spark_df2)}") |
| 54 | + |
| 55 | + |
| 56 | +# load iris dataset (classification dataset) |
| 57 | +iris_X, iris_y = sklearn.datasets.load_iris(return_X_y=True) |
| 58 | +iris_X_train, iris_X_test, iris_y_train, iris_y_test = \ |
| 59 | + train_test_split(iris_X, iris_y, test_size=0.3, shuffle=True) |
| 60 | + |
| 61 | +iris_train_spark_df = create_spark_df(iris_X_train, iris_y_train) |
| 62 | +iris_test_spark_df = create_spark_df(iris_X_test, iris_y_test) |
| 63 | + |
| 64 | +# train xgboost classifier model |
| 65 | +xgb_classifier = SparkXGBClassifier(max_depth=5) |
| 66 | +xgb_classifier_model = xgb_classifier.fit(iris_train_spark_df) |
| 67 | + |
| 68 | +transformed_iris_test_spark_df = xgb_classifier_model.transform(iris_test_spark_df) |
| 69 | +classifier_evaluator = MulticlassClassificationEvaluator(metricName="f1") |
| 70 | +print(f"classifier f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df)}") |
| 71 | + |
| 72 | +iris_train_spark_df2 = iris_train_spark_df.withColumn( |
| 73 | + "validationIndicatorCol", rand(1) > 0.7 |
| 74 | +) |
| 75 | + |
| 76 | +# train xgboost classifier model with validation dataset |
| 77 | +xgb_classifier2 = SparkXGBClassifier(max_depth=5, validation_indicator_col="validationIndicatorCol") |
| 78 | +xgb_classifier_model2 = xgb_classifier.fit(iris_train_spark_df2) |
| 79 | +transformed_iris_test_spark_df2 = xgb_classifier_model2.transform(iris_test_spark_df) |
| 80 | +print(f"classifier2 f1={classifier_evaluator.evaluate(transformed_iris_test_spark_df2)}") |
| 81 | + |
| 82 | +spark.stop() |
0 commit comments