In [None]:
"""
Lesson 5: Machine Learning Pipelines in PySpark
===============================================

Author: Deb
Date:   2024-06-10

Description:
------------
End-to-end supervised classification example using PySpark ML pipelines on the Iris dataset.

Key Concepts:
-------------
1. Feature transformation with VectorAssembler & StandardScaler
2. Training a Logistic Regression classifier
3. Building a Pipeline with multiple stages
4. Model evaluation using accuracy metrics
5. Hyperparameter tuning via CrossValidator
6. Saving and loading trained models

Dataset:
--------
Iris dataset (UCI Machine Learning Repository)
https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv

Run:
----
$ source venv/bin/activate
$ python src/lesson5_ml_pipeline.py
"""

from pathlib import Path
import urllib.request
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import PipelineModel


# =========================================================
# 1. Initialize Spark Session
# =========================================================
spark = SparkSession.builder \
    .appName("lesson5-ml-pipeline") \
    .master("local[*]") \
    .getOrCreate()

print("âœ… Spark Session Started.")
print("Spark Version:", spark.version)
print("-" * 60)


# =========================================================
# 2. Download and Load Dataset
# =========================================================
iris_url = "https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv"
local_path = Path("/tmp/iris.csv")
urllib.request.urlretrieve(iris_url, local_path)

iris_df = spark.read.csv(str(local_path), header=True, inferSchema=True)
print("Schema:")
iris_df.printSchema()
iris_df.show(5)

# =========================================================
# 3. Data Preparation
# =========================================================
# Convert categorical 'species' into numeric label (index)
label_indexer = StringIndexer(inputCol="species", outputCol="label")

# Combine numerical features into a single vector column
feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features_assembled")

# Optionally scale the features
scaler = StandardScaler(inputCol="features_assembled", outputCol="features", withMean=True, withStd=True)

# =========================================================
# 4. Define Model (Estimator)
# =========================================================
lr = LogisticRegression(featuresCol="features", labelCol="label", maxIter=20)

# =========================================================
# 5. Build ML Pipeline
# =========================================================
pipeline = Pipeline(stages=[label_indexer, assembler, scaler, lr])

# =========================================================
# 6. Split Data into Train/Test
# =========================================================
train_df, test_df = iris_df.randomSplit([0.8, 0.2], seed=42)
print(f"Training samples: {train_df.count()}, Test samples: {test_df.count()}")

# =========================================================
# 7. Fit Model
# =========================================================
model = pipeline.fit(train_df)
predictions = model.transform(test_df)

print("=== Predictions Sample ===")
predictions.select("species", "label", "prediction", "probability").show(5, truncate=False)

# =========================================================
# 8. Evaluate Model
# =========================================================
evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)

accuracy = evaluator.evaluate(predictions)
print(f"Model Accuracy = {accuracy:.4f}")
print("-" * 60)

# =========================================================
# 9. Hyperparameter Tuning (Cross Validation)
# =========================================================
param_grid = (
    ParamGridBuilder()
    .addGrid(lr.regParam, [0.01, 0.1, 0.5])
    .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
    .build()
)

crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=3
)

cv_model = crossval.fit(train_df)
cv_predictions = cv_model.transform(test_df)
cv_accuracy = evaluator.evaluate(cv_predictions)
print(f"Cross-Validated Accuracy = {cv_accuracy:.4f}")

# =========================================================
# 10. Save & Load Model
# =========================================================
model_path = "output/iris_logreg_pipeline"
cv_model.bestModel.write().overwrite().save(model_path)
print(f"âœ… Model saved to {model_path}")

loaded_model = PipelineModel.load(model_path)
print("âœ… Model pipeline reloaded successfully.")

# =========================================================
# 11. Stop Spark
# =========================================================
spark.stop()
print("ðŸ§  Spark session stopped. End of Lesson 5.")


âœ… Spark Session Started.
Spark Version: 3.5.7
------------------------------------------------------------
Schema:
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- species: string (nullable = true)

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
|         4.7|        3.2|         1.3|        0.2| setosa|
|         4.6|        3.1|         1.5|        0.2| setosa|
|         5.0|        3.6|         1.4|        0.2| setosa|
+------------+-----------+------------+-----------+-------+
only showing top 5 rows

Training samples: 126, Test samples: 24
=== Predictions Sample ===
+-------+-----+----------+------