In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

# 1️⃣ Start Spark session
spark = SparkSession.builder.appName("CustomerDataCleaning").getOrCreate()

# 2️⃣ Load your CSV file (update the actual file location if different)
df = spark.read.csv("customer_data_partitioned.csv", header=True, inferSchema=True)

print("✅ Raw data loaded")
df.show(5, truncate=False)
print(f"Total rows before cleaning: {df.count()}")

# 3️⃣ Remove duplicate rows
df = df.dropDuplicates()
print(f"✅ After dropping duplicates: {df.count()} rows")

# 4️⃣ Standardize column names (lowercase, underscores)
df = df.toDF(*[c.lower().strip().replace(" ", "_") for c in df.columns])

# 5️⃣ Identify numeric & categorical columns
numeric_cols = [f.name for f in df.schema.fields if "StringType" not in str(f.dataType)]
categorical_cols = [f.name for f in df.schema.fields if "StringType" in str(f.dataType)]

# 6️⃣ Handle missing values
# Fill numeric nulls with column mean
for col in numeric_cols:
    mean_val = df.select(F.mean(F.col(col))).collect()[0][0]
    if mean_val is not None:
        df = df.fillna({col: mean_val})

# Fill categorical nulls with "Unknown"
for col in categorical_cols:
    df = df.fillna({col: "Unknown"})

# 7️⃣ Clean categorical columns (trim spaces, capitalize words)
for col in categorical_cols:
    df = df.withColumn(col, F.trim(F.col(col)))
    df = df.withColumn(col, F.initcap(F.col(col)))  # Optional title case

# 8️⃣ Remove outliers (IQR method) for numeric columns
for col in numeric_cols:
    if df.filter(F.col(col).isNotNull()).count() > 0:
        q1, q3 = df.approxQuantile(col, [0.25, 0.75], 0.05)
        if q1 is not None and q3 is not None:
            IQR = q3 - q1
            lower_bound = q1 - 1.5 * IQR
            upper_bound = q3 + 1.5 * IQR
            df = df.filter((F.col(col) >= lower_bound) & (F.col(col) <= upper_bound))

# 9️⃣ (Example) Remove invalid values for specific fields
if "age" in numeric_cols:
    df = df.filter((F.col("age") >= 0) & (F.col("age") <= 120))

# 🔟 Save the cleaned data
df.write.csv("cleaned_customer_data_partitioned.csv", header=True, mode='overwrite')

print("✅ Data cleaning completed")
df.show(10, truncate=False)