# 02 – Data Cleaning (PySpark)

In this notebook we clean and preprocess the data. We load the raw CSV files using PySpark (limiting to 1 million rows), drop duplicates, handle missing values, and perform outlier detection. Cleaned data are saved to the `data/processed` directory for later use.

Steps:
1. Load data from `data/raw` using PySpark.
2. Drop duplicate rows.
3. Fill missing values using appropriate strategies (median for numeric, mode for categorical).
4. Convert timestamp columns to datetime if needed.
5. Save cleaned data as Parquet for efficient downstream processing.


In [3]:
"""""
1) Read raw CSV files from data/raw
2) Take a RANDOM sample up to 500K rows (efficient, no heavy shuffle)
3) Remove duplicate rows
4) Normalize join keys (user, adgroup_id)
5) Handle missing values:
   - Numeric  -> median
   - Categorical -> "UNKNOWN"
   - Add *_missing flags for important columns
6) Add readable timestamp columns (if time_stamp exists)
7) Run basic sanity checks
8) Save cleaned CSVs into data/processed

"""

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_unixtime
from pyspark.sql.types import NumericType
from pyspark.sql import functions as F
import os
import shutil
import pandas as pd

SAMPLE_ROWS = 500_000
SEED = 42

# Spark session
spark = (
    SparkSession.builder
        .appName("CTR_Project_Cleaning")
        .master("local[4]")
        .config("spark.driver.memory", "6g")
        .config("spark.sql.shuffle.partitions", "200")
        .getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
print("Spark version:", spark.version)

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
raw_dir = os.path.join(project_root, "data", "raw")
processed_dir = os.path.join(project_root, "data", "processed")
os.makedirs(processed_dir, exist_ok=True)

file_names = {
    "raw_sample": "raw_sample.csv",
    "ad_feature": "ad_feature.csv",
    "user_profile": "user_profile.csv",
    "behavior_log": "behavior_log.csv",
}

def safe_read_csv(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing file: {path}")
    return spark.read.csv(path, header=True, inferSchema=True)

def load_random_sample_limit(path, limit_rows=SAMPLE_ROWS, seed=SEED):

    df = safe_read_csv(path)
    total = df.count()

    if total == 0:
        return df

    if total <= limit_rows:
        return df

    frac = min(1.0, (limit_rows / float(total)) * 1.25)
    return df.sample(False, frac, seed).limit(limit_rows)

def split_cols(df):
    num_cols, cat_cols = [], []
    for c in df.columns:
        if isinstance(df.schema[c].dataType, NumericType):
            num_cols.append(c)
        else:
            cat_cols.append(c)
    return num_cols, cat_cols

def fill_numeric_with_median(df, numeric_cols, skip_cols=None, rel_err=0.01):
    skip_cols = skip_cols or set()
    out = df

    for c in numeric_cols:
        if c in skip_cols or c not in out.columns:
            continue

        has_any = out.select(F.max(F.col(c).isNotNull().cast("int"))).collect()[0][0]
        med = 0 if not has_any else out.approxQuantile(c, [0.5], rel_err)[0]
        out = out.fillna({c: med})

    return out

def fill_categorical_unknown(df, cat_cols, unknown="UNKNOWN"):
    fill_map = {c: unknown for c in cat_cols if c in df.columns}
    return df.fillna(fill_map) if fill_map else df

def add_missing_flags(df, cols):
    for c in cols:
        if c in df.columns:
            df = df.withColumn(f"{c}_missing", F.col(c).isNull().cast("int"))
    return df

def normalize_keys(user_df, ad_df, click_df, behavior_df):
    # user_profile: userid -> user
    if "userid" in user_df.columns and "user" not in user_df.columns:
        user_df = user_df.withColumn("user", F.col("userid").cast("int")).drop("userid")
    elif "userid" in user_df.columns:
        user_df = user_df.drop("userid")

    # click
    if "user" in click_df.columns:
        click_df = click_df.withColumn("user", F.col("user").cast("int"))
    if "adgroup_id" in click_df.columns:
        click_df = click_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))

    # ad
    if "adgroup_id" in ad_df.columns:
        ad_df = ad_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))

    # behavior
    if "user" in behavior_df.columns:
        behavior_df = behavior_df.withColumn("user", F.col("user").cast("int"))

    return user_df, ad_df, click_df, behavior_df

# Helpers – timestamp
def add_timestamp_strings(click_df, behavior_df):
    if "time_stamp" in click_df.columns:
        click_df = click_df.withColumn(
            "time_stamp_str",
            from_unixtime(col("time_stamp").cast("bigint"))
        )
    if "time_stamp" in behavior_df.columns:
        behavior_df = behavior_df.withColumn(
            "time_stamp_str",
            from_unixtime(col("time_stamp").cast("bigint"))
        )
    return click_df, behavior_df

# Helpers – safe save (Windows friendly)
def save_df_safely(df, out_csv_path, chunk_rows=200_000):

    tmp_dir = out_csv_path + "_tmp"

    try:
        if os.path.exists(tmp_dir):
            shutil.rmtree(tmp_dir, ignore_errors=True)

        df.coalesce(1).write.mode("overwrite").option("header", True).csv(tmp_dir)

        part_file = next(
            os.path.join(tmp_dir, f)
            for f in os.listdir(tmp_dir)
            if f.startswith("part-") and f.endswith(".csv")
        )

        if os.path.exists(out_csv_path):
            os.remove(out_csv_path)

        shutil.move(part_file, out_csv_path)
        shutil.rmtree(tmp_dir, ignore_errors=True)
        print(f"Saved with Spark: {out_csv_path}")
        return

    except Exception as e:
        print("Spark write failed, using Pandas fallback:", str(e)[:120])

    # Pandas fallback (chunked)
    cols = df.columns
    buffer = []
    wrote_header = False

    if os.path.exists(out_csv_path):
        os.remove(out_csv_path)

    for row in df.toLocalIterator():
        buffer.append([row[c] for c in cols])
        if len(buffer) >= chunk_rows:
            pd.DataFrame(buffer, columns=cols).to_csv(
                out_csv_path,
                mode="a" if wrote_header else "w",
                header=not wrote_header,
                index=False,
            )
            wrote_header = True
            buffer = []

    if buffer:
        pd.DataFrame(buffer, columns=cols).to_csv(
            out_csv_path,
            mode="a" if wrote_header else "w",
            header=not wrote_header,
            index=False,
        )

    print(f"Saved with Pandas fallback: {out_csv_path}")

# Load data
print("\nLoading randomized samples...")

user_df = load_random_sample_limit(os.path.join(raw_dir, file_names["user_profile"]))
ad_df = load_random_sample_limit(os.path.join(raw_dir, file_names["ad_feature"]))
click_df = load_random_sample_limit(os.path.join(raw_dir, file_names["raw_sample"]))
behavior_df = load_random_sample_limit(os.path.join(raw_dir, file_names["behavior_log"]))

# Remove duplicates
user_df = user_df.dropDuplicates()
ad_df = ad_df.dropDuplicates()
click_df = click_df.dropDuplicates()
behavior_df = behavior_df.dropDuplicates()

# Normalize keys
user_df, ad_df, click_df, behavior_df = normalize_keys(
    user_df, ad_df, click_df, behavior_df
)

# Missing value handling
user_num, user_cat = split_cols(user_df)
user_df = add_missing_flags(user_df, ["age_level", "shopping_level", "occupation", "final_gender_code"])
user_df = fill_numeric_with_median(user_df, user_num, skip_cols={"user"})
user_df = fill_categorical_unknown(user_df, user_cat)

ad_num, ad_cat = split_cols(ad_df)
ad_df = add_missing_flags(ad_df, ["price", "cate_id", "campaign_id", "customer"])
ad_df = fill_numeric_with_median(ad_df, ad_num, skip_cols={"adgroup_id"})
ad_df = fill_categorical_unknown(ad_df, ad_cat)

click_num, click_cat = split_cols(click_df)
click_df = add_missing_flags(click_df, ["pid"])
click_df = fill_numeric_with_median(click_df, click_num, skip_cols={"user", "adgroup_id"})
click_df = fill_categorical_unknown(click_df, click_cat)

beh_num, beh_cat = split_cols(behavior_df)
behavior_df = fill_numeric_with_median(behavior_df, beh_num, skip_cols={"user"})
behavior_df = fill_categorical_unknown(behavior_df, beh_cat)

# Timestamp strings
click_df, behavior_df = add_timestamp_strings(click_df, behavior_df)

# Sanity checks
if "clk" in click_df.columns:
    click_df.select(F.mean("clk").alias("click_rate")).show()

# Save cleaned CSVs
print("\nSaving cleaned CSV files...")

save_df_safely(user_df, os.path.join(processed_dir, "user_profile_clean.csv"))
save_df_safely(ad_df, os.path.join(processed_dir, "ad_feature_clean.csv"))
save_df_safely(click_df, os.path.join(processed_dir, "raw_sample_clean.csv"))
save_df_safely(behavior_df, os.path.join(processed_dir, "behavior_log_clean.csv"))

spark.stop()
print("\nSpark session stopped.")


Spark version: 4.0.1

Loading randomized samples...
+----------+
|click_rate|
+----------+
|   0.05148|
+----------+


Saving cleaned CSV files...
Spark write failed, using Pandas fallback: An error occurred while calling o1823.csv.
: java.lang.RuntimeException: java.io.FileNotFoundException: java.io.FileNotF
Saved with Pandas fallback: d:\projects\Ai\project_fusion_ecu\data\processed\user_profile_clean.csv
Spark write failed, using Pandas fallback: An error occurred while calling o1833.csv.
: java.lang.RuntimeException: java.io.FileNotFoundException: java.io.FileNotF
Saved with Pandas fallback: d:\projects\Ai\project_fusion_ecu\data\processed\ad_feature_clean.csv
Spark write failed, using Pandas fallback: An error occurred while calling o1843.csv.
: java.lang.RuntimeException: java.io.FileNotFoundException: java.io.FileNotF
Saved with Pandas fallback: d:\projects\Ai\project_fusion_ecu\data\processed\raw_sample_clean.csv
Spark write failed, using Pandas fallback: An error occurred whil