In [93]:
from utils.spark_config import get_spark_session
from configs.paths import ROOT_DIR

spark = get_spark_session()

# Data preparation

In [94]:
churn = spark.read.format("delta").load(str(ROOT_DIR / "data" / "bronze" / "customer_churn_bronze"))
churn.show()

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|     OnlineSecurity|       OnlineBackup|   DeviceProtection|        TechSupport|        StreamingTV|    StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|  

In [95]:
churn.describe().show()

+-------+----------+------+------------------+-------+----------+------------------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+------------------+------------------+-----+
|summary|customerID|gender|     SeniorCitizen|Partner|Dependents|            tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|    MonthlyCharges|      TotalCharges|Churn|
+-------+----------+------+------------------+-------+----------+------------------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+------------------+------------------+-----+
|  count|      7043|  7043|              7043|   7043|      7043|     

In [96]:
print(churn.count())
churn = churn.filter(~churn.TotalCharges.isNull())
print(churn.count())

7043
7043


# fixing columns dtypes

In [97]:
from pyspark.sql.functions import col, when
from pyspark.sql.types import IntegerType

churn = churn.withColumns(
    {
        'gender': when(col('gender') == 'Female', True).otherwise(False).cast(IntegerType()),
        'SeniorCitizen': when(col('SeniorCitizen') == 1, True).otherwise(False).cast(IntegerType()),
        'Partner': when(col('Partner') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'Dependents': when(col('Dependents') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'PhoneService': when(col('PhoneService') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'MultipleLines': when(col('MultipleLines') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'OnlineSecurity': when(col('OnlineSecurity') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'OnlineBackup': when(col('OnlineBackup') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'DeviceProtection': when(col('DeviceProtection') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'TechSupport': when(col('TechSupport') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'StreamingTV': when(col('StreamingTV') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'StreamingMovies': when(col('StreamingMovies') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'PaperlessBilling': when(col('PaperlessBilling') == 'Yes', True).otherwise(False).cast(IntegerType()),
        'TotalCharges': col('TotalCharges').try_cast('double'),
        'Churn': when(col('Churn') == 'Yes', 1).otherwise(0).cast(IntegerType()),
        

    }
)
churn.show(5)

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|     1|            0|      1|         0|     1|           0|            0|            DSL|             0|           1|               0|          0|          0|    

# Feature engineering

In [98]:
categorical_columns = ['InternetService', 'Contract', 'PaymentMethod']
numerical_columns = ['SeniorCitizen', 'Partner', 'Dependents', 'tenure', 'PhoneService',
                    'MultipleLines', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection',
                    'TechSupport', 'StreamingTV', 'StreamingMovies', 'PaperlessBilling',
                    'MonthlyCharges', 'TotalCharges', 'Churn']
for column in categorical_columns:
    display(column)
    churn.groupBy(column).count().orderBy('count').show()

'InternetService'

+---------------+-----+
|InternetService|count|
+---------------+-----+
|             No| 1526|
|            DSL| 2421|
|    Fiber optic| 3096|
+---------------+-----+



'Contract'

+--------------+-----+
|      Contract|count|
+--------------+-----+
|      One year| 1473|
|      Two year| 1695|
|Month-to-month| 3875|
+--------------+-----+



'PaymentMethod'

+--------------------+-----+
|       PaymentMethod|count|
+--------------------+-----+
|Credit card (auto...| 1522|
|Bank transfer (au...| 1544|
|        Mailed check| 1612|
|    Electronic check| 2365|
+--------------------+-----+



## Indexer

In [99]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder

indexers = StringIndexer(inputCols=categorical_columns, outputCols=[f"{col}_index" for col in categorical_columns], handleInvalid='skip')
indexers_model = indexers.fit(churn)
indexed_df = indexers_model.transform(churn)

indexed_df.show(5)

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+---------------------+--------------+-------------------+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|InternetService_index|Contract_index|PaymentMethod_index|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+---------------------+--------------+-------------------+
|759

## encoder

In [100]:
ohe_cols = [f"{col}_index" for col in categorical_columns]
encoders = OneHotEncoder(inputCols=ohe_cols, outputCols=[f"{col}_ohe" for col in ohe_cols], handleInvalid='skip')
encoders = OneHotEncoder(inputCols=ohe_cols, outputCols=[f"{col}_ohe" for col in ohe_cols], handleInvalid='keep')
ohe_model = encoders.fit(indexed_df)
ohe_df = ohe_model.transform(indexed_df)
ohe_df.show(5)

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+---------------------+--------------+-------------------+-------------------------+------------------+-----------------------+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|InternetService_index|Contract_index|PaymentMethod_index|InternetService_index_ohe|Contract_index_ohe|PaymentMethod_index_ohe|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+-----------

## vector assembler

In [101]:
from pyspark.ml.feature import VectorAssembler

vector_assembler = VectorAssembler(inputCols=[f"{col}_ohe" for col in ohe_cols] + numerical_columns, outputCol='features', handleInvalid='keep')
vector_df = vector_assembler.transform(ohe_df)
vector_df.show(5)

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+---------------------+--------------+-------------------+-------------------------+------------------+-----------------------+--------------------+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|InternetService_index|Contract_index|PaymentMethod_index|InternetService_index_ohe|Contract_index_ohe|PaymentMethod_index_ohe|            features|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+---------

# Split data

In [None]:
target_col = 'Churn'
primary_key = 'customerID'

label_indexer = StringIndexer(inputCol=target_col, outputCol='label', handleInvalid='skip')

train_df, test_df = churn.randomSplit([0.8, 0.2], seed=42)

## Pipeline

In [103]:
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=[indexers, encoders, vector_assembler, label_indexer])
pipeline_model = pipeline.fit(train_df)
# Save the fitted pipeline model
pipeline_model_path = str(ROOT_DIR / "models" / "preprocessing_pipeline")
pipeline_model.write().overwrite().save(pipeline_model_path)
print(f"Pipeline model saved to: {pipeline_model_path}")

Pipeline model saved to: /home/administrator/Desktop/datascience/github/learning-spark/models/preprocessing_pipeline


In [None]:
churn_df = pipeline_model.transform(churn).select(primary_key, 'features', 'label')

In [111]:
churn_df.show()

+----------+--------------------+-----+
|customerID|            features|label|
+----------+--------------------+-----+
|7590-VHVEG|(26,[1,3,6,11,13,...|  0.0|
|5575-GNVDE|(26,[1,5,7,13,14,...|  0.0|
|3668-QPYBK|(26,[1,3,7,13,14,...|  1.0|
|7795-CFOCW|(26,[1,5,9,13,16,...|  0.0|
|9237-HQITU|(26,[0,3,6,13,14,...|  1.0|
|9305-CDSKC|(26,[0,3,6,13,14,...|  1.0|
|1452-KIOVK|(26,[0,3,8,12,13,...|  0.0|
|6713-OKOMC|(26,[1,3,7,13,16,...|  0.0|
|7892-POOKP|(26,[0,3,6,11,13,...|  1.0|
|6388-TABGU|(26,[1,5,9,12,13,...|  0.0|
|9763-GRSKD|(26,[1,3,7,11,12,...|  0.0|
|7469-LKBCI|(26,[2,4,8,13,14,...|  0.0|
|8091-TTVAX|(26,[0,5,8,11,13,...|  0.0|
|0280-XJGEX|(26,[0,3,9,13,14,...|  1.0|
|5129-JLPIS|(26,[0,3,6,13,14,...|  0.0|
|3655-SNQYZ|(26,[0,4,8,11,12,...|  0.0|
|8191-XWSZG|(26,[2,5,7,13,14,...|  0.0|
|9959-WOFKT|(26,[0,4,9,12,13,...|  0.0|
|4190-MFLUW|(26,[1,3,8,11,12,...|  1.0|
|4183-MYFRB|(26,[0,3,6,13,14,...|  0.0|
+----------+--------------------+-----+
only showing top 20 rows


In [113]:
churn_df.write.format("delta").mode("overwrite").save(str(ROOT_DIR / "data" / "silver" / "customer_churn_silver"))

In [114]:
spark.stop()