# 03 â€“ Exploratory Data Analysis and Data Fusion (PySpark)

This notebook performs exploratory data analysis and demonstrates different fusion strategies.  We investigate distribution of variables, relationships between features and the target, and how to combine data sources.

**Fusion strategies:**
- **Early Fusion**: Join all relevant tables into a single wide table.
- **Hybrid Fusion**: Aggregate behaviour logs into features (counts) and merge them into the click log.

We use PySpark for heavy data processing and Pandas for visualisation on sampled data.


In [6]:
# ============================================================
# Single-file: Fusion + Visual EDA + ready outputs
# Language: English (comments/strings)
# ============================================================
# What this notebook does:
# 1) Read cleaned CSVs from data/processed
# 2) Fusion (Click + User + Ad) + Behavior Aggregation
# 3) Join quality checks (Enrichment Rates)
# 4) Visual EDA to support preprocessing decisions:
#    - Class balance (Bar)
#    - CTR over time (Line)
#    - CTR by hour (Bar)
#    - Join quality (Bar)
#    - CTR by category (Top-N) (Bar)
#    - price vs label (Boxplot)
#    - Missingness (Bar)
#    - Correlation heatmap (Numeric)
#    - (optional) PR Curve + Score Distribution after baseline model if npz files exist
# 5) Save figures in data/processed/figures

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import os
import pandas as pd
import matplotlib.pyplot as plt

# seaborn is optional for visuals
try:
    import seaborn as sns
    sns.set_theme(style="whitegrid")
    HAS_SNS = True
except Exception:
    HAS_SNS = False


# ============================================================
# 0) General settings
# ============================================================
SEED = 42
SAMPLE_SIZE = 20000          # sample size for quick EDA in Pandas
TOP_N = 15                   # top N categories for cate/brand
FIG_DPI = 140

# ============================================================
# 1) Spark Session
# ============================================================
spark = (
    SparkSession.builder
        .appName("CTR_Fusion_Visual_EDA")
        .master("local[4]")
        .config("spark.sql.shuffle.partitions", "100")
        .config("spark.driver.memory", "6g")
        .getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
print("Spark version:", spark.version)

# ============================================================
# 2) Paths
# ============================================================
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
processed_dir = os.path.join(project_root, "data", "processed")
fig_dir = os.path.join(processed_dir, "figures")
os.makedirs(fig_dir, exist_ok=True)

print("Project root:", project_root)
print("Processed dir:", processed_dir)
print("Figures dir:", fig_dir)

# ============================================================
# 3) Load cleaned CSVs
# ============================================================
user_df = spark.read.csv(os.path.join(processed_dir, "user_profile_clean.csv"), header=True, inferSchema=True)
ad_df = spark.read.csv(os.path.join(processed_dir, "ad_feature_clean.csv"), header=True, inferSchema=True)
click_df = spark.read.csv(os.path.join(processed_dir, "raw_sample_clean.csv"), header=True, inferSchema=True)
behavior_df = spark.read.csv(os.path.join(processed_dir, "behavior_log_clean.csv"), header=True, inferSchema=True)

def strip_columns(df):
    return df.toDF(*[c.strip() for c in df.columns])

user_df = strip_columns(user_df)
ad_df = strip_columns(ad_df)
click_df = strip_columns(click_df)
behavior_df = strip_columns(behavior_df)

def ensure_user_col(df, candidates=("user", "userid", "nick")):
    for c in candidates:
        if c in df.columns:
            if c != "user":
                return df.withColumnRenamed(c, "user")
            return df
    return df

user_df = ensure_user_col(user_df)
click_df = ensure_user_col(click_df)
behavior_df = ensure_user_col(behavior_df)

# ============================================================
# 4) Unify join key types (important)
# ============================================================
if "user" in user_df.columns:
    user_df = user_df.withColumn("user", F.col("user").cast("int"))
if "user" in click_df.columns:
    click_df = click_df.withColumn("user", F.col("user").cast("int"))
if "user" in behavior_df.columns:
    behavior_df = behavior_df.withColumn("user", F.col("user").cast("int"))

if "adgroup_id" in ad_df.columns:
    ad_df = ad_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))
if "adgroup_id" in click_df.columns:
    click_df = click_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))

print("Rows:",
      "user_df", user_df.count(),
      "ad_df", ad_df.count(),
      "click_df", click_df.count(),
      "behavior_df", behavior_df.count()
)

# ============================================================
# 5) Fusion: Click + User + Ad (avoid duplicate columns)
# ============================================================
c = click_df.alias("c")
u = user_df.alias("u")
a = ad_df.alias("a")

click_cols = [F.col(f"c.{x}") for x in click_df.columns]
user_cols_except_user = [F.col(f"u.{x}") for x in user_df.columns if x != "user"]

click_user_df = (
    c.join(u, F.col("c.user") == F.col("u.user"), how="left")
     .select(*click_cols, *user_cols_except_user)
)

cu = click_user_df.alias("cu")
cu_cols_except_adgroup = [F.col(f"cu.{x}") for x in click_user_df.columns if x != "adgroup_id"]
ad_cols_except_key = [F.col(f"a.{x}") for x in ad_df.columns if x != "adgroup_id"]

full_df = (
    cu.join(a, F.col("cu.adgroup_id") == F.col("a.adgroup_id"), how="left")
      .select(
          *cu_cols_except_adgroup,
          F.col("cu.adgroup_id").alias("adgroup_id"),
          *ad_cols_except_key
      )
)

print("After Early Fusion:",
      "rows:", full_df.count(),
      "cols:", len(full_df.columns)
)

# ============================================================
# 6) Behavior Aggregation (Pivot)
# Note: this aggregation is global per-user (not time-aware).
# For training, prefer time-aware aggregations, but we use this as general signal.
# ============================================================
if "btag" in behavior_df.columns and "user" in behavior_df.columns:
    behaviour_counts = (
        behavior_df.groupBy("user")
        .pivot("btag")
        .agg(F.count("btag"))
        .fillna(0)
    )

    f = full_df.alias("f")
    b = behaviour_counts.alias("b")
    beh_cols_except_user = [F.col(f"b.{x}") for x in behaviour_counts.columns if x != "user"]

    full_df = (
        f.join(b, F.col("f.user") == F.col("b.user"), how="left")
         .select(*[F.col(f"f.{x}") for x in full_df.columns], *beh_cols_except_user)
    )

    for x in behaviour_counts.columns:
        if x != "user" and x in full_df.columns:
            full_df = full_df.fillna({x: 0})

    print("After Hybrid Fusion (Behavior):",
          "rows:", full_df.count(),
          "cols:", len(full_df.columns)
    )
else:
    print("Warning: cannot perform Behavior Pivot (btag/user missing).")

# ============================================================
# 7) Add label and time features (important for preprocessing)
# ============================================================
if "clk" in full_df.columns:
    full_df = full_df.withColumn("label", F.col("clk").cast("int"))
else:
    print("Warning: clk column missing - cannot compute CTR precisely.")

# time features
if "time_stamp" in full_df.columns:
    # hour of day (0-23) and day number
    full_df = full_df.withColumn("hour", (F.col("time_stamp") / 3600).cast("long") % 24)
    full_df = full_df.withColumn("day", (F.col("time_stamp") / 86400).cast("long"))
else:
    print("Warning: time_stamp missing - time-based visuals will be skipped.")

# ============================================================
# 8) Compute join quality (Enrichment Rates)
# ============================================================
join_checks = []
if "age_level" in full_df.columns:
    join_checks.append(("User join (age_level not null)", "age_level"))
if "price" in full_df.columns:
    join_checks.append(("Ad join (price not null)", "price"))
if "cate_id" in full_df.columns:
    join_checks.append(("Ad join (cate_id not null)", "cate_id"))

enrichment = []
total_rows = full_df.count()

for label_name, col_name in join_checks:
    nn = full_df.filter(F.col(col_name).isNotNull()).count()
    enrichment.append((label_name, nn / total_rows if total_rows else 0.0))

print("\nEnrichment rates after joins:")
for x in enrichment:
    print("-", x[0], ":", round(x[1]*100, 2), "%")

# ============================================================
# 9) Take random sample for Pandas EDA (reduce memory use)
# ============================================================
sample_sdf = full_df.orderBy(F.rand(seed=SEED)).limit(SAMPLE_SIZE)
sample_pdf = sample_sdf.toPandas()
print("Sample size for EDA:", sample_pdf.shape)

# ============================================================
# 10) Helper plotting utils
# ============================================================
def save_fig(name):
    path = os.path.join(fig_dir, name)
    plt.tight_layout()
    plt.savefig(path, dpi=FIG_DPI)
    print("Saved figure:", path)
    plt.close()


def safe_has_cols(pdf, cols):
    return all(c in pdf.columns for c in cols)

# ============================================================
# 11) Visual 1: Class balance (Click vs No-Click)
# ============================================================
if "label" in sample_pdf.columns:
    counts = sample_pdf["label"].value_counts().sort_index()
    plt.figure(figsize=(6, 4))
    plt.bar(counts.index.astype(str), counts.values)
    plt.title("Target distribution: Click vs No-Click")
    plt.xlabel("label (0=No Click, 1=Click)")
    plt.ylabel("count")
    save_fig("01_label_distribution.png")
else:
    print("Skipping: label missing.")

# ============================================================
# 12) Visual 2: CTR over time (daily)
# ============================================================
if safe_has_cols(sample_pdf, ["day", "label"]):
    tmp = sample_pdf.dropna(subset=["day", "label"]).copy()
    ctr_by_day = tmp.groupby("day")["label"].mean().reset_index()
    ctr_by_day = ctr_by_day.sort_values("day")

    plt.figure(figsize=(10, 4))
    plt.plot(ctr_by_day["day"], ctr_by_day["label"])
    plt.title("CTR over time (daily)")
    plt.xlabel("day (derived from time_stamp)")
    plt.ylabel("CTR = mean(label)")
    save_fig("02_ctr_over_time_day.png")
else:
    print("Skipping: day/label missing.")

# ============================================================
# 13) Visual 3: CTR by hour (Bar)
# ============================================================
if safe_has_cols(sample_pdf, ["hour", "label"]):
    tmp = sample_pdf.dropna(subset=["hour", "label"]).copy()
    ctr_by_hour = tmp.groupby("hour")["label"].mean().reset_index().sort_values("hour")

    plt.figure(figsize=(10, 4))
    plt.bar(ctr_by_hour["hour"].astype(int), ctr_by_hour["label"])
    plt.title("CTR by hour")
    plt.xlabel("hour (0-23)")
    plt.ylabel("CTR")
    save_fig("03_ctr_by_hour.png")
else:
    print("Skipping: hour/label missing.")

# ============================================================
# 14) Visual 4: Join enrichment (Bar)
# ============================================================
if enrichment:
    names = [x[0] for x in enrichment]
    vals = [x[1] for x in enrichment]

    plt.figure(figsize=(10, 4))
    plt.barh(names, vals)
    plt.title("Join enrichment rates")
    plt.xlabel("Enrichment Rate")
    plt.xlim(0, 1)
    save_fig("04_join_enrichment_rates.png")

# ============================================================
# 15) Visual 5: CTR by cate_id (Top-N by count)
# ============================================================
def plot_ctr_topN(pdf, col_cat, top_n=TOP_N, fname="ctr_top.png", title=""):
    if not safe_has_cols(pdf, [col_cat, "label"]):
        print(f"Skipping: {col_cat} or label missing.")
        return

    tmp = pdf[[col_cat, "label"]].dropna()
    # top categories by volume to avoid misleading small-sample rates
    vc = tmp[col_cat].value_counts().head(top_n)
    top_cats = set(vc.index.tolist())
    tmp2 = tmp[tmp[col_cat].isin(top_cats)]

    ctr = tmp2.groupby(col_cat)["label"].mean()
    cnt = tmp2.groupby(col_cat)["label"].size()

    out = pd.DataFrame({"count": cnt, "ctr": ctr}).reset_index()
    out = out.sort_values("count", ascending=False)

    plt.figure(figsize=(12, 5))
    if HAS_SNS:
        sns.barplot(data=out, x=col_cat, y="ctr")
    else:
        plt.bar(out[col_cat].astype(str), out["ctr"].values)

    plt.title(title)
    plt.xlabel(col_cat)
    plt.ylabel("CTR (mean label)")
    plt.xticks(rotation=45, ha="right")
    save_fig(fname)

if "cate_id" in sample_pdf.columns:
    plot_ctr_topN(
        sample_pdf,
        "cate_id",
        top_n=TOP_N,
        fname="05_ctr_by_cate_topN.png",
        title=f"CTR by cate_id (Top {TOP_N} by count)"
    )

if "brand" in sample_pdf.columns:
    plot_ctr_topN(
        sample_pdf,
        "brand",
        top_n=TOP_N,
        fname="06_ctr_by_brand_topN.png",
        title=f"CTR by brand (Top {TOP_N} by count)"
    )

# ============================================================
# 16) Visual 6: price vs label (Boxplot)
# ============================================================
if safe_has_cols(sample_pdf, ["price", "label"]):
    tmp = sample_pdf[["price", "label"]].dropna()
    plt.figure(figsize=(6, 5))
    if HAS_SNS:
        sns.boxplot(x="label", y="price", data=tmp)
    else:
        # simple fallback without seaborn
        g0 = tmp[tmp["label"] == 0]["price"].values
        g1 = tmp[tmp["label"] == 1]["price"].values
        plt.boxplot([g0, g1], labels=["0", "1"])
        plt.xlabel("label")

    plt.title("Price distribution by click")
    plt.ylabel("price")
    save_fig("07_price_vs_label_boxplot.png")
else:
    print("Skipping: price/label missing.")

# ============================================================
# 17) Visual 7: Missingness per column
# ============================================================
missing_rate = (sample_pdf.isna().mean().sort_values(ascending=False))
missing_rate = missing_rate[missing_rate > 0]

if len(missing_rate) > 0:
    plt.figure(figsize=(10, 6))
    plt.barh(missing_rate.index[::-1], missing_rate.values[::-1])
    plt.title("Missingness rate per column")
    plt.xlabel("Missing Rate")
    save_fig("08_missingness_rates.png")
else:
    print("No notable missing values in the sample.")

# ============================================================
# 18) Visual 8: Correlation heatmap for numeric features
# ============================================================
numeric_cols = [c for c in sample_pdf.columns if pd.api.types.is_numeric_dtype(sample_pdf[c])]
if len(numeric_cols) >= 2:
    corr = sample_pdf[numeric_cols].corr()
    plt.figure(figsize=(12, 10))
    if HAS_SNS:
        sns.heatmap(corr, annot=False, center=0, cmap="coolwarm")
    else:
        plt.imshow(corr.values)
        plt.colorbar()
        plt.xticks(range(len(corr.columns)), corr.columns, rotation=90)
        plt.yticks(range(len(corr.columns)), corr.columns)

    plt.title("Correlation heatmap for numeric features")
    save_fig("09_numeric_correlation_heatmap.png")
else:
    print("Skipping: not enough numeric columns for correlation.")

# ============================================================
# 19) (optional) Save fused data + temporal train/test split
# ============================================================
if "time_stamp" in full_df.columns:
    cutoff = full_df.approxQuantile("time_stamp", [0.8], 0.001)[0]
    train_df = full_df.filter(F.col("time_stamp") <= cutoff)
    test_df  = full_df.filter(F.col("time_stamp") > cutoff)

    print("\nTemporal split 80/20:")
    print("cutoff time_stamp:", cutoff)
    print("train rows:", train_df.count())
    print("test rows :", test_df.count())

    # Save CSV (use Pandas to avoid Windows/Hadoop issues)
    train_path = os.path.join(processed_dir, "fused_train.csv")
    test_path = os.path.join(processed_dir, "fused_test.csv")

    train_df.toPandas().to_csv(train_path, index=False)
    test_df.toPandas().to_csv(test_path, index=False)
    print("Saved:", train_path)
    print("Saved:", test_path)
else:
    fused_path = os.path.join(processed_dir, "fused_data.csv")
    full_df.toPandas().to_csv(fused_path, index=False)
    print("Saved fused_data:", fused_path)

# ============================================================
# 20) (optional) PR Curve + Score Distribution if processed training files exist
# ============================================================
try:
    from scipy import sparse
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import precision_recall_curve, average_precision_score

    xtr_path = os.path.join(processed_dir, "X_train_processed.npz")
    xte_path = os.path.join(processed_dir, "X_test_processed.npz")
    ytr_path = os.path.join(processed_dir, "y_train.csv")
    yte_path = os.path.join(processed_dir, "y_test.csv")

    if all(os.path.exists(p) for p in [xtr_path, xte_path, ytr_path, yte_path]):
        X_train = sparse.load_npz(xtr_path)
        X_test  = sparse.load_npz(xte_path)
        y_train = pd.read_csv(ytr_path).squeeze().astype(int)
        y_test  = pd.read_csv(yte_path).squeeze().astype(int)

        model = LogisticRegression(max_iter=3000, n_jobs=-1)
        model.fit(X_train, y_train)
        probas = model.predict_proba(X_test)[:, 1]

        precision, recall, _ = precision_recall_curve(y_test, probas)
        ap = average_precision_score(y_test, probas)

        plt.figure(figsize=(7, 5))
        plt.plot(recall, precision)
        plt.title(f"Precision-Recall Curve (Baseline LR) | AP={ap:.4f}")
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        save_fig("10_precision_recall_curve.png")

        # score distribution for pos vs neg
        plt.figure(figsize=(8, 4))
        plt.hist(probas[y_test == 0], bins=50, alpha=0.7, label="No Click (0)")
        plt.hist(probas[y_test == 1], bins=50, alpha=0.7, label="Click (1)")
        plt.title("Predicted probability distribution (Baseline LR)")
        plt.xlabel("Predicted Probability")
        plt.ylabel("Count")
        plt.legend()
        save_fig("11_score_distribution.png")

        print("PR Curve and Score Distribution created (optional).")
    else:
        print("Skipping PR Curve: processed npz/y files missing.")
except Exception as e:
    print("Skipping PR Curve due to error:", str(e))

# ============================================================
# 21) Stop Spark
# ============================================================
spark.stop()
print("Spark stopped.")

Spark version: 4.0.1
Project root: d:\projects\Ai\project_fusion_ecu
Processed dir: d:\projects\Ai\project_fusion_ecu\data\processed
Figures dir: d:\projects\Ai\project_fusion_ecu\data\processed\figures
Rows: user_df 500000 ad_df 500000 click_df 500000 behavior_df 499964
After Early Fusion: rows: 500000 cols: 21
After Hybrid Fusion (Behavior): rows: 500000 cols: 25

Enrichment rates after joins:
- User join (age_level not null) : 44.28 %
- Ad join (price not null) : 59.56 %
- Ad join (cate_id not null) : 59.56 %
Sample size for EDA: (20000, 28)
Saved figure: d:\projects\Ai\project_fusion_ecu\data\processed\figures\01_label_distribution.png
Saved figure: d:\projects\Ai\project_fusion_ecu\data\processed\figures\02_ctr_over_time_day.png
Saved figure: d:\projects\Ai\project_fusion_ecu\data\processed\figures\03_ctr_by_hour.png
Saved figure: d:\projects\Ai\project_fusion_ecu\data\processed\figures\04_join_enrichment_rates.png
Saved figure: d:\projects\Ai\project_fusion_ecu\data\processed\fig