In [1]:
"""
Lesson 5 (Advanced): Multiclass Gradient-Boosted Trees with One-vs-Rest (PySpark)
=================================================================================

Author: Deb
Date: 2024-06-10

Description:
------------
This script trains a multiclass classifier using Spark's GBTClassifier wrapped inside OneVsRest.
It demonstrates:
 - Data ingestion (download Iris CSV)
 - Feature engineering using VectorAssembler (and optional StandardScaler)
 - StringIndexer to convert categorical labels to numeric
 - OneVsRest wrapper to extend GBT (binary) into multiclass
 - Pipeline construction
 - Hyperparameter tuning with CrossValidator
 - Model saving and loading

Notes:
 - GBTClassifier in pyspark.ml is a binary classifier. Wrapping it with OneVsRest enables
   multiclass classification by training one binary GBT per class (one-versus-rest strategy).
 - For large-scale production, consider XGBoost or LightGBM integration for true distributed GBMs.
"""

from pathlib import Path
import urllib.request

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import GBTClassifier, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator


# ---------------------------
# 1) Initialize Spark Session
# ---------------------------
spark = SparkSession.builder \
    .appName("lesson5-gbt-onevsrest") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", "8") \
    .getOrCreate()

print("✅ Spark Session started")
print("Spark Version:", spark.version)
print("-" * 60)


# ---------------------------
# 2) Download & Load Dataset
# ---------------------------
iris_url = "https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv"
local_path = Path("/tmp/iris.csv")
print("⬇️  Downloading Iris dataset...")
urllib.request.urlretrieve(iris_url, str(local_path))
print(f"✅ Downloaded to {local_path}")

iris_df = spark.read.csv(str(local_path), header=True, inferSchema=True)
print("Schema:")
iris_df.printSchema()
iris_df.show(5)
print("-" * 60)


# ---------------------------
# 3) Preprocessing & Features
# ---------------------------
# 3.1 Convert species (string) to numeric label
label_indexer = StringIndexer(inputCol="species", outputCol="label")

# 3.2 Assemble features into a single vector
feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features_assembled")

# 3.3 (Optional) Standard scaling - not strictly necessary for tree-based models,
# but kept here to show the pipeline pattern and for consistency with other models.
scaler = StandardScaler(inputCol="features_assembled", outputCol="features", withMean=True, withStd=True)


# ---------------------------
# 4) Define Base Binary Estimator (GBT) and One-vs-Rest wrapper
# ---------------------------
# GBTClassifier is binary; OneVsRest turns it into multiclass by training one binary classifier per class.
gbt = GBTClassifier(labelCol="label", featuresCol="features", maxIter=50, maxDepth=5, stepSize=0.2)

ovr = OneVsRest(classifier=gbt, labelCol="label", featuresCol="features")


# ---------------------------
# 5) Build Pipeline
# ---------------------------
pipeline = Pipeline(stages=[label_indexer, assembler, scaler, ovr])


# ---------------------------
# 6) Train/Test Split
# ---------------------------
train_df, test_df = iris_df.randomSplit([0.8, 0.2], seed=42)
print(f"Training samples: {train_df.count()}, Test samples: {test_df.count()}")
print("-" * 60)


# ---------------------------
# 7) Baseline fit (fast check)
# ---------------------------
print("Training baseline pipeline (no hyperparam tuning)...")
baseline_model = pipeline.fit(train_df)
baseline_preds = baseline_model.transform(test_df)

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
baseline_acc = evaluator.evaluate(baseline_preds)
print(f"Baseline Accuracy: {baseline_acc:.4f}")
print("Sample predictions:")
baseline_preds.select("species", "label", "prediction").show(10)
print("-" * 60)


# ---------------------------
# 8) Hyperparameter Tuning with CrossValidator
# ---------------------------
# Build a parameter grid. Because OneVsRest wraps the GBT, we reference GBT params
# through the OneVsRest's classifier parameter path. To access nested params we use:
#   ovr.classifier.paramName  (but ParamGridBuilder accepts the Param object from the estimator)
#
# We fetch the param objects directly from the gbt instance used above for clarity.

param_grid = (
    ParamGridBuilder()
    .addGrid(gbt.maxDepth, [3, 5])        # tree depth
    .addGrid(gbt.maxIter, [20, 50])       # number of boosting iterations
    .addGrid(gbt.stepSize, [0.1, 0.2])    # learning rate
    .build()
)

# Note: CrossValidator internally will train OneVsRest models for each param combination.
crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=3,
    parallelism=2  # run up to 2 models in parallel (adjust to your machine)
)

print("Starting cross-validation (this may take some time)...")
cv_model = crossval.fit(train_df)
print("Cross-validation completed.")
print("-" * 60)


# ---------------------------
# 9) Evaluate Best Model
# ---------------------------
best_model = cv_model.bestModel
print("Best Model summary:")
# It's a PipelineModel; we can inspect stages or parameters if desired
# For brevity, show the params of the underlying GBT by traversing stages:
for i, stage in enumerate(best_model.stages):
    print(f"Stage {i}: {stage.uid} ({stage.__class__.__name__})")

# Make predictions on the test set
cv_preds = best_model.transform(test_df)
cv_acc = evaluator.evaluate(cv_preds)
print(f"Cross-Validated Model Accuracy on test set: {cv_acc:.4f}")

print("Sample predictions from best model:")
cv_preds.select("species", "label", "prediction").show(10)
print("-" * 60)


# ---------------------------
# 10) Save & Load the Best Model
# ---------------------------
model_path = "output/iris_gbt_onevsrest"
print(f"Saving best model to {model_path} ...")
best_model.write().overwrite().save(model_path)
print("Model saved.")

# Loading back (optional check)
from pyspark.ml import PipelineModel
loaded = PipelineModel.load(model_path)
print("Loaded model successfully. Running a quick predict check...")
loaded_preds = loaded.transform(test_df.limit(5)).select("species", "prediction")
loaded_preds.show()
print("-" * 60)


# ---------------------------
# 11) Cleanup
# ---------------------------
spark.stop()
print("Done. Spark session stopped.")


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/08 11:59:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


✅ Spark Session started
Spark Version: 3.5.7
------------------------------------------------------------
⬇️  Downloading Iris dataset...
✅ Downloaded to /tmp/iris.csv
Schema:
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- species: string (nullable = true)

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
|         4.7|        3.2|         1.3|        0.2| setosa|
|         4.6|        3.1|         1.5|        0.2| setosa|
|         5.0|        3.6|         1.4|        0.2| setosa|
+------------+-----------+------------+-----------+-------+
only showing top 5 rows

----------------------------------------

25/11/08 11:59:44 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

Baseline Accuracy: 0.9583
Sample predictions:


                                                                                

+----------+-----+----------+
|   species|label|prediction|
+----------+-----+----------+
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|versicolor|  0.0|       0.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|versicolor|  0.0|       0.0|
+----------+-----+----------+
only showing top 10 rows

------------------------------------------------------------
Starting cross-validation (this may take some time)...


                                                                                

Cross-validation completed.
------------------------------------------------------------
Best Model summary:
Stage 0: StringIndexer_26d39cfd3601 (StringIndexerModel)
Stage 1: VectorAssembler_1d5ce6837e9c (VectorAssembler)
Stage 2: StandardScaler_f7ba0702dd28 (StandardScalerModel)
Stage 3: OneVsRestModel_746fd41be852 (OneVsRestModel)
Cross-Validated Model Accuracy on test set: 0.9583
Sample predictions from best model:
+----------+-----+----------+
|   species|label|prediction|
+----------+-----+----------+
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|versicolor|  0.0|       0.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|    setosa|  2.0|       2.0|
|versicolor|  0.0|       0.0|
+----------+-----+----------+
only showing top 10 rows

------------------------------------------------------------
Saving best model to output/iris_gbt_onevsrest ...


                                                                                

Model saved.




Loaded model successfully. Running a quick predict check...
+-------+----------+
|species|prediction|
+-------+----------+
| setosa|       2.0|
| setosa|       2.0|
| setosa|       2.0|
| setosa|       2.0|
| setosa|       2.0|
+-------+----------+

------------------------------------------------------------
Done. Spark session stopped.
