# Imports

In [1]:
# from pyspark.sql import SparkSession
# from pyspark.sql.functions import round,avg
# from pyspark.sql.functions import min
# from pyspark.ml.feature import VectorAssembler
# from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier
# from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Data Analysis

In [2]:
from pyspark.sql import SparkSession

In [3]:
spark=SparkSession.builder.appName('RevisionApp').getOrCreate()
spark

In [4]:
df = spark.read.csv('diabetes.csv',header = True, inferSchema=True)

### Mam's Analysis

In [5]:
df.show()

+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|Pregnancies|Glucose|BloodPressure|SkinThickness|Insulin| BMI|DiabetesPedigreeFunction|Age|Outcome|
+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|          6|    148|           72|           35|      0|33.6|                   0.627| 50|      1|
|          1|     85|           66|           29|      0|26.6|                   0.351| 31|      0|
|          8|    183|           64|            0|      0|23.3|                   0.672| 32|      1|
|          1|     89|           66|           23|     94|28.1|                   0.167| 21|      0|
|          0|    137|           40|           35|    168|43.1|                   2.288| 33|      1|
|          5|    116|           74|            0|      0|25.6|                   0.201| 30|      0|
|          3|     78|           50|           32|     88|31.0|                   0.248| 26|      1|


In [6]:
df.count()

768

In [7]:
len(df.columns)

9

In [8]:
out_df=df.groupBy('Outcome').count()


In [9]:
out_df['count']/df.count()

Column<'/(count, 768)'>

In [10]:
agg_df=df.groupBy('Outcome').avg('age')
agg_df.withColumnRenamed('avg(age)','Average').show()

+-------+-----------------+
|Outcome|          Average|
+-------+-----------------+
|      1|37.06716417910448|
|      0|            31.19|
+-------+-----------------+




### Practice Questions

#### Basic Data Exploration Questions

Question 1: Data Loading and Schema Exploration

Load the diabetes dataset and display basic information
Tasks:
1. Load the CSV file into a PySpark DataFrame
2. Display the schema
3. Show the first 5 rows
4. Print the total number of records

In [11]:
df.printSchema()

root
 |-- Pregnancies: integer (nullable = true)
 |-- Glucose: integer (nullable = true)
 |-- BloodPressure: integer (nullable = true)
 |-- SkinThickness: integer (nullable = true)
 |-- Insulin: integer (nullable = true)
 |-- BMI: double (nullable = true)
 |-- DiabetesPedigreeFunction: double (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Outcome: integer (nullable = true)



In [12]:
df.schema

StructType([StructField('Pregnancies', IntegerType(), True), StructField('Glucose', IntegerType(), True), StructField('BloodPressure', IntegerType(), True), StructField('SkinThickness', IntegerType(), True), StructField('Insulin', IntegerType(), True), StructField('BMI', DoubleType(), True), StructField('DiabetesPedigreeFunction', DoubleType(), True), StructField('Age', IntegerType(), True), StructField('Outcome', IntegerType(), True)])

In [13]:
df.show(5)

+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|Pregnancies|Glucose|BloodPressure|SkinThickness|Insulin| BMI|DiabetesPedigreeFunction|Age|Outcome|
+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|          6|    148|           72|           35|      0|33.6|                   0.627| 50|      1|
|          1|     85|           66|           29|      0|26.6|                   0.351| 31|      0|
|          8|    183|           64|            0|      0|23.3|                   0.672| 32|      1|
|          1|     89|           66|           23|     94|28.1|                   0.167| 21|      0|
|          0|    137|           40|           35|    168|43.1|                   2.288| 33|      1|
+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
only showing top 5 rows


Q2. Basic Statistics:
Calculate and display summary statistics (count, mean, stddev, min, max) for the numerical columns (quantity and price).


In [14]:
df.count()

768

Question 2: Basic Statistics

Calculate basic statistics for numerical columns
Tasks:
1. Show summary statistics (count, mean, stddev, min, max)
2. Find the average age of patients

In [15]:
df.describe().show()

+-------+------------------+-----------------+------------------+------------------+------------------+------------------+------------------------+------------------+------------------+
|summary|       Pregnancies|          Glucose|     BloodPressure|     SkinThickness|           Insulin|               BMI|DiabetesPedigreeFunction|               Age|           Outcome|
+-------+------------------+-----------------+------------------+------------------+------------------+------------------+------------------------+------------------+------------------+
|  count|               768|              768|               768|               768|               768|               768|                     768|               768|               768|
|   mean|3.8450520833333335|     120.89453125|       69.10546875|20.536458333333332| 79.79947916666667|31.992578124999977|      0.4718763020833327|33.240885416666664|0.3489583333333333|
| stddev|  3.36957806269887|31.97261819513622|19.355807170644777|15.95

In [16]:
df.groupBy("Outcome").avg("Age").show()

+-------+-----------------+
|Outcome|         avg(Age)|
+-------+-----------------+
|      1|37.06716417910448|
|      0|            31.19|
+-------+-----------------+



#### Data Analysis and Aggregation Questions
Question 5: GroupBy Operations

Analyze data using groupby operations
Tasks:
1. Find average glucose level by outcome
2. Count patients by age group and diabetes outcome
3. Calculate average BMI for diabetic vs non-diabetic patients
4. Find the maximum glucose level for each individual age

In [17]:
df.show()

+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|Pregnancies|Glucose|BloodPressure|SkinThickness|Insulin| BMI|DiabetesPedigreeFunction|Age|Outcome|
+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|          6|    148|           72|           35|      0|33.6|                   0.627| 50|      1|
|          1|     85|           66|           29|      0|26.6|                   0.351| 31|      0|
|          8|    183|           64|            0|      0|23.3|                   0.672| 32|      1|
|          1|     89|           66|           23|     94|28.1|                   0.167| 21|      0|
|          0|    137|           40|           35|    168|43.1|                   2.288| 33|      1|
|          5|    116|           74|            0|      0|25.6|                   0.201| 30|      0|
|          3|     78|           50|           32|     88|31.0|                   0.248| 26|      1|


In [18]:
df.groupBy("Outcome").avg("Glucose").alias("Average_Glucose").show()

+-------+------------------+
|Outcome|      avg(Glucose)|
+-------+------------------+
|      1|141.25746268656715|
|      0|            109.98|
+-------+------------------+



In [19]:
from pyspark .sql.functions import * 

df.groupBy("Outcome").agg(avg("Glucose").alias("Average_Glucose")).show()

+-------+------------------+
|Outcome|   Average_Glucose|
+-------+------------------+
|      1|141.25746268656715|
|      0|            109.98|
+-------+------------------+



In [20]:
df.withColumn(
    "AgeGroup",
    when(col("Age")<30, "Young")
    .when((col("Age")>=30) & (col("Age")<50), "Middle-Aged")
    .otherwise("Senior")
).groupBy("AgeGroup","Outcome") \
 .agg(count("*").alias("Patient_Count")) \
 .orderBy("AgeGroup","Outcome").show()

+-----------+-------+-------------+
|   AgeGroup|Outcome|Patient_Count|
+-----------+-------+-------------+
|Middle-Aged|      0|          142|
|Middle-Aged|      1|          141|
|     Senior|      0|           46|
|     Senior|      1|           43|
|      Young|      0|          312|
|      Young|      1|           84|
+-----------+-------+-------------+



In [21]:
df.groupBy("Outcome").avg("BMI").show()

+-------+-----------------+
|Outcome|         avg(BMI)|
+-------+-----------------+
|      1|35.14253731343278|
|      0|30.30419999999996|
+-------+-----------------+



In [22]:
df_category=df.withColumn(
    "AgeGroup",
    when(col("Age")<30, "Young")
    .when((col("Age")>=30) & (col("Age")<50), "Middle-Aged")
    .otherwise("Senior")
)

df_category.groupBy("AgeGroup").max("Glucose").show()

+-----------+------------+
|   AgeGroup|max(Glucose)|
+-----------+------------+
|     Senior|         197|
|Middle-Aged|         197|
|      Young|         199|
+-----------+------------+



Question 6: Filtering and Conditional Analysis

Filter data based on conditions
Tasks:
1. Find patients with glucose level > 195
2. Count diabetic patients with BMI > 30
3. Find young patients (age < 35) with diabetes

df.filter(df.Glucose>195).show()

In [23]:
df.filter((df.Outcome==1) & (df.BMI>30)).count()

215

In [24]:
df.filter((df.Age<35) & (df.Outcome==1)).show()

+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|Pregnancies|Glucose|BloodPressure|SkinThickness|Insulin| BMI|DiabetesPedigreeFunction|Age|Outcome|
+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+
|          8|    183|           64|            0|      0|23.3|                   0.672| 32|      1|
|          0|    137|           40|           35|    168|43.1|                   2.288| 33|      1|
|          3|     78|           50|           32|     88|31.0|                   0.248| 26|      1|
|         10|    168|           74|            0|      0|38.0|                   0.537| 34|      1|
|          7|    100|            0|            0|      0|30.0|                   0.484| 32|      1|
|          0|    118|           84|           47|    230|45.8|                   0.551| 31|      1|
|          7|    107|           74|            0|      0|29.6|                   0.254| 31|      1|


# ML

## Load the Data

In [25]:
# from pyspark.sql import SparkSession
# spark = SparkSession.builder.appName("ML").getOrCreate()
# df=spark.read.csv("diabetes.csv",header=True,inferSchema=True)
# df.show()

In [26]:
df.columns

['Pregnancies',
 'Glucose',
 'BloodPressure',
 'SkinThickness',
 'Insulin',
 'BMI',
 'DiabetesPedigreeFunction',
 'Age',
 'Outcome']

## Data Preparation

In [27]:
from pyspark.ml.feature import VectorAssembler

In [28]:
assembler = VectorAssembler(inputCols=['Pregnancies',
                                        'Glucose',
                                         'BloodPressure',
                                         'SkinThickness',
                                         'Insulin',
                                         'BMI',
                                         'DiabetesPedigreeFunction',
                                         'Age'],outputCol="features")
all_df=assembler.transform(df)
all_df.show(truncate=False)

+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+-------------------------------------------+
|Pregnancies|Glucose|BloodPressure|SkinThickness|Insulin|BMI |DiabetesPedigreeFunction|Age|Outcome|features                                   |
+-----------+-------+-------------+-------------+-------+----+------------------------+---+-------+-------------------------------------------+
|6          |148    |72           |35           |0      |33.6|0.627                   |50 |1      |[6.0,148.0,72.0,35.0,0.0,33.6,0.627,50.0]  |
|1          |85     |66           |29           |0      |26.6|0.351                   |31 |0      |[1.0,85.0,66.0,29.0,0.0,26.6,0.351,31.0]   |
|8          |183    |64           |0            |0      |23.3|0.672                   |32 |1      |[8.0,183.0,64.0,0.0,0.0,23.3,0.672,32.0]   |
|1          |89     |66           |23           |94     |28.1|0.167                   |21 |0      |[1.0,89.0,66.0,23.0,94.0,28.1,0.167,2

In [29]:
assembled_df=assembler.transform(df).select("features","Outcome")
assembled_df.show(truncate=False)

+-------------------------------------------+-------+
|features                                   |Outcome|
+-------------------------------------------+-------+
|[6.0,148.0,72.0,35.0,0.0,33.6,0.627,50.0]  |1      |
|[1.0,85.0,66.0,29.0,0.0,26.6,0.351,31.0]   |0      |
|[8.0,183.0,64.0,0.0,0.0,23.3,0.672,32.0]   |1      |
|[1.0,89.0,66.0,23.0,94.0,28.1,0.167,21.0]  |0      |
|[0.0,137.0,40.0,35.0,168.0,43.1,2.288,33.0]|1      |
|[5.0,116.0,74.0,0.0,0.0,25.6,0.201,30.0]   |0      |
|[3.0,78.0,50.0,32.0,88.0,31.0,0.248,26.0]  |1      |
|[10.0,115.0,0.0,0.0,0.0,35.3,0.134,29.0]   |0      |
|[2.0,197.0,70.0,45.0,543.0,30.5,0.158,53.0]|1      |
|[8.0,125.0,96.0,0.0,0.0,0.0,0.232,54.0]    |1      |
|[4.0,110.0,92.0,0.0,0.0,37.6,0.191,30.0]   |0      |
|[10.0,168.0,74.0,0.0,0.0,38.0,0.537,34.0]  |1      |
|[10.0,139.0,80.0,0.0,0.0,27.1,1.441,57.0]  |0      |
|[1.0,189.0,60.0,23.0,846.0,30.1,0.398,59.0]|1      |
|[5.0,166.0,72.0,19.0,175.0,25.8,0.587,51.0]|1      |
|[7.0,100.0,0.0,0.0,0.0,30.0

## Train & Test Data Split

In [30]:
train_df,test_df=assembled_df.randomSplit([0.8,0.2],seed=42)

## Model Building

### Logistic Regression

In [31]:
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(featuresCol="features",labelCol="Outcome")
model_lr=lr.fit(train_df)

In [32]:
lr_train_predictions = model_lr.transform(train_df)
lr_test_predictions = model_lr.transform(test_df)

### Decision Tree

In [33]:
from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier(featuresCol="features",labelCol="Outcome")
model_dt=dt.fit(train_df)

In [34]:
dt_train_predictions=model_dt.transform(train_df)
dt_test_predictions=model_dt.transform(test_df)

### Random Forest

In [35]:
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(featuresCol="features",labelCol="Outcome")
model_rf=rf.fit(train_df)

In [36]:
rf_train_predictions=model_rf.transform(train_df)
rf_test_predictions=model_rf.transform(test_df)

## Model Evaluation

In [37]:
### Logistic Regression

In [38]:
lr_train_predictions.show()

+--------------------+-------+--------------------+--------------------+----------+
|            features|Outcome|       rawPrediction|         probability|prediction|
+--------------------+-------+--------------------+--------------------+----------+
|(8,[0,1,6,7],[2.0...|      0|[4.81560259615597...|[0.99196278208874...|       0.0|
|(8,[0,1,6,7],[2.0...|      0|[4.31519705947199...|[0.98681232326030...|       0.0|
|(8,[0,1,6,7],[6.0...|      0|[2.94161295270562...|[0.94986559290546...|       0.0|
|(8,[0,1,6,7],[7.0...|      0|[3.02623474541700...|[0.95374535166608...|       0.0|
|(8,[0,1,6,7],[10....|      1|[2.30954296151672...|[0.90966430511917...|       0.0|
|(8,[1,5,6,7],[99....|      0|[1.94905417494776...|[0.87534347259153...|       0.0|
|(8,[1,5,6,7],[119...|      1|[0.75960133175394...|[0.68126717234068...|       0.0|
|(8,[1,5,6,7],[131...|      1|[-0.7134310789887...|[0.32884114089827...|       1.0|
|(8,[1,5,6,7],[138...|      1|[-0.9853740213663...|[0.27182676275860...|    

In [39]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
accuracy_evaluator = MulticlassClassificationEvaluator(predictionCol='prediction',labelCol='Outcome',metricName='accuracy')
f1_evaluator=MulticlassClassificationEvaluator(predictionCol='prediction',labelCol='Outcome',metricName='f1')
precision_evaluator=MulticlassClassificationEvaluator(predictionCol='prediction',labelCol='Outcome',metricName='weightedPrecision')
recall_evaluator=MulticlassClassificationEvaluator(predictionCol='prediction',labelCol='Outcome',metricName='weightedRecall')

def evalaute_model(model_predictions):
    print("Accuracy = ", accuracy_evaluator.evaluate(model_predictions))
    print("F1 Score = ", f1_evaluator.evaluate(model_predictions))
    print("Precision = ", precision_evaluator.evaluate(model_predictions))
    print("Recall = ", recall_evaluator.evaluate(model_predictions))
    model_predictions.groupBy("Outcome","prediction").count().show()

### Logistic Regression

In [40]:
print("Logistic Regression Training Performance")
print("-------------------------")
evalaute_model(lr_train_predictions)

Logistic Regression Training Performance
-------------------------
Accuracy =  0.7736434108527132
F1 Score =  0.7669366138447278
Precision =  0.7689769543372289
Recall =  0.7736434108527132
+-------+----------+-----+
|Outcome|prediction|count|
+-------+----------+-----+
|      1|       0.0|   96|
|      0|       0.0|  366|
|      1|       1.0|  133|
|      0|       1.0|   50|
+-------+----------+-----+



In [41]:
print("Logistic Regression Testing Performance")
print("-------------------------")
evalaute_model(lr_test_predictions)

Logistic Regression Testing Performance
-------------------------
Accuracy =  0.8211382113821138
F1 Score =  0.8167694607980562
Precision =  0.8169007144616901
Recall =  0.8211382113821138
+-------+----------+-----+
|Outcome|prediction|count|
+-------+----------+-----+
|      1|       0.0|   14|
|      0|       0.0|   76|
|      1|       1.0|   25|
|      0|       1.0|    8|
+-------+----------+-----+



### Decision Tree

In [42]:
print("Decision Tree Training Performance")
print("-------------------------")
evalaute_model(dt_train_predictions)

Decision Tree Training Performance
-------------------------
Accuracy =  0.8248062015503876
F1 Score =  0.821094438169818
Precision =  0.822801102817593
Recall =  0.8248062015503875
+-------+----------+-----+
|Outcome|prediction|count|
+-------+----------+-----+
|      1|       0.0|   74|
|      0|       0.0|  377|
|      1|       1.0|  155|
|      0|       1.0|   39|
+-------+----------+-----+



In [43]:
print("Decision Tree Training Performance")
print("-------------------------")
evalaute_model(dt_test_predictions)

Decision Tree Training Performance
-------------------------
Accuracy =  0.7967479674796748
F1 Score =  0.7960294864610061
Precision =  0.7954239975836291
Recall =  0.7967479674796748
+-------+----------+-----+
|Outcome|prediction|count|
+-------+----------+-----+
|      1|       0.0|   13|
|      0|       0.0|   72|
|      1|       1.0|   26|
|      0|       1.0|   12|
+-------+----------+-----+



### Random Forest

In [44]:
print("Random Forest Training Performance")
print("-------------------------")
evalaute_model(rf_train_predictions)

Random Forest Training Performance
-------------------------
Accuracy =  0.8310077519379845
F1 Score =  0.8253000826205281
Precision =  0.8316707409896288
Recall =  0.8310077519379845
+-------+----------+-----+
|Outcome|prediction|count|
+-------+----------+-----+
|      1|       0.0|   80|
|      0|       0.0|  387|
|      1|       1.0|  149|
|      0|       1.0|   29|
+-------+----------+-----+



In [45]:
print("Decision Tree Testing Performance")
print("-------------------------")
evalaute_model(rf_test_predictions)

Decision Tree Testing Performance
-------------------------
Accuracy =  0.8130081300813008
F1 Score =  0.8109085722436171
Precision =  0.809877581534436
Recall =  0.8130081300813009
+-------+----------+-----+
|Outcome|prediction|count|
+-------+----------+-----+
|      1|       0.0|   13|
|      0|       0.0|   74|
|      1|       1.0|   26|
|      0|       1.0|   10|
+-------+----------+-----+



## Hyperparameter Tuning

In [47]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 1️⃣ Create Logistic Regression model
lr = LogisticRegression(featuresCol="features", labelCol="Outcome")

# 2️⃣ Build parameter grid for number of iterations (maxIter)
param_grid = ParamGridBuilder() \
    .addGrid(lr.maxIter, [10, 20, 30, 40, 50, 100]) \
    .build()

# 3️⃣ Define evaluator (accuracy metric)
accuracy_evaluator = MulticlassClassificationEvaluator(
    predictionCol="prediction",
    labelCol="Outcome",
    metricName="accuracy"
)

# 4️⃣ Create CrossValidator (5-fold cross validation)
lr_cv = CrossValidator(
    estimator=lr,
    estimatorParamMaps=param_grid,
    evaluator=accuracy_evaluator,
    numFolds=5,
    seed=10
)

# 5️⃣ Fit model using cross-validation
lr_cv_model = lr_cv.fit(train_df)

# 6️⃣ Retrieve the best model
best_lr_model = lr_cv_model.bestModel

# 7️⃣ Display best parameter and performance
print("Best Logistic Regression Model Parameters:")
print(f"➡️ Best maxIter: {best_lr_model._java_obj.getMaxIter()}")
print(f"➡️ Cross-Validation Accuracies: {lr_cv_model.avgMetrics}")


Best Logistic Regression Model Parameters:
➡️ Best maxIter: 10
➡️ Cross-Validation Accuracies: [np.float64(0.7491237049885783), np.float64(0.7491237049885783), np.float64(0.7491237049885783), np.float64(0.7491237049885783), np.float64(0.7491237049885783), np.float64(0.7491237049885783)]


In [46]:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

dt = DecisionTreeClassifier(featuresCol="features",labelCol="Outcome")
param_grid=ParamGridBuilder().addGrid(dt.maxDepth,range(2,11)).build()
accuracy_evaluator=MulticlassClassificationEvaluator(predictionCol='prediction',labelCol='Outcome',metricName='accuracy')
dt_cv=CrossValidator(estimator=dt,estimatorParamMaps=param_grid,evaluator=accuracy_evaluator,numFolds=5, seed=10)
dt_cv_model=dt_cv.fit(train_df)
dt_cv_model.bestModel

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_bc5ea5b2087c, depth=2, numNodes=5, numClasses=2, numFeatures=8

In [50]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 1️⃣ Create Decision Tree model
dt = DecisionTreeClassifier(featuresCol="features", labelCol="Outcome")

# 2️⃣ Build parameter grid for maxDepth
param_grid = ParamGridBuilder() \
    .addGrid(dt.maxDepth, [2, 4, 6, 8, 10]) \
    .build()

# 3️⃣ Define evaluator (accuracy metric)
accuracy_evaluator = MulticlassClassificationEvaluator(
    predictionCol="prediction",
    labelCol="Outcome",
    metricName="accuracy"
)

# 4️⃣ Create CrossValidator (5-fold cross validation)
dt_cv = CrossValidator(
    estimator=dt,
    estimatorParamMaps=param_grid,
    evaluator=accuracy_evaluator,
    numFolds=5,
    seed=10
)

# 5️⃣ Fit model using cross-validation
dt_cv_model = dt_cv.fit(train_df)

# 6️⃣ Retrieve the best model
best_dt_model = dt_cv_model.bestModel

# 7️⃣ Display best parameters and performance
print("Best Decision Tree Model Parameters:")
print(f"➡️ Best maxDepth: {best_dt_model._java_obj.getMaxDepth()}")
print(f"➡️ Cross-Validation Accuracies: {dt_cv_model.avgMetrics}")


Best Decision Tree Model Parameters:
➡️ Best maxDepth: 2
➡️ Cross-Validation Accuracies: [np.float64(0.7252121197538777), np.float64(0.7080243862070574), np.float64(0.6969546588329824), np.float64(0.6638690446244403), np.float64(0.664295580865803)]


In [49]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 1️⃣ Create Random Forest model
rf = RandomForestClassifier(featuresCol="features", labelCol="Outcome", seed=10)

# 2️⃣ Build parameter grid for number of trees and max depth
param_grid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20, 50, 100]) \
    .addGrid(rf.maxDepth, [3, 5, 7, 10]) \
    .build()

# 3️⃣ Define evaluator (accuracy metric)
accuracy_evaluator = MulticlassClassificationEvaluator(
    predictionCol="prediction",
    labelCol="Outcome",
    metricName="accuracy"
)

# 4️⃣ Create CrossValidator (5-fold CV)
rf_cv = CrossValidator(
    estimator=rf,
    estimatorParamMaps=param_grid,
    evaluator=accuracy_evaluator,
    numFolds=5,
    seed=10
)

# 5️⃣ Fit cross-validation model on training data
rf_cv_model = rf_cv.fit(train_df)

# 6️⃣ Retrieve best model
best_rf_model = rf_cv_model.bestModel

# 7️⃣ Display best parameters and accuracies
print("Best Random Forest Model Parameters:")
print(f"➡️ Best numTrees: {best_rf_model.getNumTrees}")
print(f"➡️ Best maxDepth: {best_rf_model.getOrDefault('maxDepth')}")
print(f"➡️ Cross-Validation Accuracies: {rf_cv_model.avgMetrics}")


Best Random Forest Model Parameters:
➡️ Best numTrees: 100
➡️ Best maxDepth: 10
➡️ Cross-Validation Accuracies: [np.float64(0.7346103901311908), np.float64(0.7517933479303515), np.float64(0.7527874325653492), np.float64(0.7409319671049417), np.float64(0.7549449768248643), np.float64(0.7552490206071689), np.float64(0.752095149663182), np.float64(0.7470167435525583), np.float64(0.7569513741393629), np.float64(0.7551865530986582), np.float64(0.7422022248421998), np.float64(0.7498078034987637), np.float64(0.7445539889158908), np.float64(0.7636994873544763), np.float64(0.7669262522484293), np.float64(0.7695311856322179)]
