# 03 â€“ Exploratory Data Analysis and Data Fusion (PySpark)

This notebook performs exploratory data analysis and demonstrates different fusion strategies.  We investigate distribution of variables, relationships between features and the target, and how to combine data sources.

**Fusion strategies:**
- **Early Fusion**: Join all relevant tables into a single wide table.
- **Hybrid Fusion**: Aggregate behaviour logs into features (counts) and merge them into the click log.

We use PySpark for heavy data processing and Pandas for visualisation on sampled data.


In [None]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt

from pyspark.sql import SparkSession
from pyspark.sql import functions as F


os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
os.environ["SPARK_LOCAL_IP"] = "127.0.0.1"

# 1) Settings
SEED = 42

# EDA sampling only 
SAMPLE_SIZE_EDA = 20_000

MAX_FUSION_ROWS = 50_000     # start safe;

# Behavior lookback window before each click time
LOOKBACK_DAYS = 7
LOOKBACK_SECONDS = LOOKBACK_DAYS * 86400

TOP_N = 15
FIG_DPI = 140

# 2) Spark Session
spark = (
    SparkSession.builder
        .appName("CTR_Fusion_Visual_EDA_Fixed")
        .master("local[2]")  # safer than local[4]
        .config("spark.driver.memory", "6g")
        .config("spark.sql.shuffle.partitions", "60")
        .config("spark.python.worker.reuse", "true")
        .config("spark.network.timeout", "800s")
        .config("spark.executor.heartbeatInterval", "60s")
        .config("spark.sql.execution.arrow.pyspark.enabled", "false")
        .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN") # verbose: INFO, WARN, ERROR
print("Spark version:", spark.version)

# 3) Paths
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
processed_dir = os.path.join(project_root, "data", "processed")
fig_dir = os.path.join(processed_dir, "figures")
os.makedirs(fig_dir, exist_ok=True)



# 4) Load cleaned CSVs
user_df = spark.read.csv(os.path.join(processed_dir, "user_profile_clean.csv"), header=True, inferSchema=True)
ad_df = spark.read.csv(os.path.join(processed_dir, "ad_feature_clean.csv"), header=True, inferSchema=True)
click_df = spark.read.csv(os.path.join(processed_dir, "raw_sample_clean.csv"), header=True, inferSchema=True)
behavior_df = spark.read.csv(os.path.join(processed_dir, "behavior_log_clean.csv"), header=True, inferSchema=True)

def strip_columns(df):
    return df.toDF(*[c.strip() for c in df.columns])

user_df = strip_columns(user_df)
ad_df = strip_columns(ad_df)
click_df = strip_columns(click_df)
behavior_df = strip_columns(behavior_df)

def ensure_user_col(df, candidates=("user", "userid", "nick")):
    for c in candidates:
        if c in df.columns:
            if c != "user":
                return df.withColumnRenamed(c, "user")
            return df
    return df

user_df = ensure_user_col(user_df)
click_df = ensure_user_col(click_df)
behavior_df = ensure_user_col(behavior_df)

# 5) Unify key types 
if "user" in user_df.columns:
    user_df = user_df.withColumn("user", F.col("user").cast("int"))
if "user" in click_df.columns:
    click_df = click_df.withColumn("user", F.col("user").cast("int"))
if "user" in behavior_df.columns:
    behavior_df = behavior_df.withColumn("user", F.col("user").cast("int"))

if "adgroup_id" in ad_df.columns:
    ad_df = ad_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))
if "adgroup_id" in click_df.columns:
    click_df = click_df.withColumn("adgroup_id", F.col("adgroup_id").cast("int"))

# 6) Cap click_df size safely 
if "time_stamp" in click_df.columns:
    click_df = click_df.withColumn("time_stamp", F.col("time_stamp").cast("long"))
    # take latest N then reorder chronologically
    click_df = (
        click_df.orderBy(F.col("time_stamp").desc())
                .limit(MAX_FUSION_ROWS)
                .orderBy("time_stamp")
    )
else:
    click_df = click_df.orderBy(F.rand(seed=SEED)).limit(MAX_FUSION_ROWS)

print("\nclick_df capped (no count). Showing 3 rows:")
click_df.show(3, truncate=False)

# 7) Early Fusion
c = click_df.alias("c")
u = user_df.alias("u")
a = ad_df.alias("a")

click_cols = [F.col(f"c.{x}") for x in click_df.columns]
user_cols_except_user = [F.col(f"u.{x}") for x in user_df.columns if x != "user"]

click_user_df = (
    c.join(u, F.col("c.user") == F.col("u.user"), how="left")
     .select(*click_cols, *user_cols_except_user)
)

cu = click_user_df.alias("cu")
cu_cols_except_adgroup = [F.col(f"cu.{x}") for x in click_user_df.columns if x != "adgroup_id"]
ad_cols_except_key = [F.col(f"a.{x}") for x in ad_df.columns if x != "adgroup_id"]

full_df = (
    cu.join(a, F.col("cu.adgroup_id") == F.col("a.adgroup_id"), how="left")
      .select(
          *cu_cols_except_adgroup,
          F.col("cu.adgroup_id").alias("adgroup_id"),
          *ad_cols_except_key
      )
)

print("\nAfter Early Fusion (no count). Columns:", len(full_df.columns))


can_time_aware = (
    ("time_stamp" in full_df.columns) and
    all(x in behavior_df.columns for x in ["user", "time_stamp", "btag"])
)

if can_time_aware:
    print("\nBuilding time-aware behavior features...")

    beh = behavior_df.select(
        F.col("user").cast("int").alias("b_user"),
        F.col("time_stamp").cast("long").alias("b_time"),
        F.col("btag").cast("string").alias("btag"),
    )

    f = full_df.select(
        "*",
        F.col("user").cast("int").alias("c_user"),
        F.col("time_stamp").cast("long").alias("c_time"),
    )

    # stable per-row id
    f = f.withColumn("row_id", F.monotonically_increasing_id())

    joined = (
        f.join(
            beh,
            (F.col("c_user") == F.col("b_user")) &
            (F.col("b_time") < F.col("c_time")) &
            (F.col("b_time") >= (F.col("c_time") - F.lit(LOOKBACK_SECONDS))),
            how="left"
        )
    )

    # adjust these if your btag values differ
    pv_tags = ["pv", "pageview"]
    cart_tags = ["cart"]
    fav_tags = ["fav", "favorite"]
    buy_tags = ["buy"]

    agg = (
        joined.groupBy("row_id")
              .agg(
                  F.sum(F.when(F.col("btag").isin(pv_tags), 1).otherwise(0)).alias(f"pv_{LOOKBACK_DAYS}d"),
                  F.sum(F.when(F.col("btag").isin(cart_tags), 1).otherwise(0)).alias(f"cart_{LOOKBACK_DAYS}d"),
                  F.sum(F.when(F.col("btag").isin(fav_tags), 1).otherwise(0)).alias(f"fav_{LOOKBACK_DAYS}d"),
                  F.sum(F.when(F.col("btag").isin(buy_tags), 1).otherwise(0)).alias(f"buy_{LOOKBACK_DAYS}d"),
              )
    )

    full_df = (
        f.join(agg, on="row_id", how="left")
         .drop("c_user", "c_time", "b_user", "b_time", "btag")  # safe even if not present
         .fillna({
             f"pv_{LOOKBACK_DAYS}d": 0,
             f"cart_{LOOKBACK_DAYS}d": 0,
             f"fav_{LOOKBACK_DAYS}d": 0,
             f"buy_{LOOKBACK_DAYS}d": 0,
         })
    )

    print("Time-aware behavior added. Columns:", len(full_df.columns))
else:
    print("\nWarning: cannot compute time-aware behavior (missing columns).")

# 9) Add label + time features
if "clk" in full_df.columns:
    full_df = full_df.withColumn("label", F.col("clk").cast("int"))
else:
    print("Warning: clk missing; label will not be created.")

if "time_stamp" in full_df.columns:
    full_df = full_df.withColumn("hour", (F.col("time_stamp") / 3600).cast("long") % 24)
    full_df = full_df.withColumn("day", (F.col("time_stamp") / 86400).cast("long"))
else:
    print("Warning: time_stamp missing; hour/day will not be created.")

# 10) Join enrichment checks 
join_checks = []
if "age_level" in full_df.columns:
    join_checks.append(("User join (age_level not null)", "age_level"))
if "price" in full_df.columns:
    join_checks.append(("Ad join (price not null)", "price"))
if "cate_id" in full_df.columns:
    join_checks.append(("Ad join (cate_id not null)", "cate_id"))

print("\nEnrichment checks (using sample to avoid heavy counts):")
enrichment = []
qc_sample = full_df.orderBy(F.rand(seed=SEED)).limit(20_000)
qc_pdf = qc_sample.toPandas()

for label_name, col_name in join_checks:
    rate = float((~qc_pdf[col_name].isna()).mean()) if col_name in qc_pdf.columns else 0.0
    enrichment.append((label_name, rate))
    print("-", label_name, ":", round(rate * 100, 2), "%")

# 11) Sample for EDA plots
sample_sdf = full_df.orderBy(F.rand(seed=SEED)).limit(SAMPLE_SIZE_EDA)
sample_pdf = sample_sdf.toPandas()
print("\nEDA sample shape:", sample_pdf.shape)

def save_fig(name):
    path = os.path.join(fig_dir, name)
    plt.tight_layout()
    plt.savefig(path, dpi=FIG_DPI)
    print("Saved figure:", path)
    plt.close()

def safe_has_cols(pdf, cols):
    return all(c in pdf.columns for c in cols)

# 1) Class balance
if "label" in sample_pdf.columns:
    counts = sample_pdf["label"].value_counts().sort_index()
    plt.figure(figsize=(6, 4))
    plt.bar(counts.index.astype(str), counts.values)
    plt.title("Target distribution: Click vs No-Click")
    plt.xlabel("label (0=No Click, 1=Click)")
    plt.ylabel("count")
    save_fig("01_label_distribution.png")

# 2) CTR by day
if safe_has_cols(sample_pdf, ["day", "label"]):
    tmp = sample_pdf.dropna(subset=["day", "label"]).copy()
    ctr_by_day = tmp.groupby("day")["label"].mean().reset_index().sort_values("day")
    plt.figure(figsize=(10, 4))
    plt.plot(ctr_by_day["day"], ctr_by_day["label"])
    plt.title("CTR over time (daily)")
    plt.xlabel("day")
    plt.ylabel("CTR")
    save_fig("02_ctr_over_time_day.png")

# 3) CTR by hour
if safe_has_cols(sample_pdf, ["hour", "label"]):
    tmp = sample_pdf.dropna(subset=["hour", "label"]).copy()
    ctr_by_hour = tmp.groupby("hour")["label"].mean().reset_index().sort_values("hour")
    plt.figure(figsize=(10, 4))
    plt.bar(ctr_by_hour["hour"].astype(int), ctr_by_hour["label"])
    plt.title("CTR by hour")
    plt.xlabel("hour")
    plt.ylabel("CTR")
    save_fig("03_ctr_by_hour.png")

# 4) Join enrichment bar
if enrichment:
    names = [x[0] for x in enrichment]
    vals = [x[1] for x in enrichment]
    plt.figure(figsize=(10, 4))
    plt.barh(names, vals)
    plt.title("Join enrichment rates (sample-based)")
    plt.xlabel("Enrichment Rate")
    plt.xlim(0, 1)
    save_fig("04_join_enrichment_rates.png")

# 5) CTR by cate_id/brand topN
def plot_ctr_topN(pdf, col_cat, top_n=TOP_N, fname="ctr_top.png", title=""):
    if not safe_has_cols(pdf, [col_cat, "label"]):
        return
    tmp = pdf[[col_cat, "label"]].dropna()
    vc = tmp[col_cat].value_counts().head(top_n)
    top_cats = set(vc.index.tolist())
    tmp2 = tmp[tmp[col_cat].isin(top_cats)]

    out = tmp2.groupby(col_cat)["label"].mean().reset_index()
    out["count"] = tmp2.groupby(col_cat)["label"].size().values
    out = out.sort_values("count", ascending=False)

    plt.figure(figsize=(12, 5))
    if HAS_SNS:
        sns.barplot(data=out, x=col_cat, y="label")
    else:
        plt.bar(out[col_cat].astype(str), out["label"].values)

    plt.title(title)
    plt.xlabel(col_cat)
    plt.ylabel("CTR")
    plt.xticks(rotation=45, ha="right")
    save_fig(fname)

if "cate_id" in sample_pdf.columns:
    plot_ctr_topN(sample_pdf, "cate_id", fname="05_ctr_by_cate_topN.png", title=f"CTR by cate_id (Top {TOP_N})")
if "brand" in sample_pdf.columns:
    plot_ctr_topN(sample_pdf, "brand", fname="06_ctr_by_brand_topN.png", title=f"CTR by brand (Top {TOP_N})")

# 6) price vs label
if safe_has_cols(sample_pdf, ["price", "label"]):
    tmp = sample_pdf[["price", "label"]].dropna()
    plt.figure(figsize=(6, 5))
    if HAS_SNS:
        sns.boxplot(x="label", y="price", data=tmp)
    else:
        g0 = tmp[tmp["label"] == 0]["price"].values
        g1 = tmp[tmp["label"] == 1]["price"].values
        plt.boxplot([g0, g1], labels=["0", "1"])
        plt.xlabel("label")
    plt.title("Price distribution by click")
    plt.ylabel("price")
    save_fig("07_price_vs_label_boxplot.png")

# 12) Save fused data + Temporal split 
if "time_stamp" in full_df.columns:
    cutoff = full_df.approxQuantile("time_stamp", [0.8], 0.001)[0]
    train_df = full_df.filter(F.col("time_stamp") <= cutoff)
    test_df  = full_df.filter(F.col("time_stamp") > cutoff)

    train_path = os.path.join(processed_dir, "fused_train.csv")
    test_path  = os.path.join(processed_dir, "fused_test.csv")
    fused_path = os.path.join(processed_dir, "fused_data.csv")

    # Save train/test
    train_df.toPandas().to_csv(train_path, index=False)
    test_df.toPandas().to_csv(test_path, index=False)

    # Also save full 
    full_df.toPandas().to_csv(fused_path, index=False)

    print("\nSaved:")
    print(" -", train_path)
    print(" -", test_path)
    print(" -", fused_path)
else:
    fused_path = os.path.join(processed_dir, "fused_data.csv")
    full_df.toPandas().to_csv(fused_path, index=False)
    print("\nSaved:", fused_path)

spark.stop()
print("\nSpark stopped.")


Spark version: 4.0.1

click_df capped (no count). Showing 3 rows:
+----+----------+------+----------+-----+---+-------------------+
|user|adgroup_id|rating|time_stamp|label|clk|time_stamp_str     |
+----+----------+------+----------+-----+---+-------------------+
|307 |7701      |0.5   |1186173318|0    |0  |2007-08-03 23:35:18|
|307 |3063      |1.0   |1186173320|0    |0  |2007-08-03 23:35:20|
|307 |2169      |1.5   |1186173322|0    |0  |2007-08-03 23:35:22|
+----+----------+------+----------+-----+---+-------------------+
only showing top 3 rows

After Early Fusion (no count). Columns: 13

Building time-aware behavior features...
Time-aware behavior added. Columns: 18

Enrichment checks (using sample to avoid heavy counts):

EDA sample shape: (20000, 20)
Saved figure: d:\projects\Ai\project_fusion_ecu\data\processed\figures\01_label_distribution.png
Saved figure: d:\projects\Ai\project_fusion_ecu\data\processed\figures\02_ctr_over_time_day.png
Saved figure: d:\projects\Ai\project_fusio