# Open session and import packages

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('project_churn').getOrCreate()

In [22]:
from pyspark.sql.types import *
import pyspark.sql.functions as F

import os
import sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

# Import data

In [3]:
# import train set
train = spark.read.csv('train.csv', header = True)
# import test set
test = spark.read.csv('test.csv', header = True)

# Data Format

In [4]:
train.show(truncate = False)

+---+----------+--------------+-----------+---------+------+----+------+---------+-------------+---------+--------------+---------------+------+
|id |CustomerId|Surname       |CreditScore|Geography|Gender|Age |Tenure|Balance  |NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+---+----------+--------------+-----------+---------+------+----+------+---------+-------------+---------+--------------+---------------+------+
|0  |15674932  |Okwudilichukwu|668        |France   |Male  |33.0|3     |0.0      |2            |1.0      |0.0           |181449.97      |0     |
|1  |15749177  |Okwudiliolisa |627        |France   |Male  |33.0|1     |0.0      |2            |1.0      |1.0           |49503.5        |0     |
|2  |15694510  |Hsueh         |678        |France   |Male  |40.0|10    |0.0      |2            |1.0      |0.0           |184866.69      |0     |
|3  |15741417  |Kao           |581        |France   |Male  |34.0|2     |148882.54|1            |1.0      |1.0           |84560.88 

In [5]:
# visualize all features type
train.printSchema()

root
 |-- id: string (nullable = true)
 |-- CustomerId: string (nullable = true)
 |-- Surname: string (nullable = true)
 |-- CreditScore: string (nullable = true)
 |-- Geography: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Tenure: string (nullable = true)
 |-- Balance: string (nullable = true)
 |-- NumOfProducts: string (nullable = true)
 |-- HasCrCard: string (nullable = true)
 |-- IsActiveMember: string (nullable = true)
 |-- EstimatedSalary: string (nullable = true)
 |-- Exited: string (nullable = true)



In [6]:
# removing features that don't serve any purpose
train = train.drop('id','surname')

In [7]:
# rename all to lower case letters, feature and data, and removing all blank spaces
for col in train.columns:
    train = train.withColumnRenamed(col, col.lower())
    train = train.withColumn(col, F.lower(col))
    train = train.withColumn(col, F.trim(col))

In [8]:
# visualize how the feature data are arranged
train.show(5, truncate = False)

+----------+-----------+---------+------+----+------+---------+-------------+---------+--------------+---------------+------+
|CustomerId|CreditScore|Geography|Gender|Age |Tenure|Balance  |NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+----------+-----------+---------+------+----+------+---------+-------------+---------+--------------+---------------+------+
|15674932  |668        |france   |male  |33.0|3     |0.0      |2            |1.0      |0.0           |181449.97      |0     |
|15749177  |627        |france   |male  |33.0|1     |0.0      |2            |1.0      |1.0           |49503.5        |0     |
|15694510  |678        |france   |male  |40.0|10    |0.0      |2            |1.0      |0.0           |184866.69      |0     |
|15741417  |581        |france   |male  |34.0|2     |148882.54|1            |1.0      |1.0           |84560.88       |0     |
|15766172  |716        |spain    |male  |33.0|5     |0.0      |2            |1.0      |1.0           |15068.83       |

In [9]:
# changing data types to numeric type
numeric_columns = ['creditscore', 'tenure', 'balance', 'numofproducts', 'hascrcard', 'isactivemember', 'estimatedsalary', 'exited']

for col in numeric_columns:
    train = train.withColumn(col, F.col(col).cast(FloatType()))

In [10]:
train.printSchema()

root
 |-- CustomerId: string (nullable = true)
 |-- creditscore: float (nullable = true)
 |-- Geography: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- tenure: float (nullable = true)
 |-- balance: float (nullable = true)
 |-- numofproducts: float (nullable = true)
 |-- hascrcard: float (nullable = true)
 |-- isactivemember: float (nullable = true)
 |-- estimatedsalary: float (nullable = true)
 |-- exited: float (nullable = true)



# Exploratory Data Analysis

In [11]:
train.summary().show()

+-------+--------------------+----------------+---------+------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+
|summary|          CustomerId|     creditscore|Geography|Gender|              Age|            tenure|           balance|     numofproducts|         hascrcard|     isactivemember|   estimatedsalary|             exited|
+-------+--------------------+----------------+---------+------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+
|  count|              165034|          165034|   165034|165034|           165034|            165034|            165034|            165034|            165034|             165034|            165034|             165034|
|   mean|1.5692005019026382E7|656.454373038283|     NULL|  NULL|38.12588787764945| 5.020353381727402| 55478.08669040132|1.554455

In [None]:
+-------+--------------------+----------------+---------+------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+
|summary|          CustomerId|     creditscore|Geography|Gender|              Age|            tenure|           balance|     numofproducts|         hascrcard|     isactivemember|   estimatedsalary|             exited|
+-------+--------------------+----------------+---------+------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+
|  count|              165034|          165034|   165034|165034|           165034|            165034|            165034|            165034|            165034|             165034|            165034|             165034|
|   mean|1.5692005019026382E7|656.454373038283|     NULL|  NULL|38.12588787764945| 5.020353381727402| 55478.08669040132|1.5544554455445545|0.7539537307463916|0.49777015645260975|112574.82270602613|0.21159882206090866|
| stddev|   71397.81679067112|80.1033404871783|     NULL|  NULL|8.867204591410792|2.8061585665860913|62817.663267958495|0.5471536788441764|0.4307071240449495|0.49999654260421705| 50292.86554962783| 0.4084431067117287|
|    min|            15565701|           350.0|   france|female|             18.0|               0.0|               0.0|               1.0|               0.0|                0.0|             11.58|                0.0|
|    25%|         1.5633112E7|           597.0|     NULL|  NULL|             32.0|               3.0|               0.0|               1.0|               1.0|                0.0|          74637.57|                0.0|
|    50%|         1.5690164E7|           659.0|     NULL|  NULL|             37.0|               5.0|               0.0|               2.0|               1.0|                0.0|          117946.3|                0.0|
|    75%|         1.5756821E7|           710.0|     NULL|  NULL|             42.0|               7.0|         119919.12|               2.0|               1.0|                1.0|         155061.97|                0.0|
|    max|            15815690|           850.0|    spain|  male|             92.0|              10.0|          250898.1|               4.0|               1.0|                1.0|         199992.48|                1.0|
+-------+--------------------+----------------+---------+------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+

In [12]:
def eda_basic(df, cols:list):
    # lista auxiliar para adicionar os registros
    summary_data = []
    
    for feature in cols:
        # retirada de espaços vazios nos valores das colunas
        trimmed_col = df.withColumn(feature, F.trim(F.col(feature)))
        # cálculo de registros únicos
        unique_values = df.select(feature).distinct().count()
        # cálculo do total de registros
        total_values = df.count()
        # cálculo da porcentagem de missings nas features
        percent_missings = (df.select(feature).where(f'{feature} == "None" or {feature} == "" ').count()/total_values)*100
        # adiciona os valores calculados a um registro referente à respectiva feature
        summary_data.append((feature, unique_values, total_values, percent_missings))
        
    # criação do Dataframe de saída
    df = spark.createDataFrame(summary_data, ["feature", "unique_value", "total_values", "percent_missing"])
    
    return df

In [14]:
eda = eda_basic(train, train.columns[1:])

In [15]:
eda.show()

Py4JJavaError: An error occurred while calling o526.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 255.0 failed 1 times, most recent failure: Lost task 0.0 in stage 255.0 (TID 246) (DESKTOP-CRUHD2O executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:67)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	at java.lang.Thread.run(Unknown Source)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(Unknown Source)
	at java.net.AbstractPlainSocketImpl.accept(Unknown Source)
	at java.net.PlainSocketImpl.accept(Unknown Source)
	at java.net.ServerSocket.implAccept(Unknown Source)
	at java.net.ServerSocket.accept(Unknown Source)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 32 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2844)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2780)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2779)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2779)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1242)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3048)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2982)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2971)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:984)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:530)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4344)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3326)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4334)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4332)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4332)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:3326)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:3549)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
	at java.lang.reflect.Method.invoke(Unknown Source)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Unknown Source)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:67)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	... 1 more
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(Unknown Source)
	at java.net.AbstractPlainSocketImpl.accept(Unknown Source)
	at java.net.PlainSocketImpl.accept(Unknown Source)
	at java.net.ServerSocket.implAccept(Unknown Source)
	at java.net.ServerSocket.accept(Unknown Source)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 32 more
