In [None]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Create a Spark session
spark = SparkSession.builder.appName("MLPipelineExample").getOrCreate()

# Sample DataFrame with a categorical feature, numerical features, and a label
data = [
    (0, "a", 1.0, 3.0, 0.0),
    (1, "b", 2.0, 4.0, 1.0),
    (2, "a", 3.0, 2.0, 0.0),
    (3, "b", 4.0, 1.0, 1.0),
    (4, "a", 5.0, 3.5, 0.0)
]
columns = ["id", "category", "feature1", "feature2", "label"]

df = spark.createDataFrame(data, columns)

# Stage 1: Index the categorical column 'category'
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")

# Stage 2: Assemble features into a single vector
assembler = VectorAssembler(inputCols=["categoryIndex", "feature1", "feature2"], outputCol="features")

# Stage 3 (optional): Scale features for better performance (especially for algorithms sensitive to feature scale)
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False)

# Stage 4: Define the classifier (using scaled features here)
rf = RandomForestClassifier(
    featuresCol="scaledFeatures",
    labelCol="label",
    predictionCol="prediction",
    numTrees=20,
    maxDepth=5,
    seed=42
)

# Combine all stages into a pipeline
pipeline = Pipeline(stages=[indexer, assembler, scaler, rf])

# Fit the pipeline on the data (this trains the model)
pipelineModel = pipeline.fit(df)

# Use the trained pipeline to make predictions
predictions = pipelineModel.transform(df)
predictions.select("id", "category", "features", "scaledFeatures", "label", "prediction", "probability").show()

# Evaluate the model's accuracy
evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)
accuracy = evaluator.evaluate(predictions)
print("Test Accuracy:", accuracy)

# Stop the Spark session
spark.stop()
