In [5]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit
from pyspark.sql.functions import col, regexp_replace, trim, when, regexp_extract
from pyspark.sql.types import *
from pyspark.sql.functions import col, isnan, when, count, date_format, to_date, to_timestamp
import traceback

# Create Spark Session
spark = SparkSession.builder \
    .appName("DataProcessing") \
    .getOrCreate()

# Customers Schema
customers_schema = StructType([
    StructField("customer_id", StringType(), True),
    StructField("customer_name", StringType(), True),
    StructField("city", StringType(), True),
])  

# Reading Customer CSV
customers_df = spark.read \
    .format("csv") \
    .option("header", True) \
    .schema(customers_schema) \
    .load("../data/customers.csv")

# Check schemas
print("Customers DataFrame Schema:")
customers_df.printSchema()

### Checking the number of rows in the data
num_rows = customers_df.count()
print(f"Number of rows: {num_rows}")

### Check for missing values
# Counting missing values for each column
missing_values = customers_df.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in customers_df.columns]
)
missing_values.show()

### Substitute with a Default Value
def clean_customer_names(customers_df):
    """
    Efficiently handle missing names in customer dataset using PySpark.
    Using na.fill() is more performant than withColumn() for simple replacements.
    """
    name_columns = [col for col in customers_df.columns 
                   if any(name_field in col.lower() 
                         for name_field in ['customer_name','city'])]
    
    fill_dict = {col: "Unknown" for col in name_columns}
    cleaned_df = customers_df.na.fill(fill_dict)
    cleaned_df = cleaned_df.cache()
    return cleaned_df

cleaned_customers_df = clean_customer_names(customers_df)

# Counting missing values for each column
missing_values = cleaned_customers_df.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in cleaned_customers_df.columns]
)
missing_values.show()

def standardize_customer_id(df, column_name: str = "customer_id"):
    try:
        if df is None or column_name not in df.columns:
            print("Invalid DataFrame or column name.")
            return None
        
        pattern = "^[0-9]+$"
        standardized_df = df.withColumn(
            column_name,
            when(
                col(column_name).rlike(pattern),
                col(column_name).cast(LongType())
            ).otherwise(None)
        )
        
        row_count = standardized_df.count()
        print(f"Row count after transformation: {row_count}")
        standardized_df = standardized_df.repartition(200)
        return standardized_df
    except Exception as e:
        print(f"Error in function: {str(e)}")
        traceback.print_exc()
        return None

standardized_customers_df = standardize_customer_id(customers_df, column_name="customer_id")
standardized_customers_df.printSchema()

def filter_invalid_customer_id(df, column_name: str = "customer_id", invalid_values: list = [-99999, 1e6]):
    try:
        if df is None or column_name not in df.columns:
            print("Invalid DataFrame or column name.")
            return None
        
        invalid_values_set = set(invalid_values)
        filter_condition = ~col(column_name).isin(invalid_values_set)
        df_filtered = df.filter(filter_condition)
        return df_filtered
    except Exception as e:
        print(f"Error in function: {str(e)}")
        traceback.print_exc()
        return None

valid_customers_df = filter_invalid_customer_id(customers_df, column_name="customer_id", invalid_values=[-99999, 1e6])
valid_customers_df.show()

def show_invalid_customer_ids(df, column_name: str = "customer_id", invalid_values: list = [-99999, 1e6]):
    try:
        if df is None or column_name not in df.columns:
            print("Invalid DataFrame or column name.")
            return None
        
        invalid_values_set = set(invalid_values)
        filter_condition = col(column_name).isin(invalid_values_set)
        df_invalid = df.filter(filter_condition)
        return df_invalid
    except Exception as e:
        print(f"Error in function: {str(e)}")
        traceback.print_exc()
        return None

invalid_customer_ids_df = show_invalid_customer_ids(customers_df)
invalid_customer_ids_df.show()

# Hashing
customers_df_opt = customers_df.select("customer_name", "city")
customers_with_hash = customers_df_opt.withColumn(
    "customer_id_hash",
    md5(concat(
        coalesce(col("customer_name"), lit("")),
        lit("_"),
        coalesce(col("city"), lit(""))
    ))
)

customers_with_hash.cache()
final_df = customers_with_hash.select(
    "customer_id_hash",
    "customer_name",
    "city"
)

final_df.write \
    .mode("overwrite") \
    .partitionBy("city") \
    .format("parquet") \
    .save("output_path")

print("Number of records processed:", final_df.count())
print("\nSchema of final dataframe:")
final_df.printSchema()
print("\nSample of processed data:")
final_df.show(5, truncate=False)

# Reference cities list
cities = [
    "Johannesburg", "Cape Town", "Durban", "Pretoria", "Port Elizabeth", 
    "East London", "Bloemfontein", "Nelspruit", "Polokwane", "Kimberley"
]

# Create a broadcast variable for efficient lookup
cities_broadcast = spark.sparkContext.broadcast([city.lower() for city in cities])

# Clean and standardize cities, handle null names and invalid cities
cleaned_customers_df = customers_df.withColumn(
    "customer_name",
    when(col("customer_name").isNull(), lit("Unknown"))
    .otherwise(col("customer_name"))
).withColumn(
    "city",
    when(
        (upper(col("city")).rlike("INVALID_CITY[0-9]*")) |
        (col("city").isNull()),
        lit("Unknown")
    ).otherwise(
        when(
            lower(regexp_replace(
                regexp_replace(col("city"), r"[-_]\d+$|[\s]+\d+$", ""),
                r"\s+", " "
            )).isin(cities_broadcast.value),
            initcap(regexp_replace(col("city"), r"[-_]\d+$|[\s]+\d+$", ""))
        ).otherwise(col("city"))
    )
)

# Show results
cleaned_customers_df.select("customer_id", "customer_name", "city").show(truncate=False)

# Checking for duplicates
duplicate_check = cleaned_customers_df.groupBy("customer_id", "customer_name", "city") \
    .agg(count("*").alias("count")) \
    .filter(col("count") > 1)

print("Number of duplicate records found:", duplicate_check.count())

if duplicate_check.count() > 0:
    print("\nDuplicate records found:")
    duplicate_check.show(truncate=False)

deduped_customers_df = cleaned_customers_df.dropDuplicates(["customer_id", "customer_name", "city"])

print("\nOriginal record count:", cleaned_customers_df.count())
print("Record count after removing duplicates:", deduped_customers_df.count())
print("\nFinal cleaned and deduplicated data:")
deduped_customers_df.select("customer_id", "customer_name", "city").show(truncate=False)

deduped_customers_df.createOrReplaceTempView("cleaned_customers")

# Products Schema
products_schema = StructType([
    StructField("product_id", StringType(), True),
    StructField("product_name", StringType(), True),
    StructField("category", StringType(), True),
])  

# Reading Products CSV
products_df = spark.read \
    .format("csv") \
    .option("header", True) \
    .schema(products_schema) \
    .load("../data/products.csv")

print("Products DataFrame Schema:")
products_df.printSchema()

num_rows = products_df.count()
print(f"Number of rows: {num_rows}")

# Check for missing values
missing_values = products_df.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in products_df.columns]
)
missing_values.show()

def clean_product_names(products_df):
    """
    Efficiently handle missing names in products dataset using PySpark.
    """
    name_columns = [col for col in products_df.columns 
                   if any(name_field in col.lower() 
                         for name_field in ['product_name','category'])]
    
    fill_dict = {col: "Unknown" for col in name_columns}
    cleaned_df = products_df.na.fill(fill_dict)
    cleaned_df = cleaned_df.cache()
    return cleaned_df

cleaned_products_df = clean_product_names(products_df)

missing_values = cleaned_products_df.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in cleaned_products_df.columns]
)
missing_values.show()
cleaned_products_df.show(5)

# Simple and efficient column renaming
products_df = products_df.select(
    col("product_name").alias("product_id"),
    col("product_id").alias("product_name"),
    col("category")
)

print("Corrected column data:")
products_df.show(truncate=False)

# Clean and standardize the dataset
cleaned_products_df = products_df.select(
    col("product_id"),
    initcap(
        regexp_replace(
            regexp_replace(
                regexp_replace(
                    col("product_name"),
                    r'[_\-]|\d+$',
                    ''
                ),
                r'\s+',
                ' '
            ),
            r'\s+$',
            ''
        )
    ).alias("product_name"),
    
    when(
        (col("category").isNull()) | 
        (col("category") == "InvalidCategory"), 
        lit("Unknown")
    ).otherwise(col("category")).alias("category")
)

final_products_df = cleaned_products_df.dropDuplicates(["product_id", "product_name", "category"])

print("Cleaned and standardized products data:")
final_products_df.show(truncate=False)

print("\nSummary of changes:")
print("Original row count:", products_df.count())
print("Final row count:", final_products_df.count())

# Create hashed product_id
cleaned_products_df = products_df.select(
    initcap(
        regexp_replace(
            regexp_replace(
                regexp_replace(
                    col("product_name"),
                    r'[_\-]|\d+$',
                    ''
                ),
                r'\s+',
                ' '
            ),
            r'\s+$',
            ''
        )
    ).alias("product_name"),
    
    when(
        (col("category").isNull()) | 
        (col("category") == "InvalidCategory"), 
        lit("Unknown")
    ).otherwise(col("category")).alias("category")
)

final_products_df = cleaned_products_df.withColumn(
    "product_id",
    md5(concat(
        coalesce(col("product_name"), lit("")),
        lit("_"),
        coalesce(col("category"), lit(""))
    ))
).select(
    "product_id",
    "product_name",
    "category"
)

final_products_df.cache()
final_products_df = final_products_df.dropDuplicates(["product_name", "category"])

print("Products with hashed IDs:")
final_products_df.show(truncate=False)

print("\nSummary:")
print("Original row count:", products_df.count())
print("Final row count:", final_products_df.count())
print("\nSchema of final dataframe:")
final_products_df.printSchema()

# Dates Schema
date_schema = StructType([
    StructField("date", DateType(), True),
    StructField("mmm_yy", StringType(), True),
    StructField("week_no", StringType(), True)
])

# Reading Dates CSV
date_df = spark.read \
    .format("csv") \
    .option("header", True) \
    .schema(date_schema) \
    .load("../data/dates.csv")

print("Dates DataFrame Schema:")
date_df.printSchema()

num_rows = date_df.count()
print(f"Number of rows: {num_rows}")

missing_values = date_df.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in date_df.columns]
)
missing_values.show()

print("Original row count:", date_df.count())
date_df = date_df.na.drop(how='any')

print("\nCleaned data (rows with no missing values):")
date_df.show(truncate=False)

print("\nSchema of cleaned dataframe:")
date_df.printSchema()

missing_values = date_df.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in date_df.columns]
)
missing_values.show()

# Clean dates data
cleaned_date_df = date_df.na.drop(how='any') \
    .filter(
        ~lower(col("mmm_yy")).contains("invalid_date") &
        col("mmm_yy").isNotNull()
    ).withColumn(
        "week_no",
        regexp_extract(col("week_no"), r"(\d+)", 1)
    )

# Save final processed datasets
deduped_customers_df.write.mode("overwrite").parquet("output/customers")
final_products_df.write.mode("overwrite").parquet("output/products")
cleaned_date_df.write.mode("overwrite").parquet("output/dates")

# Show top 20 records from each dataset
print("\nTop 20 Customer Records:")
deduped_customers_df.show(20, truncate=False)

print("\nTop 20 Product Records:")
final_products_df.show(20, truncate=False)

print("\nTop 20 Date Records:")
cleaned_date_df.show(20, truncate=False)

spark.stop()

Customers DataFrame Schema:
root
 |-- customer_id: string (nullable = true)
 |-- customer_name: string (nullable = true)
 |-- city: string (nullable = true)

Number of rows: 200
+-----------+-------------+----+
|customer_id|customer_name|city|
+-----------+-------------+----+
|          0|           22|  21|
+-----------+-------------+----+

25/01/14 19:36:32 WARN CacheManager: Asked to cache already cached data.
+-----------+-------------+----+
|customer_id|customer_name|city|
+-----------+-------------+----+
|          0|            0|   0|
+-----------+-------------+----+

Row count after transformation: 200
root
 |-- customer_id: long (nullable = true)
 |-- customer_name: string (nullable = true)
 |-- city: string (nullable = true)

+-----------+-----------------+-----------------+
|customer_id|    customer_name|             city|
+-----------+-----------------+-----------------+
|     789221|      Info Stores|     Polokwane 11|
|     789301|             null|           durban|
|  

                                                                                

Number of records processed: 200

Schema of final dataframe:
root
 |-- customer_id_hash: string (nullable = false)
 |-- customer_name: string (nullable = true)
 |-- city: string (nullable = true)


Sample of processed data:
+--------------------------------+-------------+-----------------+
|customer_id_hash                |customer_name|city             |
+--------------------------------+-------------+-----------------+
|1e7746579353fc48ce01dad8221156c9|Info Stores  |Polokwane 11     |
|c270645522c203f91626a17998317cc8|null         |durban           |
|b979555b640767d3f2603e29319024f0|Coolblue 96  |Cape Town        |
|6d9be2294eb49598988c36e46e3d630a|Logic Stores |port elizabeth-10|
|4ac660ea4ac050c24e503263f4e29107|coolblue     |port elizabeth   |
+--------------------------------+-------------+-----------------+
only showing top 5 rows

+-----------+-----------------+--------------+
|customer_id|customer_name    |city          |
+-----------+-----------------+--------------+
|789221

In [4]:
from pyspark.sql.functions import md5, concat, col, lit, coalesce
from pyspark.sql.types import StringType
from pyspark.sql.functions import regexp_replace, initcap, when, col, lower,upper
