In [None]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.1-py2.py3-none-any.whl size=317488491 sha256=00d2d83e2bf08fb1c3f827896231cbb7e1018c4a3216271eb8e65607a51d9b37
  Stored in directory: /root/.cache/pip/wheels/80/1d/60/2c256ed38dddce2fdd93be545214a63e02fbd8d74fb0b7f3a6
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.1


In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml import Pipeline

spark = SparkSession.builder \
    .appName("Customer Churn Prediction") \
    .getOrCreate()

df = spark.read.csv("/content/BDA-Project.csv", header=True, inferSchema=True)
df = df.drop("Satisfaction Score", "Churn Category", "Churn Reason")
categorical_columns = ["City", "Gender", "Senior Citizen", "Married", "Dependents", "Phone Service",
                       "Multiple Lines", "Internet Service", "Internet Type", "Online Security",
                       "Online Backup", "Device Protection Plan", "Premium Tech Support",
                       "Streaming TV", "Streaming Movies", "Streaming Music", "Unlimited Data",
                       "Contract", "Paperless Billing", "Payment Method"]

indexers = [StringIndexer(inputCol=column, outputCol=column+"_index").fit(df) for column in categorical_columns]
assembler = VectorAssembler(inputCols=[column+"_index" for column in categorical_columns] +
                            ["Age", "Number of Dependents", "Tenure in Months", "Total Revenue"],
                            outputCol="features")
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
label_indexer = StringIndexer(inputCol="Churn Label", outputCol="label")
pipeline = Pipeline(stages=indexers + [assembler, scaler, label_indexer])

prepared_df = pipeline.fit(df).transform(df)
train_df, test_df = prepared_df.randomSplit([0.8, 0.2], seed=42)
rf = RandomForestClassifier(featuresCol="scaled_features", labelCol="label", seed=42)
paramGrid = ParamGridBuilder().addGrid(rf.numTrees, [10, 20]).addGrid(rf.maxDepth, [5, 10]).build()
evaluator = BinaryClassificationEvaluator(labelCol="label")
cv = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)

cv_model = cv.fit(train_df)

predictions_test = cv_model.transform(test_df)
accuracy = evaluator.evaluate(predictions_test)
print(f"Test Accuracy: {accuracy}")

rf_model = cv_model.bestModel
print("Feature Importances: ", rf_model.featureImportances)

predictions_test.select("Customer ID", "Churn Label", "prediction").show()

Test Accuracy: 0.8821346461415441
Feature Importances:  (24,[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23],[0.06549499439107452,0.0101180157535006,0.031199235699997895,0.012497497128369688,0.04307187692174649,0.006078896195913981,0.012785448510790266,0.03561998592277536,0.0713194377953795,0.022929300482456717,0.011365409294024787,0.006291488260067146,0.021727482276040256,0.014688205254894204,0.011125927276829284,0.015032941782811029,0.017099020779312656,0.19399416177132295,0.018710725931583738,0.036841322110703414,0.06778868554253312,0.040669995192406584,0.1335228519948099,0.10002709373065585])
+-----------+-----------+----------+
|Customer ID|Churn Label|prediction|
+-----------+-----------+----------+
| 0004-TLHLJ|        Yes|       1.0|
| 0013-SMEOE|         No|       0.0|
| 0015-UOCOJ|         No|       0.0|
| 0019-EFAEP|         No|       0.0|
| 0023-HGHWL|        Yes|       1.0|
| 0030-FNXPP|         No|       0.0|
| 0042-RLHYP|         No|       0.0|
| 0057-QBUQ