# 02 â€“ Data Cleaning (PySpark)

In this notebook we clean and preprocess the data. We load the raw CSV files using PySpark (limiting to 1 million rows), drop duplicates, handle missing values, and perform outlier detection. Cleaned data are saved to the `data/processed` directory for later use.

Steps:
1. Load data from `data/raw` using PySpark.
2. Drop duplicate rows.
3. Fill missing values using appropriate strategies (median for numeric, mode for categorical).
4. Convert timestamp columns to datetime if needed.
5. Save cleaned data as Parquet for efficient downstream processing.


In [2]:
# ============================
# Clean + Prepare (Spark) - FIXED VERSION
# ============================
# Fixes that directly impact CTR modeling quality:
# 1) Avoid non-random limit() sampling -> use random sampling with seed
# 2) Normalize join keys types NOW (user/adgroup_id) to prevent join-enrichment loss later
# 3) Safer missing handling:
#    - Categorical: fill with "UNKNOWN" (do NOT mode-impute)
#    - Numeric: fill with median + add missing indicator flags where useful
# 4) Keep time_stamp as integer (for future time-based split). Also store time_stamp_str for Windows safety.
# 5) Still saves as CSV to data/processed (same outputs as your pipeline expects)

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_unixtime
from pyspark.sql.types import NumericType
from pyspark.sql import functions as F
import os

# ============================
# 0. Spark session
# ============================
spark = (
    SparkSession.builder
        .appName("CTR_Project_Cleaning_Fixed")
        .master("local[4]")
        .config("spark.driver.memory", "6g")
        .config("spark.sql.shuffle.partitions", "200")
        .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")
print("Spark version:", spark.version)

# ============================
# 1. Paths
# ============================
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
raw_dir = os.path.join(project_root, "data", "raw")
processed_dir = os.path.join(project_root, "data", "processed")
os.makedirs(processed_dir, exist_ok=True)

print("Project root:", project_root)
print("Raw data dir:", raw_dir)
print("Processed data dir:", processed_dir)

file_names = {
    "raw_sample": "raw_sample.csv",
    "ad_feature": "ad_feature.csv",
    "user_profile": "user_profile.csv",
    "behavior_log": "behavior_log.csv",
}

# ============================
# 2. Verify files
# ============================
missing = []
for key, fname in file_names.items():
    path = os.path.join(raw_dir, fname)
    exists = os.path.exists(path)
    print(f"{key:15s} -> {path} | exists: {exists}")
    if not exists:
        missing.append(path)

if missing:
    raise FileNotFoundError(
        "Missing required files:\n" + "\n".join(missing)
    )

# ============================
# 3. Load data
#    IMPORTANT: avoid .limit() for modeling datasets unless randomized
# ============================
SAMPLE_ROWS = 500_000
SEED = 42

def load_random_limit(path):
    df = spark.read.csv(path, header=True, inferSchema=True)
    # Randomize then take SAMPLE_ROWS
    return df.orderBy(F.rand(seed=SEED)).limit(SAMPLE_ROWS).cache()

print("\nLoading randomized sample (up to 500k rows each)...")

user_df = load_random_limit(os.path.join(raw_dir, file_names["user_profile"]))
ad_df = load_random_limit(os.path.join(raw_dir, file_names["ad_feature"]))
click_df = load_random_limit(os.path.join(raw_dir, file_names["raw_sample"]))
behavior_df = load_random_limit(os.path.join(raw_dir, file_names["behavior_log"]))

print("User rows:     ", user_df.count())
print("Ad rows:       ", ad_df.count())
print("Click rows:    ", click_df.count())
print("Behavior rows: ", behavior_df.count())

# ============================
# 4. Drop exact duplicates
# ============================
print("\nDropping exact duplicate rows...")

user_df = user_df.dropDuplicates()
ad_df = ad_df.dropDuplicates()
click_df = click_df.dropDuplicates()
behavior_df = behavior_df.dropDuplicates()

print("User rows after dropDuplicates:    ", user_df.count())
print("Ad rows after dropDuplicates:      ", ad_df.count())
print("Click rows after dropDuplicates:   ", click_df.count())
print("Behavior rows after dropDuplicates:", behavior_df.count())

# ============================
# 5. Normalize key columns + schemas (critical for later joins)
# ============================
print("\nNormalizing join keys types...")

# user_profile has userid (double) in your schema -> create "user" int
if "userid" in user_df.columns and "user" not in user_df.columns:
    user_df = user_df.withColumn("user", F.col("userid").cast("int")).drop("userid")
elif "userid" in user_df.columns and "user" in user_df.columns:
    # if both exist, keep user and drop userid
    user_df = user_df.drop("userid")

# click_df user/adgroup_id are int already in your schema, but cast defensively
if "user" in click_df.columns:
    click_df = click_df.withColumn("user", F.col("user").cast("int"))
if "adgroup_id" in click_df.columns:
    click_df = click_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))

# ad_df adgroup_id is double in your schema -> cast to int
if "adgroup_id" in ad_df.columns:
    ad_df = ad_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))

# behavior_df user is int, cast defensively
if "user" in behavior_df.columns:
    behavior_df = behavior_df.withColumn("user", F.col("user").cast("int"))

# Optional: unify brand types
# ad_df brand is string; behavior_df brand is integer (different meaning), keep as-is.

# ============================
# 6. Missing value handling (safer for CTR)
#    - Categorical: fill with "UNKNOWN" (do NOT mode-impute)
#    - Numeric: median + (optional) missing flags for selected numeric cols
# ============================
print("\nHandling missing values safely...")

def fill_numeric_with_median(df, numeric_cols, skip_cols=None):
    if skip_cols is None:
        skip_cols = set()
    out = df
    for c in numeric_cols:
        if c in skip_cols or c not in out.columns:
            continue
        nn = out.filter(F.col(c).isNotNull())
        if nn.limit(1).count() == 0:
            med = 0
        else:
            med = nn.approxQuantile(c, [0.5], 0.01)[0]
        out = out.fillna({c: med})
    return out

def fill_categorical_unknown(df, cat_cols, unknown="UNKNOWN"):
    out = df
    fill_map = {}
    for c in cat_cols:
        if c in out.columns:
            fill_map[c] = unknown
    if fill_map:
        out = out.fillna(fill_map)
    return out

def add_missing_flags(df, cols):
    out = df
    for c in cols:
        if c in out.columns:
            out = out.withColumn(f"{c}_missing", F.col(c).isNull().cast("int"))
    return out

# Identify numeric/categorical columns automatically, but keep control
def split_cols(df):
    num_cols = []
    cat_cols = []
    for c, dtype in df.dtypes:
        field = df.schema[c].dataType
        if isinstance(field, NumericType):
            num_cols.append(c)
        else:
            cat_cols.append(c)
    return num_cols, cat_cols

# USER
user_num, user_cat = split_cols(user_df)
user_df = add_missing_flags(user_df, ["age_level", "shopping_level", "occupation", "final_gender_code"])
user_df = fill_numeric_with_median(user_df, user_num, skip_cols={"user"})
user_df = fill_categorical_unknown(user_df, user_cat, unknown="UNKNOWN")

# AD
ad_num, ad_cat = split_cols(ad_df)
ad_df = add_missing_flags(ad_df, ["price", "cate_id", "campaign_id", "customer"])
ad_df = fill_numeric_with_median(ad_df, ad_num, skip_cols={"adgroup_id"})
ad_df = fill_categorical_unknown(ad_df, ad_cat, unknown="UNKNOWN")

# CLICK
click_num, click_cat = split_cols(click_df)
# keep clk/nonclk/user/adgroup_id; fill other numeric if any
click_df = add_missing_flags(click_df, ["pid"])
click_df = fill_numeric_with_median(click_df, click_num, skip_cols={"user", "adgroup_id"})
click_df = fill_categorical_unknown(click_df, click_cat, unknown="UNKNOWN")

# BEHAVIOR
beh_num, beh_cat = split_cols(behavior_df)
behavior_df = fill_numeric_with_median(behavior_df, beh_num, skip_cols={"user"})
behavior_df = fill_categorical_unknown(behavior_df, beh_cat, unknown="UNKNOWN")

print("Missing handling done.")

# ============================
# 7. Safe timestamp columns
#    - keep time_stamp as integer for time splits later
#    - add time_stamp_str for Windows / CSV friendliness
# ============================
print("\nAdding safe timestamp string columns...")

if "time_stamp" in click_df.columns:
    click_df = click_df.withColumn(
        "time_stamp_str",
        from_unixtime(col("time_stamp").cast("bigint")).cast("string"),
    )

if "time_stamp" in behavior_df.columns:
    behavior_df = behavior_df.withColumn(
        "time_stamp_str",
        from_unixtime(col("time_stamp").cast("bigint")).cast("string"),
    )

# ============================
# 8. Quick sanity checks that affect modeling
#    - label distribution
#    - join key null rates
# ============================
print("\nSanity checks:")

if "clk" in click_df.columns:
    click_df.select(F.mean("clk").alias("clk_rate")).show()

click_df.select(
    F.mean(F.col("user").isNull().cast("int")).alias("user_null_rate"),
    F.mean(F.col("adgroup_id").isNull().cast("int")).alias("adgroup_null_rate"),
).show()

# ============================
# 9. Convert to Pandas and save CSV
#    (keeps your downstream notebook compatible)
# ============================
print("\nConverting Spark DataFrames to Pandas for saving... (can take time)")

user_pdf = user_df.toPandas()
ad_pdf = ad_df.toPandas()
click_pdf = click_df.toPandas()
behavior_pdf = behavior_df.toPandas()

print("Saving cleaned data as CSV...")

user_pdf.to_csv(os.path.join(processed_dir, "user_profile_clean.csv"), index=False)
ad_pdf.to_csv(os.path.join(processed_dir, "ad_feature_clean.csv"), index=False)
click_pdf.to_csv(os.path.join(processed_dir, "raw_sample_clean.csv"), index=False)
behavior_pdf.to_csv(os.path.join(processed_dir, "behavior_log_clean.csv"), index=False)

print("\nSaved cleaned CSVs into:", processed_dir)

# ============================
# 10. Stop Spark
# ============================
spark.stop()
print("\nSpark session stopped.")


Spark version: 4.0.1
Project root: d:\projects\Ai\project_fusion_ecu
Raw data dir: d:\projects\Ai\project_fusion_ecu\data\raw
Processed data dir: d:\projects\Ai\project_fusion_ecu\data\processed
raw_sample      -> d:\projects\Ai\project_fusion_ecu\data\raw\raw_sample.csv | exists: True
ad_feature      -> d:\projects\Ai\project_fusion_ecu\data\raw\ad_feature.csv | exists: True
user_profile    -> d:\projects\Ai\project_fusion_ecu\data\raw\user_profile.csv | exists: True
behavior_log    -> d:\projects\Ai\project_fusion_ecu\data\raw\behavior_log.csv | exists: True

Loading randomized sample (up to 500k rows each)...
User rows:      500000
Ad rows:        500000
Click rows:     500000
Behavior rows:  500000

Dropping exact duplicate rows...
User rows after dropDuplicates:     500000
Ad rows after dropDuplicates:       500000
Click rows after dropDuplicates:    500000
Behavior rows after dropDuplicates: 499964

Normalizing join keys types...

Handling missing values safely...
Missing handlin