In [0]:
## libs

from functools import reduce
from pyspark.sql import SparkSession, Row, functions
from pyspark.sql.types import (
    StructField
    , StringType
    , IntegerType
    , DoubleType
    , BooleanType
    , StructType
)

In [0]:
## pipeline definition

class DataPipeline:
    def __init__(self, spark_session:SparkSession, debug_mode:bool=False) -> None:
        self.spark = spark_session
        self.debug = debug_mode
        self.file_paths:dict = {
            'test.csv':'/FileStore/tables/test.csv'
            , 'train.csv':'/FileStore/tables/train.csv'
            , 'Customer_Churn_Records.csv':'/FileStore/tables/Customer_Churn_Records.csv'
            , 'Bank_Customer_Churn_Prediction.csv':'/FileStore/tables/Bank_Customer_Churn_Prediction.csv'
            , 'Churn_Modeling.csv':'/FileStore/tables/Churn_Modeling.csv'
            , 'Churn_Modelling.csv':'/FileStore/tables/Churn_Modelling.csv'
            , 'Churn_Modelling-1.csv':'/FileStore/tables/Churn_Modelling-1.csv'
            , 'churn.csv':'/FileStore/tables/churn.csv'
        }
        self.spark_dataframes = None
        self.joined_df = None
        self.target_columns = None
        self.target_schema_dict = None
        self.target_df = None
        print('Pipeline initialised!') if self.debug else None
    def run(self) -> None:
        print('Starting to run pipeline...') if self.debug else None
        self.extract()
        self.transform()
        self.load()
        print('Pipeline ran successfully!') if self.debug else None
    def extract(self) -> None:
        self._extract()
    def transform(self) -> None:
        self._transform_rename_columns()
        self._transform_assert_existence_and_unique_id()
        self._transform_full_outer_join()
        self._transform_create_target_schema()
        self._transform_filter_valid_data()
        self._transform_populate_target_df()
    def load(self) -> None:
        self._load()
    def _extract(self) -> None:
        self.spark_dataframes = dict()
        self.spark_dataframes['test.csv'] = self.spark.read.csv(
            self.file_paths['test.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('id', IntegerType(), True)
                , StructField('CustomerId', IntegerType(), True)
                , StructField('Surname', StringType(), True)
                , StructField('CreditScore', IntegerType(), True)
                , StructField('Geography', StringType(), True)
                , StructField('Gender', StringType(), True)
                , StructField('Age', IntegerType(), True)
                , StructField('Tenure', IntegerType(), True)
                , StructField('Balance', DoubleType(), True)
                , StructField('NumOfProducts', IntegerType(), True)
                , StructField('HasCrCard', IntegerType(), True)
                , StructField('IsActiveMember', IntegerType(), True)
                , StructField('EstimatedSalary', DoubleType(), True)
            ])
        )
        self.spark_dataframes['train.csv'] = self.spark.read.csv(
            self.file_paths['train.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('id', IntegerType(), True)
                , StructField('CustomerId', IntegerType(), True)
                , StructField('Surname', StringType(), True)
                , StructField('CreditScore', IntegerType(), True)
                , StructField('Geography', StringType(), True)
                , StructField('Gender', StringType(), True)
                , StructField('Age', IntegerType(), True)
                , StructField('Tenure', IntegerType(), True)
                , StructField('Balance', DoubleType(), True)
                , StructField('NumOfProducts', IntegerType(), True)
                , StructField('HasCrCard', IntegerType(), True)
                , StructField('IsActiveMember', IntegerType(), True)
                , StructField('EstimatedSalary', DoubleType(), True)
                , StructField('Exited', IntegerType(), True)
            ])
        )
        self.spark_dataframes['Customer_Churn_Records.csv'] = self.spark.read.csv(
            self.file_paths['Customer_Churn_Records.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('RowNumber', IntegerType(), True)
                , StructField('CustomerId', IntegerType(), True)
                , StructField('Surname', StringType(), True)
                , StructField('CreditScore', IntegerType(), True)
                , StructField('Geography', StringType(), True)
                , StructField('Gender', StringType(), True)
                , StructField('Age', IntegerType(), True)
                , StructField('Tenure', IntegerType(), True)
                , StructField('Balance', DoubleType(), True)
                , StructField('NumOfProducts', IntegerType(), True)
                , StructField('HasCrCard', IntegerType(), True)
                , StructField('IsActiveMember', IntegerType(), True)
                , StructField('EstimatedSalary', DoubleType(), True)
                , StructField('Exited', IntegerType(), True)
                , StructField('Complain', IntegerType(), True)
                , StructField('Satisfaction Score', IntegerType(), True)
                , StructField('Card Type', StringType(), True)
                , StructField('Point Earned', IntegerType(), True)
            ])
        )
        self.spark_dataframes['Bank_Customer_Churn_Prediction.csv'] = self.spark.read.csv(
            self.file_paths['Bank_Customer_Churn_Prediction.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('customer_id', IntegerType(), True)
                , StructField('credit_score', IntegerType(), True)
                , StructField('country', StringType(), True)
                , StructField('gender', StringType(), True)
                , StructField('age', IntegerType(), True)
                , StructField('tenure', IntegerType(), True)
                , StructField('balance', DoubleType(), True)
                , StructField('products_number', IntegerType(), True)
                , StructField('credit_card', IntegerType(), True)
                , StructField('active_member', IntegerType(), True)
                , StructField('estimated_salary', DoubleType(), True)
                , StructField('churn', IntegerType(), True)
            ])
        )
        self.spark_dataframes['Churn_Modeling.csv'] = self.spark.read.csv(
            self.file_paths['Churn_Modeling.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('RowNumber', IntegerType(), True)
                , StructField('CustomerId', IntegerType(), True)
                , StructField('Surname', StringType(), True)
                , StructField('CreditScore', IntegerType(), True)
                , StructField('Geography', StringType(), True)
                , StructField('Gender', StringType(), True)
                , StructField('Age', IntegerType(), True)
                , StructField('Tenure', IntegerType(), True)
                , StructField('Balance', DoubleType(), True)
                , StructField('NumOfProducts', IntegerType(), True)
                , StructField('HasCrCard', IntegerType(), True)
                , StructField('IsActiveMember', IntegerType(), True)
                , StructField('EstimatedSalary', DoubleType(), True)
                , StructField('Exited', IntegerType(), True)
            ])
        )
        self.spark_dataframes['Churn_Modelling.csv'] = self.spark.read.csv(
            self.file_paths['Churn_Modelling.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('RowNumber', IntegerType(), True)
                , StructField('CustomerId', IntegerType(), True)
                , StructField('Surname', StringType(), True)
                , StructField('CreditScore', IntegerType(), True)
                , StructField('Geography', StringType(), True)
                , StructField('Gender', StringType(), True)
                , StructField('Age', IntegerType(), True)
                , StructField('Tenure', IntegerType(), True)
                , StructField('Balance', DoubleType(), True)
                , StructField('NumOfProducts', IntegerType(), True)
                , StructField('HasCrCard', IntegerType(), True)
                , StructField('IsActiveMember', IntegerType(), True)
                , StructField('EstimatedSalary', DoubleType(), True)
                , StructField('Exited', IntegerType(), True)
            ])
        )
        self.spark_dataframes['Churn_Modelling-1.csv'] = self.spark.read.csv(
            self.file_paths['Churn_Modelling-1.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('RowNumber', IntegerType(), True)
                , StructField('CustomerId', IntegerType(), True)
                , StructField('Surname', StringType(), True)
                , StructField('CreditScore', IntegerType(), True)
                , StructField('Geography', StringType(), True)
                , StructField('Gender', StringType(), True)
                , StructField('Age', IntegerType(), True)
                , StructField('Tenure', IntegerType(), True)
                , StructField('Balance', DoubleType(), True)
                , StructField('NumOfProducts', IntegerType(), True)
                , StructField('HasCrCard', IntegerType(), True)
                , StructField('IsActiveMember', IntegerType(), True)
                , StructField('EstimatedSalary', DoubleType(), True)
                , StructField('Exited', IntegerType(), True)
            ])
        )
        self.spark_dataframes['churn.csv'] = self.spark.read.csv(
            self.file_paths['churn.csv']
            , header=True
            , schema=StructType(fields=[
                StructField('RowNumber', IntegerType(), True)
                , StructField('CustomerId', IntegerType(), True)
                , StructField('Surname', StringType(), True)
                , StructField('CreditScore', IntegerType(), True)
                , StructField('Geography', StringType(), True)
                , StructField('Gender', StringType(), True)
                , StructField('Age', IntegerType(), True)
                , StructField('Tenure', IntegerType(), True)
                , StructField('Balance', DoubleType(), True)
                , StructField('NumOfProducts', IntegerType(), True)
                , StructField('HasCrCard', IntegerType(), True)
                , StructField('IsActiveMember', IntegerType(), True)
                , StructField('EstimatedSalary', DoubleType(), True)
                , StructField('Exited', IntegerType(), True)
            ])
        )
        print('Data extracted.') if self.debug else None
    def _transform_rename_columns(self) -> None:
        new_column_names = [
            'id_test_csv'
            , 'customer_id'
            , 'surname_test_csv'
            , 'credit_score_test_csv'
            , 'geography_test_csv'
            , 'gender_test_csv'
            , 'age_test_csv'
            , 'tenure_test_csv'
            , 'balance_test_csv'
            , 'product_count_test_csv'
            , 'has_creditcard_test_csv'
            , 'active_member_test_csv'
            , 'estimated_salary_test_csv'
        ]
        self.spark_dataframes['test.csv'] = self.spark_dataframes['test.csv'].toDF(*new_column_names)
        new_column_names = [
            'id_train_csv'
            , 'customer_id'
            , 'surname_train_csv'
            , 'credit_score_train_csv'
            , 'geography_train_csv'
            , 'gender_train_csv'
            , 'age_train_csv'
            , 'tenure_train_csv'
            , 'balance_train_csv'
            , 'product_count_train_csv'
            , 'has_creditcard_train_csv'
            , 'active_member_train_csv'
            , 'estimated_salary_train_csv'
            , 'churn_train_csv'
        ]
        self.spark_dataframes['train.csv'] = self.spark_dataframes['train.csv'].toDF(*new_column_names)
        new_column_names = [
            'rownum_Customer_Churn_Records_csv'
            , 'customer_id'
            , 'surname_Customer_Churn_Records_csv'
            , 'credit_score_Customer_Churn_Records_csv'
            , 'geography_Customer_Churn_Records_csv'
            , 'gender_Customer_Churn_Records_csv'
            , 'age_Customer_Churn_Records_csv'
            , 'tenure_Customer_Churn_Records_csv'
            , 'balance_Customer_Churn_Records_csv'
            , 'product_count_Customer_Churn_Records_csv'
            , 'has_creditcard_Customer_Churn_Records_csv'
            , 'active_member_Customer_Churn_Records_csv'
            , 'estimated_salary_Customer_Churn_Records_csv'
            , 'churn_Customer_Churn_Records_csv'
            , 'complain_Customer_Churn_Records_csv'
            , 'satisfaction_score_Customer_Churn_Records_csv'
            , 'card_type_Customer_Churn_Records_csv'
            , 'points_earned_Customer_Churn_Records_csv'
        ]
        self.spark_dataframes['Customer_Churn_Records.csv'] = \
            self.spark_dataframes['Customer_Churn_Records.csv'].toDF(*new_column_names)
        new_column_names = [
            'customer_id'
            , 'credit_score_Bank_Customer_Churn_Prediction_csv'
            , 'geography_Bank_Customer_Churn_Prediction_csv'
            , 'gender_Bank_Customer_Churn_Prediction_csv'
            , 'age_Bank_Customer_Churn_Prediction_csv'
            , 'tenure_Bank_Customer_Churn_Prediction_csv'
            , 'balance_Bank_Customer_Churn_Prediction_csv'
            , 'product_count_Bank_Customer_Churn_Prediction_csv'
            , 'has_creditcard_Bank_Customer_Churn_Prediction_csv'
            , 'active_member_Bank_Customer_Churn_Prediction_csv'
            , 'estimated_salary_Bank_Customer_Churn_Prediction_csv'
            , 'churn_Bank_Customer_Churn_Prediction_csv'
        ]
        self.spark_dataframes['Bank_Customer_Churn_Prediction.csv'] = \
            self.spark_dataframes['Bank_Customer_Churn_Prediction.csv'].toDF(*new_column_names)
        new_column_names = [
            'rownum_Churn_Modeling_csv'
            , 'customer_id'
            , 'surname_Churn_Modeling_csv'
            , 'credit_score_Churn_Modeling_csv'
            , 'geography_Churn_Modeling_csv'
            , 'gender_Churn_Modeling_csv'
            , 'age_Churn_Modeling_csv'
            , 'tenure_Churn_Modeling_csv'
            , 'balance_Churn_Modeling_csv'
            , 'product_count_Churn_Modeling_csv'
            , 'has_creditcard_Churn_Modeling_csv'
            , 'active_member_Churn_Modeling_csv'
            , 'estimated_salary_Churn_Modeling_csv'
            , 'churn_Churn_Modeling_csv'
        ]
        self.spark_dataframes['Churn_Modeling.csv'] = \
            self.spark_dataframes['Churn_Modeling.csv'].toDF(*new_column_names)
        new_column_names = [
            'rownum_Churn_Modelling_csv'
            , 'customer_id'
            , 'surname_Churn_Modelling_csv'
            , 'credit_score_Churn_Modelling_csv'
            , 'geography_Churn_Modelling_csv'
            , 'gender_Churn_Modelling_csv'
            , 'age_Churn_Modelling_csv'
            , 'tenure_Churn_Modelling_csv'
            , 'balance_Churn_Modelling_csv'
            , 'product_count_Churn_Modelling_csv'
            , 'has_creditcard_Churn_Modelling_csv'
            , 'active_member_Churn_Modelling_csv'
            , 'estimated_salary_Churn_Modelling_csv'
            , 'churn_Churn_Modelling_csv'
        ]
        self.spark_dataframes['Churn_Modelling.csv'] = \
            self.spark_dataframes['Churn_Modelling.csv'].toDF(*new_column_names)
        new_column_names = [
            'rownum_Churn_Modelling-1_csv'
            , 'customer_id'
            , 'surname_Churn_Modelling-1_csv'
            , 'credit_score_Churn_Modelling-1_csv'
            , 'geography_Churn_Modelling-1_csv'
            , 'gender_Churn_Modelling-1_csv'
            , 'age_Churn_Modelling-1_csv'
            , 'tenure_Churn_Modelling-1_csv'
            , 'balance_Churn_Modelling-1_csv'
            , 'product_count_Churn_Modelling-1_csv'
            , 'has_creditcard_Churn_Modelling-1_csv'
            , 'active_member_Churn_Modelling-1_csv'
            , 'estimated_salary_Churn_Modelling-1_csv'
            , 'churn_Churn_Modelling-1_csv'
        ]
        self.spark_dataframes['Churn_Modelling-1.csv'] = \
            self.spark_dataframes['Churn_Modelling-1.csv'].toDF(*new_column_names)
        new_column_names = [
            'rownum_churn_csv'
            , 'customer_id'
            , 'surname_churn_csv'
            , 'credit_score_churn_csv'
            , 'geography_churn_csv'
            , 'gender_churn_csv'
            , 'age_churn_csv'
            , 'tenure_churn_csv'
            , 'balance_churn_csv'
            , 'product_count_churn_csv'
            , 'has_creditcard_churn_csv'
            , 'active_member_churn_csv'
            , 'estimated_salary_churn_csv'
            , 'churn_churn_csv'
        ]
        self.spark_dataframes['churn.csv'] = self.spark_dataframes['churn.csv'].toDF(*new_column_names)
        print('Columns synchronised.') if self.debug else None
    def _transform_assert_existence_and_unique_id(self) -> None:
        self.spark_dataframes = \
            {k:df.filter(df.customer_id.isNotNull()) for k,df in self.spark_dataframes.copy().items()}
        self.spark_dataframes = \
            {k:df.dropDuplicates(['customer_id']) for k,df in self.spark_dataframes.copy().items()}
        print('Asserted existence of unique IDs.') if self.debug else None
    def _transform_full_outer_join(self) -> None:
        ordered_df = [df for df in self.spark_dataframes.values()]
        self.joined_df = ordered_df[0]
        for df in ordered_df[1:]:
            self.joined_df = self.joined_df.join(df, on='customer_id', how='full')
        print(f'Data merged with {self.joined_df.count()} rows.') if self.debug else None
    def _transform_create_target_schema(self) -> None:
        target_struct = StructType(fields=[
            StructField('customer_id', IntegerType(), True)
            , StructField('surname', StringType(), True)
            , StructField('credit_score', IntegerType(), True)
            , StructField('geography', StringType(), True)
            , StructField('gender', StringType(), True)
            , StructField('age', IntegerType(), True)
            , StructField('tenure', IntegerType(), True)
            , StructField('balance', DoubleType(), True)
            , StructField('product_count', IntegerType(), True)
            , StructField('has_creditcard', IntegerType(), True)
            , StructField('active_member', IntegerType(), True)
            , StructField('estimated_salary', DoubleType(), True)
            , StructField('complain', IntegerType(), True)
            , StructField('satisfaction_score', IntegerType(), True)
            , StructField('card_type', StringType(), True)
            , StructField('points_earned', IntegerType(), True)
            , StructField('churn', IntegerType(), True)
        ])
        self.target_columns = [field.name for field in target_struct.fields]
        self.target_schema_dict = {field.name: field.dataType for field in target_struct.fields}
        print('Created target schema.') if self.debug else None
    def _transform_filter_valid_data(self) -> None:
        for target_column in self.target_columns[1:]:
            columns = [c for c in self.joined_df.columns if target_column in c]
            expression = ', '.join([f'`{c}`' for c in columns])
            expression = f'filter(array({expression}), x -> x IS NOT NULL)'
            number_of_values = functions.size(functions.array_distinct(functions.expr(expression)))
            new_column = functions.when(number_of_values>1, functions.lit(False)).otherwise(functions.lit(True))
            self.joined_df = self.joined_df.withColumn(f'valid_{target_column}', new_column)
        boolean_columns = [c for c in self.joined_df.columns if 'valid' in c]
        filter_cond = (functions.col(c) == True for c in boolean_columns)
        filter_expr = reduce(lambda x, y: x & y, filter_cond)
        self.joined_df = self.joined_df.filter(filter_expr)
        print('Filtered valid data.') if self.debug else None
    def _transform_populate_target_df(self) -> None:
        for target_column in self.target_columns[1:]:
            uncasted_columns = [c for c in self.joined_df.columns if c.startswith(target_column)]
            casted_columns = [functions.col(c).cast(self.target_schema_dict[target_column]) for c in uncasted_columns]
            self.joined_df = self.joined_df.withColumn(target_column, functions.coalesce(*casted_columns))
        final_columns = [functions.col(c) for c in self.target_columns]
        self.target_df = self.joined_df.select(*final_columns)
        print(f'Target dataframe populated with {self.target_df.count()} rows.') if self.debug else None
    def _load(self) -> None:
        self.target_df.write.mode('overwrite').option('header', 'true').csv('/FileStore/tables/pyspark_churn.csv')
        print('Data loaded.') if self.debug else None

In [0]:
## main

def main() -> None:
    spark = SparkSession.builder.appName('Data_Pipeline').getOrCreate()
    debug_mode = True
    pipeline = DataPipeline(spark, debug_mode)
    pipeline.run()

In [0]:
## run script

if __name__ == '__main__':
    main()

Pipeline initialised!
Starting to run pipeline...
Data extracted.
Columns synchronised.
Asserted existence of unique IDs.
Data merged with 28536 rows.
Created target schema.
Filtered valid data.
Target dataframe populated with 13545 rows.
Data loaded.
Pipeline ran successfully!
