In [1]:
# Intialization
import os
import sys

os.environ["SPARK_HOME"] = "/home/talentum/spark"
os.environ["PYLIB"] = os.environ["SPARK_HOME"] + "/python/lib"
# In below two lines, use /usr/bin/python2.7 if you want to use Python 2
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.6" 
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3"
sys.path.insert(0, os.environ["PYLIB"] +"/py4j-0.10.7-src.zip")
sys.path.insert(0, os.environ["PYLIB"] +"/pyspark.zip")

# NOTE: Whichever package you want mention here.
# os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.databricks:spark-xml_2.11:0.6.0 pyspark-shell' 
# os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-avro_2.11:2.4.0 pyspark-shell'
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.databricks:spark-xml_2.11:0.6.0,org.apache.spark:spark-avro_2.11:2.4.3 pyspark-shell'
# os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.databricks:spark-xml_2.11:0.6.0,org.apache.spark:spark-avro_2.11:2.4.0 pyspark-shell'

In [2]:
#Entrypoint 2.x
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Bank Customer Segmentation").enableHiveSupport().getOrCreate()
sc = spark.sparkContext

In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, when, regexp_replace
from pyspark.sql.types import StructType, StructField, StringType, FloatType, DateType, IntegerType
data_path = "Bank_dataset.csv"

# Print the file_path
print("The file_path is", data_path)

The file_path is Bank_dataset.csv


In [4]:
df = spark.read.csv(data_path, header=True, inferSchema=True)

# Inspect Data
df.show(5)
df.printSchema()

+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+
|TransactionID|CustomerID|CustomerDOB|CustGender|CustLocation|CustAccountBalance|TransactionDate|TransactionTime|TransactionAmount (INR)|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+
|           T1|  C5841053| 10-01-1994|         F|  JAMSHEDPUR|          17819.05|     02-08-2016|         143207|                   25.0|
|           T2|  C2142763| 04-04-1957|         M|     JHAJJAR|           2270.69|     02-08-2016|         141858|                27999.0|
|           T3|  C4417068| 26-11-1996|         F|      MUMBAI|          17874.44|     02-08-2016|         142712|                  459.0|
|           T4|  C5342380| 14-09-1973|         F|      MUMBAI|         866503.21|     02-08-2016|         142714|                 2060.0|
|           T5|  C9031234| 24-03-1

In [10]:
# Summary Statistics
df.describe().show()

+-------+-------------+----------+-----------+----------+--------------------+------------------+---------------+------------------+-----------------------+
|summary|TransactionID|CustomerID|CustomerDOB|CustGender|        CustLocation|CustAccountBalance|TransactionDate|   TransactionTime|TransactionAmount (INR)|
+-------+-------------+----------+-----------+----------+--------------------+------------------+---------------+------------------+-----------------------+
|  count|      1048567|   1048567|    1048567|   1047467|             1048416|           1046198|        1048567|           1048567|                1048567|
|   mean|         null|      null|       null|      null|            400012.0|115403.54005622343|           null|157087.52939297154|     1574.3350034571733|
| stddev|         null|      null|       null|      null|                 0.0|  846485.380600677|           null| 51261.85402233114|      6574.742978453954|
|    min|           T1|  C1010011| 01-01-1930|         F|(

In [11]:
# Checking for Null Values
from pyspark.sql.functions import col, isnan, when, count

null_counts = df.select([count(when(col(c).isNull() | isnan(c), c)).alias(c) for c in df.columns])
null_counts.show()

+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+
|TransactionID|CustomerID|CustomerDOB|CustGender|CustLocation|CustAccountBalance|TransactionDate|TransactionTime|TransactionAmount (INR)|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+
|            0|         0|          0|      1100|         151|              2369|              0|              0|                      0|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+



In [12]:
# Value Counts for Categorical Features
categorical_columns = [col_name for col_name, dtype in df.dtypes if dtype == 'string']

for column in categorical_columns:
    df.groupBy(column).count().show()

+-------------+-----+
|TransactionID|count|
+-------------+-----+
|         T352|    1|
|         T590|    1|
|         T855|    1|
|         T929|    1|
|         T947|    1|
|        T1118|    1|
|        T1401|    1|
|        T1508|    1|
|        T1767|    1|
|        T1872|    1|
|        T2345|    1|
|        T2463|    1|
|        T2837|    1|
|        T2947|    1|
|        T3091|    1|
|        T3230|    1|
|        T3271|    1|
|        T3337|    1|
|        T3396|    1|
|        T4155|    1|
+-------------+-----+
only showing top 20 rows

+----------+-----+
|CustomerID|count|
+----------+-----+
|  C8732711|    1|
|  C6421261|    2|
|  C2939112|    2|
|  C8741166|    1|
|  C4440931|    1|
|  C8116854|    1|
|  C4029960|    1|
|  C7817677|    4|
|  C4240562|    1|
|  C2221420|    1|
|  C8524541|    1|
|  C5638051|    3|
|  C2331867|    1|
|  C5534211|    1|
|  C4940219|    1|
|  C2230276|    1|
|  C4341556|    1|
|  C7283217|    1|
|  C3825339|    1|
|  C4138928|    1|
+--------

In [13]:
# Correlation Analysis (for numerical features)
numeric_columns = [col_name for col_name, dtype in df.dtypes if dtype in ['int', 'double']]

for column in numeric_columns:
    df.select(column).summary().show()

+-------+------------------+
|summary|CustAccountBalance|
+-------+------------------+
|  count|           1046198|
|   mean|115403.54005622343|
| stddev|  846485.380600677|
|    min|               0.0|
|    25%|           4721.76|
|    50%|          16794.15|
|    75%|          57646.03|
|    max|     1.150354951E8|
+-------+------------------+

+-------+------------------+
|summary|   TransactionTime|
+-------+------------------+
|  count|           1048567|
|   mean|157087.52939297154|
| stddev| 51261.85402233114|
|    min|                 0|
|    25%|            124033|
|    50%|            164221|
|    75%|            200012|
|    max|            235959|
+-------+------------------+

+-------+-----------------------+
|summary|TransactionAmount (INR)|
+-------+-----------------------+
|  count|                1048567|
|   mean|     1574.3350034571733|
| stddev|      6574.742978453954|
|    min|                    0.0|
|    25%|                  160.2|
|    50%|                 459.

In [14]:
# Drop rows with null values in specific columns
df = df.na.drop(subset=["CustGender", "CustLocation", "CustAccountBalance"])

# Verify that null values are dropped
null_counts_after_drop = df.select([count(when(col(c).isNull() | isnan(c), c)).alias(c) for c in df.columns])
null_counts_after_drop.show()


+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+
|TransactionID|CustomerID|CustomerDOB|CustGender|CustLocation|CustAccountBalance|TransactionDate|TransactionTime|TransactionAmount (INR)|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+
|            0|         0|          0|         0|           0|                 0|              0|              0|                      0|
+-------------+----------+-----------+----------+------------+------------------+---------------+---------------+-----------------------+



In [15]:
# Correlation matrix
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation

vector_col = "corr_features"
assembler = VectorAssembler(inputCols=numeric_columns, outputCol=vector_col)
vector_df = assembler.transform(df).select(vector_col)

In [16]:
matrix = Correlation.corr(vector_df, vector_col)
correlation_matrix = matrix.collect()[0][0]
print(correlation_matrix.toArray())

[[ 1.         -0.00410334  0.06264685]
 [-0.00410334  1.          0.00787137]
 [ 0.06264685  0.00787137  1.        ]]
