# PySpark Data Processing Pipeline

**Purpose**: Prepare training data

**Author**: Bai, Gengyuan

**Tasks**:
1. Process raw data (events.csv, item_properties) using PySpark
2. Build candidate sets (prefix, item covisitation, category covisitation, popularity, user history)
3. Generate all features required for training
4. Save as parquet files for subsequent training

**Data Windows**:
- Training: 2015-05-01 to 2015-07-01 (2 months)
- Validation: 2015-07-01 to 2015-08-01 (1 month)
- Session gap: 30 minutes


## Import Libraries and Configuration


In [64]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import os
import warnings
import logging
from pathlib import Path
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lag, unix_timestamp, when, concat, lit
import numpy as np
import pandas as pd
from gensim.models import Word2Vec

# Suppress all Python warnings
warnings.filterwarnings('ignore')
os.environ['PYTHONWARNINGS'] = 'ignore'

# Set environment variables BEFORE creating Spark session
os.environ['SPARK_LOCAL_DIRS'] = '/tmp/spark-temp'
os.environ['SPARK_LOCAL_IP'] = '127.0.0.1'

# Suppress Spark and Hadoop warnings
os.environ['PYARROW_IGNORE_TIMEZONE'] = '1'

# Configure Python logging to suppress all warnings
logging.getLogger('py4j').setLevel(logging.CRITICAL)
logging.getLogger('pyspark').setLevel(logging.CRITICAL)

# Redirect Java stderr to suppress log4j errors
import subprocess
import tempfile


In [65]:
# Configuration
DATA_DIR = Path("data/raw")
OUTPUT_DIR = Path("data/processed")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_START = "2015-05-01"
TRAIN_END = "2015-07-01"
VALID_START = "2015-07-01"
VALID_END = "2015-08-01"
SESSION_GAP_MINUTES = 30

# Start timing
import time
from datetime import datetime

pipeline_start_time = time.time()
pipeline_start_datetime = datetime.now()

print("=" * 80)
print("PySpark Data Processing Pipeline Start")
print("=" * 80)
print(f"Start Time: {pipeline_start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Training window: [{TRAIN_START}, {TRAIN_END})")
print(f"Validation window: [{VALID_START}, {VALID_END})")
print(f"Session gap: {SESSION_GAP_MINUTES} minutes")
print()


PySpark Data Processing Pipeline Start
Start Time: 2025-12-08 13:52:45
Training window: [2015-05-01, 2015-07-01)
Validation window: [2015-07-01, 2015-08-01)
Session gap: 30 minutes



## STEP 1: Initialize Spark Session


In [66]:
import sys
from io import StringIO

# Create a minimal log4j configuration to suppress all logging
log4j_config_content = """
log4j.rootCategory=OFF
log4j.logger.org.apache.spark=OFF
log4j.logger.org.spark_project=OFF
log4j.logger.org.apache.hadoop=OFF
log4j.logger.akka=OFF
log4j.logger.org.eclipse.jetty=OFF
"""

log4j_config_path = "/tmp/spark-log4j.properties"
with open(log4j_config_path, "w") as f:
    f.write(log4j_config_content)

# Also create log4j2 configuration
log4j2_config_content = """
status = OFF
name = SparkConfig
appender.console.type = Console
appender.console.name = STDOUT
appender.console.layout.type = PatternLayout
rootLogger.level = OFF
"""

log4j2_config_path = "/tmp/spark-log4j2.properties"
with open(log4j2_config_path, "w") as f:
    f.write(log4j2_config_content)

# Capture stderr to suppress Spark initialization warnings
old_stderr = sys.stderr
sys.stderr = StringIO()

print("STEP 1: Initializing Spark session...")

# Create Spark session with comprehensive logging suppression
spark = SparkSession.builder \
    .appName("Ecommerce_Training_Data_Preparation") \
    .config("spark.driver.memory", "4g") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.local.dir", "/tmp/spark-temp") \
    .config("spark.sql.debug.maxToStringFields", "1000") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.ui.showConsoleProgress", "false") \
    .config("spark.driver.extraJavaOptions", 
            f"-Dlog4j.configuration=file:{log4j_config_path} "
            f"-Dlog4j2.configurationFile=file:{log4j2_config_path} "
            "-Dlog4j.logLevel=OFF "
            "-Dlog4j2.level=OFF") \
    .config("spark.executor.extraJavaOptions", 
            f"-Dlog4j.configuration=file:{log4j_config_path} "
            "-Dlog4j.logLevel=OFF") \
    .master("local[*]") \
    .getOrCreate()

# Immediately set log level to OFF (most restrictive)
spark.sparkContext.setLogLevel("OFF")

# Suppress all Java/Scala logging through log4j
try:
    log4j = spark._jvm.org.apache.log4j
    log4j.LogManager.getRootLogger().setLevel(log4j.Level.OFF)
    log4j.Logger.getLogger("org").setLevel(log4j.Level.OFF)
    log4j.Logger.getLogger("akka").setLevel(log4j.Level.OFF)
except Exception:
    pass

# Restore stderr
sys.stderr = old_stderr

print("✓ Spark session created successfully")
print()


STEP 1: Initializing Spark session...
✓ Spark session created successfully



## STEP 2: Load and Sessionize Events


In [67]:
print("STEP 2: Loading events.csv and performing sessionization...")

events_path = f"file://{(DATA_DIR / 'events.csv').absolute()}"
events_df = spark.read.csv(events_path, header=True, inferSchema=True)

# Convert timestamp (from milliseconds to seconds)
events_df = events_df.withColumn(
    "ts", 
    F.from_unixtime(col("timestamp") / 1000).cast("timestamp")
)

# Filter time range (includes both training and validation windows)
events_df = events_df.filter(
    (col("ts") >= F.lit(TRAIN_START).cast("timestamp")) &
    (col("ts") < F.lit(VALID_END).cast("timestamp"))
)

# Rename columns
events_df = events_df.select(
    col("visitorid").cast("bigint").alias("user_id"),
    col("ts"),
    col("itemid").cast("bigint").alias("item_id"),
    col("event")
)

print(f"  Events loaded: {events_df.count():,}")


STEP 2: Loading events.csv and performing sessionization...
  Events loaded: 1,902,445


In [68]:
# Sessionization: 30-minute gap rule
window_user_time = Window.partitionBy("user_id").orderBy("ts")

events_df = events_df.withColumn(
    "prev_ts",
    lag("ts").over(window_user_time)
)

events_df = events_df.withColumn(
    "time_gap_sec",
    when(
        col("prev_ts").isNotNull(),
        unix_timestamp("ts") - unix_timestamp("prev_ts")
    ).otherwise(0)
)

events_df = events_df.withColumn(
    "is_new_session",
    when(
        (col("time_gap_sec") > SESSION_GAP_MINUTES * 60) | col("prev_ts").isNull(),
        1
    ).otherwise(0)
)

window_session = Window.partitionBy("user_id").orderBy("ts")
events_df = events_df.withColumn(
    "session_num",
    F.sum("is_new_session").over(window_session)
)

events_df = events_df.withColumn(
    "session_id",
    concat(col("user_id").cast("string"), lit("_"), col("session_num").cast("string"))
)

# Clean intermediate columns
events_df = events_df.select("session_id", "user_id", "ts", "item_id", "event")

# Cache to speed up subsequent operations
events_df.cache()

session_count = events_df.select("session_id").distinct().count()
print(f"  Sessions generated: {session_count:,}")
print("✓ Sessionization completed")
print()


  Sessions generated: 1,194,255
✓ Sessionization completed



## STEP 3: Load Item Category Information


In [69]:
print("STEP 3: Loading item properties and extracting category...")

part1_path = f"file://{(DATA_DIR / 'item_properties_part1.csv').absolute()}"
part2_path = f"file://{(DATA_DIR / 'item_properties_part2.csv').absolute()}"

props1 = spark.read.csv(part1_path, header=True, inferSchema=True)
props2 = spark.read.csv(part2_path, header=True, inferSchema=True)

# Merge two parts
item_props = props1.union(props2)

# Convert timestamp
item_props = item_props.withColumn(
    "ts",
    F.from_unixtime(col("timestamp") / 1000).cast("timestamp")
)

# Keep only categoryid property, and take the latest value
item_props = item_props.filter(col("property") == "categoryid")

item_props = item_props.select(
    col("itemid").cast("bigint").alias("item_id"),
    col("value").cast("bigint").alias("category_id"),
    col("ts")
)

# For each item, take the latest category_id
window_item = Window.partitionBy("item_id").orderBy(F.desc("ts"))
item_props = item_props.withColumn("rn", F.row_number().over(window_item))
item_category = item_props.filter(col("rn") == 1).select("item_id", "category_id")

item_category.cache()
print(f"  Items with category: {item_category.count():,}")
print("✓ Category information loaded")
print()


STEP 3: Loading item properties and extracting category...
  Items with category: 417,053
✓ Category information loaded



## STEP 4: Extract Add-to-Cart Events


In [70]:
print("STEP 4: Extracting add-to-cart events...")

# Join events with category
atc_events = events_df.filter(col("event") == "addtocart") \
    .join(item_category, "item_id", "inner") \
    .select(
        col("session_id"),
        col("user_id"),
        col("ts").alias("atc_ts"),
        col("item_id"),
        col("category_id")
    )

atc_events.cache()

# Split into training and validation sets
atc_train = atc_events.filter(
    (col("atc_ts") >= F.lit(TRAIN_START).cast("timestamp")) &
    (col("atc_ts") < F.lit(TRAIN_END).cast("timestamp"))
)

atc_valid = atc_events.filter(
    (col("atc_ts") >= F.lit(VALID_START).cast("timestamp")) &
    (col("atc_ts") < F.lit(VALID_END).cast("timestamp"))
)

atc_train.cache()
atc_valid.cache()

n_atc_train = atc_train.count()
n_atc_valid = atc_valid.count()

print(f"  Training ATC events: {n_atc_train:,}")
print(f"  Validation ATC events: {n_atc_valid:,}")
print("✓ ATC event extraction completed")
print()


STEP 4: Extracting add-to-cart events...
  Training ATC events: 29,244
  Validation ATC events: 17,151
✓ ATC event extraction completed



## STEP 5: Build Candidate Sets


In [71]:
print("STEP 5: Building candidate sets...")

def build_candidates_spark(atc_df, split_name, train_cutoff_str):
    """
    Build candidate category sets for given ATC events
    Includes: prefix, item covisitation, category covisitation, popularity, user history
    """
    print(f"  Building {split_name} candidates...")
    
    train_cutoff = F.lit(train_cutoff_str).cast("timestamp")
    
    # 1. Prefix candidates: all categories in session prefix
    prefix_cands = atc_df.alias("a") \
        .join(
            events_df.alias("se"),
            (col("a.session_id") == col("se.session_id")) & (col("se.ts") < col("a.atc_ts")),
            "inner"
        ) \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "inner") \
        .select(
            col("a.session_id"),
            col("a.atc_ts"),
            col("ic.category_id").alias("category_id")
        ).distinct()
    
    # 2. Item Co-visitation candidates
    # Calculate item-item co-occurrence (before training cutoff)
    train_events = events_df.filter(col("ts") < train_cutoff)
    
    item_covis = train_events.alias("a") \
        .join(
            train_events.alias("b"),
            (col("a.session_id") == col("b.session_id")) & (col("a.item_id") < col("b.item_id")),
            "inner"
        ) \
        .groupBy(col("a.item_id").alias("item_a"), col("b.item_id").alias("item_b")) \
        .agg(F.count("*").alias("covis")) \
        .filter(col("covis") >= 3)
    
    # For each ATC's prefix items, find co-occurring items and convert to categories
    prefix_items = atc_df.alias("a") \
        .join(
            events_df.alias("se"),
            (col("a.session_id") == col("se.session_id")) & (col("se.ts") < col("a.atc_ts")),
            "inner"
        ) \
        .select(
            col("a.session_id"),
            col("a.atc_ts"),
            col("se.item_id")
        )
    
    itemcovis_cands = prefix_items.alias("pi") \
        .join(item_covis.alias("iv"), col("pi.item_id") == col("iv.item_a"), "inner") \
        .join(item_category.alias("ic2"), col("iv.item_b") == col("ic2.item_id"), "inner") \
        .groupBy(col("pi.session_id"), col("pi.atc_ts"), col("ic2.category_id")) \
        .agg(F.max("iv.covis").alias("max_covis")) \
        .withColumn(
            "rn",
            F.row_number().over(
                Window.partitionBy("session_id", "atc_ts").orderBy(F.desc("max_covis"))
            )
        ) \
        .filter(col("rn") <= 15) \
        .select(col("session_id"), col("atc_ts"), col("ic2.category_id").alias("category_id"))
    
    # 3. Category Co-visitation candidates
    # Calculate category-category co-occurrence
    train_events_with_cat = train_events.alias("se") \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "inner") \
        .select(col("se.session_id"), col("ic.category_id"))
    
    cat_covis = train_events_with_cat.alias("a") \
        .join(
            train_events_with_cat.alias("b"),
            (col("a.session_id") == col("b.session_id")) & (col("a.category_id") < col("b.category_id")),
            "inner"
        ) \
        .groupBy(col("a.category_id").alias("cat_a"), col("b.category_id").alias("cat_b")) \
        .agg(F.countDistinct("a.session_id").alias("cooccur")) \
        .filter(col("cooccur") >= 5)
    
    prefix_cats = atc_df.alias("a") \
        .join(
            events_df.alias("se"),
            (col("a.session_id") == col("se.session_id")) & (col("se.ts") < col("a.atc_ts")),
            "inner"
        ) \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "inner") \
        .select(
            col("a.session_id"),
            col("a.atc_ts"),
            col("ic.category_id")
        )
    
    catcovis_cands = prefix_cats.alias("pc") \
        .join(cat_covis.alias("cc"), col("pc.category_id") == col("cc.cat_a"), "inner") \
        .groupBy(col("pc.session_id"), col("pc.atc_ts"), col("cc.cat_b")) \
        .agg(F.max("cc.cooccur").alias("max_cooccur")) \
        .withColumn(
            "rn",
            F.row_number().over(
                Window.partitionBy("session_id", "atc_ts").orderBy(F.desc("max_cooccur"))
            )
        ) \
        .filter(col("rn") <= 10) \
        .select(col("session_id"), col("atc_ts"), col("cat_b").alias("category_id"))
    
    # 4. Popularity candidates: top 20 globally most popular categories
    cat_pop = train_events.alias("se") \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "inner") \
        .groupBy("ic.category_id") \
        .agg(F.count("*").alias("cnt")) \
        .orderBy(F.desc("cnt")) \
        .limit(20)
    
    pop_cands = atc_df.alias("a").crossJoin(cat_pop.select(col("category_id").alias("pop_cat_id"))) \
        .select(col("a.session_id"), col("a.atc_ts"), col("pop_cat_id").alias("category_id"))
    
    # 5. User History candidates: user's historically viewed categories (recent 10)
    user_past_cats = train_events.alias("se") \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "inner") \
        .filter(col("se.ts") < train_cutoff) \
        .groupBy(col("se.user_id"), col("ic.category_id")) \
        .agg(F.max("se.ts").alias("last_seen"))
    
    userhist_cands = atc_df.alias("a") \
        .join(
            user_past_cats.alias("upc"),
            (col("a.user_id") == col("upc.user_id")) & (col("upc.last_seen") < col("a.atc_ts")),
            "inner"
        ) \
        .withColumn(
            "rn",
            F.row_number().over(
                Window.partitionBy("a.session_id", "a.atc_ts").orderBy(F.desc("upc.last_seen"))
            )
        ) \
        .filter(col("rn") <= 10) \
        .select(col("a.session_id"), col("a.atc_ts"), col("upc.category_id").alias("category_id"))
    
    # Merge all candidates
    all_candidates = prefix_cands \
        .union(itemcovis_cands) \
        .union(catcovis_cands) \
        .union(pop_cands) \
        .union(userhist_cands) \
        .distinct()
    
    n_cands = all_candidates.count()
    print(f"    {split_name}: {n_cands:,} candidates")
    
    return all_candidates


STEP 5: Building candidate sets...


In [72]:
# Build training and validation candidate sets
train_candidates = build_candidates_spark(atc_train, "train", TRAIN_END)
valid_candidates = build_candidates_spark(atc_valid, "valid", TRAIN_END)

train_candidates.cache()
valid_candidates.cache()

print("✓ Candidate set building completed")
print()


  Building train candidates...
    train: 948,570 candidates
  Building valid candidates...
    valid: 538,331 candidates
✓ Candidate set building completed



In [73]:
print("STEP 5.5: Training Category Embeddings (Word2Vec)...")

# Extract category sequences for each session (for Word2Vec training)
cat_seqs_spark = events_df.filter(col("ts") < F.lit(TRAIN_END).cast("timestamp")) \
    .join(item_category, "item_id", "inner") \
    .select("session_id", "ts", "category_id") \
    .orderBy("session_id", "ts")

cat_seqs_spark = cat_seqs_spark.groupBy("session_id").agg(
    F.collect_list("category_id").alias("cat_sequence")
)

# Convert to Pandas to use gensim
cat_seqs_pd = cat_seqs_spark.toPandas()

# Prepare training data (convert to string lists)
sequences = [[str(cat) for cat in seq if cat is not None] for seq in cat_seqs_pd['cat_sequence']]
sequences = [seq for seq in sequences if len(seq) >= 2]  # Filter short sequences

print(f"  Extracted {len(sequences):,} category sequences for training")

# Train Word2Vec model
w2v_model = Word2Vec(
    sentences=sequences,
    vector_size=16,
    window=5,
    min_count=3,
    workers=4,
    sg=1,
    epochs=10,
    seed=42
)

print(f"  Trained embeddings for {len(w2v_model.wv)} categories")

# Create embedding lookup dictionary
cat_embeddings = {int(cat): w2v_model.wv[cat] for cat in w2v_model.wv.index_to_key}

# Show similarity check
sample_cat = list(cat_embeddings.keys())[0]
similar = w2v_model.wv.most_similar(str(sample_cat), topn=5)
print(f"  Example: Category {sample_cat} most similar categories: {[(int(c), round(s, 3)) for c, s in similar]}")

print("✓ Word2Vec training completed")
print()


STEP 5.5: Training Category Embeddings (Word2Vec)...


  Extracted 160,240 category sequences for training


Exception ignored in: 'gensim.models.word2vec_inner.our_dot_float'


  Trained embeddings for 921 categories
  Example: Category 1051 most similar categories: [(1192, 0.863), (218, 0.851), (626, 0.839), (1213, 0.821), (568, 0.804)]
✓ Word2Vec training completed



## STEP 6: Feature Engineering


In [74]:
print("STEP 6: Feature engineering...")

def build_features_spark(atc_df, candidates_df, split_name, train_cutoff_str):
    """
    Build all features required for training for the candidate set
    """
    print(f"  Building {split_name} features...")
    
    train_cutoff = F.lit(train_cutoff_str).cast("timestamp")
    train_events = events_df.filter(col("ts") < train_cutoff)
    
    # Join ATC with candidates
    base = atc_df.alias("a") \
        .join(
            candidates_df.alias("c"),
            (col("a.session_id") == col("c.session_id")) & (col("a.atc_ts") == col("c.atc_ts")),
            "inner"
        ) \
        .select(
            col("a.session_id"),
            col("a.user_id"),
            col("a.atc_ts"),
            col("a.category_id").alias("true_category_id"),
            col("c.category_id").alias("cand_category_id")
        )
    
    # 1. Prefix statistics features
    prefix_events = base.alias("b") \
        .join(
            events_df.alias("se"),
            (col("b.session_id") == col("se.session_id")) & (col("se.ts") < col("b.atc_ts")),
            "left"
        ) \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "left")
    
    prefix_stats = prefix_events.groupBy(
        col("b.session_id"), col("b.atc_ts"), col("b.cand_category_id")
    ).agg(
        F.countDistinct(col("se.item_id")).alias("n_prefix_items"),
        F.count(col("se.item_id")).alias("n_prefix_events"),
        F.sum(when(col("ic.category_id") == col("b.cand_category_id"), 1).otherwise(0)).alias("cat_count_in_prefix"),
        F.max(
            when(col("ic.category_id") == col("b.cand_category_id"), 
                 unix_timestamp(col("b.atc_ts")) - unix_timestamp(col("se.ts")))
        ).alias("recency_sec"),
        F.min(col("se.ts")).alias("session_start"),
        F.countDistinct(col("ic.category_id")).alias("n_unique_cats_in_session")
    ).select(
        col("session_id"),
        col("atc_ts"),
        col("cand_category_id"),
        col("n_prefix_items"),
        col("n_prefix_events"),
        col("cat_count_in_prefix"),
        col("recency_sec"),
        col("session_start"),
        col("n_unique_cats_in_session")
    )
    
    # 2. Category global popularity
    cat_pop = train_events.alias("se") \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "inner") \
        .groupBy(col("ic.category_id")) \
        .agg(F.count("*").alias("global_pop")) \
        .select(
            col("category_id"),
            col("global_pop")
        )
    
    # 3. User-Category affinity
    user_cat_aff = train_events.alias("se") \
        .join(item_category.alias("ic"), col("se.item_id") == col("ic.item_id"), "inner") \
        .groupBy(col("se.user_id"), col("ic.category_id")) \
        .agg(
            F.count("*").alias("user_cat_interactions"),
            F.countDistinct(col("se.session_id")).alias("user_cat_sessions")
        ) \
        .select(
            col("user_id"),
            col("category_id"),
            col("user_cat_interactions"),
            col("user_cat_sessions")
        )
    
    # 4. User statistics
    user_stats = train_events.groupBy("user_id", "session_id").agg(
        (F.max("ts").cast("long") - F.min("ts").cast("long")).alias("session_duration")
    ).groupBy("user_id").agg(
        F.countDistinct("session_id").alias("total_sessions"),
        F.avg("session_duration").alias("avg_session_duration")
    )
    
    # Join all features
    features = base.alias("base") \
        .join(
            prefix_stats.alias("ps"),
            (col("base.session_id") == col("ps.session_id")) &
            (col("base.atc_ts") == col("ps.atc_ts")) &
            (col("base.cand_category_id") == col("ps.cand_category_id")),
            "left"
        ) \
        .join(
            cat_pop.alias("cp"),
            col("base.cand_category_id") == col("cp.category_id"),
            "left"
        ) \
        .join(
            user_cat_aff.alias("uca"),
            (col("base.user_id") == col("uca.user_id")) &
            (col("base.cand_category_id") == col("uca.category_id")),
            "left"
        ) \
        .join(
            user_stats.alias("us"),
            col("base.user_id") == col("us.user_id"),
            "left"
        )
    
    # Calculate derived features
    features = features.select(
        col("base.session_id"),
        col("base.atc_ts"),
        col("base.cand_category_id").alias("category_id"),
        
        # Prefix features
        F.coalesce(col("ps.n_prefix_items"), F.lit(0)).alias("n_prefix_items"),
        F.coalesce(col("ps.n_prefix_events"), F.lit(0)).alias("n_prefix_events"),
        F.coalesce(col("ps.cat_count_in_prefix"), F.lit(0)).alias("cat_count_in_prefix"),
        (F.coalesce(col("ps.cat_count_in_prefix"), F.lit(0)) / 
         F.greatest(F.coalesce(col("ps.n_prefix_events"), F.lit(1)), F.lit(1))).alias("cat_share_in_prefix"),
        F.coalesce(col("ps.recency_sec"), F.lit(999999)).alias("recency_sec"),
        F.log1p(F.coalesce(col("ps.recency_sec"), F.lit(999999))).alias("log_recency"),
        
        # Time features
        F.hour(col("base.atc_ts")).alias("hour_of_day"),
        F.dayofweek(col("base.atc_ts")).alias("day_of_week"),
        when(F.dayofweek(col("base.atc_ts")).isin([1, 7]), 1).otherwise(0).alias("is_weekend"),
        F.coalesce(unix_timestamp(col("base.atc_ts")) - unix_timestamp(col("ps.session_start")), F.lit(0)).alias("time_since_session_start"),
        F.coalesce(col("ps.n_unique_cats_in_session"), F.lit(0)).alias("session_cat_diversity"),
        
        # Category popularity
        F.coalesce(col("cp.global_pop"), F.lit(1)).alias("cat_popularity"),
        F.log1p(F.coalesce(col("cp.global_pop"), F.lit(1))).alias("log_cat_pop"),
        
        # User-Category affinity
        F.coalesce(col("uca.user_cat_interactions"), F.lit(0)).alias("user_cat_hist"),
        F.log1p(F.coalesce(col("uca.user_cat_interactions"), F.lit(0))).alias("log_user_cat_hist"),
        F.coalesce(col("uca.user_cat_sessions"), F.lit(0)).alias("user_cat_sessions"),
        
        # User statistics
        F.coalesce(col("us.total_sessions"), F.lit(0)).alias("user_total_sessions"),
        F.coalesce(col("us.avg_session_duration"), F.lit(0)).alias("user_avg_session_dur"),
        
        # Label
        when(col("base.true_category_id") == col("base.cand_category_id"), 1).otherwise(0).alias("y")
    )
    
    # Count rows first (before adding embeddings)
    n_rows = features.count()
    
    print(f"    {split_name}: {n_rows:,} rows of base features")
    print(f"    Adding 16-dimensional category embeddings...")
    
    # Broadcast embedding dictionary to improve performance
    emb_broadcast = spark.sparkContext.broadcast(cat_embeddings)
    
    # Define UDF to get a specific dimension of the embedding
    def get_embedding_dim(cat_id, dim_idx):
        emb_dict = emb_broadcast.value
        if cat_id in emb_dict:
            return float(emb_dict[cat_id][dim_idx])
        else:
            return 0.0
    
    # Register UDF
    from pyspark.sql.types import FloatType
    get_emb_udf = F.udf(get_embedding_dim, FloatType())
    
    # Add embedding dimensions one by one
    for dim in range(16):
        features = features.withColumn(
            f'cat_emb_{dim}',
            get_emb_udf(col("category_id"), F.lit(dim))
        )
    
    print(f"    {split_name}: {n_rows:,} rows x {len(features.columns)} columns (with embeddings)")
    
    return features


STEP 6: Feature engineering...


In [75]:
# Build training and validation features
X_train_spark = build_features_spark(atc_train, train_candidates, "train", TRAIN_END)
X_valid_spark = build_features_spark(atc_valid, valid_candidates, "valid", TRAIN_END)

print("✓ Feature engineering completed")
print()


  Building train features...
    train: 970,995 rows of base features
    Adding 16-dimensional category embeddings...
    train: 970,995 rows x 38 columns (with embeddings)
  Building valid features...
    valid: 551,775 rows of base features
    Adding 16-dimensional category embeddings...
    valid: 551,775 rows x 38 columns (with embeddings)
✓ Feature engineering completed



In [76]:
print("STEP 7: Saving training data...")

# Optimize Spark configurations for saving large DataFrames
spark.conf.set("spark.sql.debug.maxToStringFields", "1000")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")

train_output_path = f"file://{(OUTPUT_DIR / 'X_train_spark.parquet').absolute()}"
valid_output_path = f"file://{(OUTPUT_DIR / 'X_valid_spark.parquet').absolute()}"

# Save with optimized settings
X_train_spark.write.mode("overwrite").option("maxRecordsPerFile", 50000).parquet(train_output_path)
X_valid_spark.write.mode("overwrite").option("maxRecordsPerFile", 50000).parquet(valid_output_path)

print(f"  Training set saved to: {OUTPUT_DIR / 'X_train_spark.parquet'}")
print(f"  Validation set saved to: {OUTPUT_DIR / 'X_valid_spark.parquet'}")
print("✓ Data saving completed")
print()


STEP 7: Saving training data...
  Training set saved to: data/processed/X_train_spark.parquet
  Validation set saved to: data/processed/X_valid_spark.parquet
✓ Data saving completed



## STEP 8: Verify output Data


In [77]:
print("=" * 80)
print("Reading Generated Parquet Files")
print("=" * 80)

# Read the saved parquet files
train_output_path = f"file://{(OUTPUT_DIR / 'X_train_spark.parquet').absolute()}"
df_train = spark.read.parquet(train_output_path)

valid_output_path = f"file://{(OUTPUT_DIR / 'X_valid_spark.parquet').absolute()}"
df_valid = spark.read.parquet(valid_output_path)

# Display Schema
print("\n" + "=" * 80)
print("TRAINING DATA SCHEMA")
print("=" * 80)
df_train.printSchema()

print("\n" + "=" * 80)
print("VALIDATION DATA SCHEMA")
print("=" * 80)
df_valid.printSchema()

# Display basic statistics
print("\n" + "=" * 80)
print("DATASET STATISTICS")
print("=" * 80)
print(f"Training rows: {df_train.count():,}")
print(f"Validation rows: {df_valid.count():,}")
print(f"Number of columns: {len(df_train.columns)}")
print(f"Column names: {', '.join(df_train.columns)}")

# Display first 10 rows of training data
print("\n" + "=" * 80)
print("TRAINING DATA - First 10 Rows")
print("=" * 80)
df_train.show(10, truncate=False)

# Display first 10 rows of validation data
print("\n" + "=" * 80)
print("VALIDATION DATA - First 10 Rows")
print("=" * 80)
df_valid.show(10, truncate=False)

# Display label distribution
print("\n" + "=" * 80)
print("LABEL DISTRIBUTION")
print("=" * 80)
print("\nTraining set:")
df_train.groupBy("y").count().orderBy("y").show()

print("Validation set:")
df_valid.groupBy("y").count().orderBy("y").show()

# Display summary statistics for key numerical features
print("\n" + "=" * 80)
print("SUMMARY STATISTICS - Key Features")
print("=" * 80)
key_features = ['n_prefix_items', 'n_prefix_events', 'cat_count_in_prefix', 
                'cat_popularity', 'user_cat_hist', 'recency_sec']
df_train.select(key_features).describe().show()

print("\n✓ Data verification completed successfully!")
print("=" * 80)


Reading Generated Parquet Files

TRAINING DATA SCHEMA
root
 |-- session_id: string (nullable = true)
 |-- atc_ts: timestamp (nullable = true)
 |-- category_id: long (nullable = true)
 |-- n_prefix_items: long (nullable = true)
 |-- n_prefix_events: long (nullable = true)
 |-- cat_count_in_prefix: long (nullable = true)
 |-- cat_share_in_prefix: double (nullable = true)
 |-- recency_sec: long (nullable = true)
 |-- log_recency: double (nullable = true)
 |-- hour_of_day: integer (nullable = true)
 |-- day_of_week: integer (nullable = true)
 |-- is_weekend: integer (nullable = true)
 |-- time_since_session_start: long (nullable = true)
 |-- session_cat_diversity: long (nullable = true)
 |-- cat_popularity: long (nullable = true)
 |-- log_cat_pop: double (nullable = true)
 |-- user_cat_hist: long (nullable = true)
 |-- log_user_cat_hist: double (nullable = true)
 |-- user_cat_sessions: long (nullable = true)
 |-- user_total_sessions: long (nullable = true)
 |-- user_avg_session_dur: double

In [78]:
print("=" * 80)
print("Data Processing Completion Summary")
print("=" * 80)

# Calculate label distribution
train_label_dist = X_train_spark.groupBy("y").count().collect()
valid_label_dist = X_valid_spark.groupBy("y").count().collect()

train_pos = [r['count'] for r in train_label_dist if r['y'] == 1][0] if any(r['y'] == 1 for r in train_label_dist) else 0
train_total = sum(r['count'] for r in train_label_dist)
valid_pos = [r['count'] for r in valid_label_dist if r['y'] == 1][0] if any(r['y'] == 1 for r in valid_label_dist) else 0
valid_total = sum(r['count'] for r in valid_label_dist)

print(f"Training Set:")
print(f"  Total rows: {train_total:,}")
print(f"  Positive samples: {train_pos:,} ({train_pos/train_total*100:.2f}%)")
print(f"  Negative samples: {train_total - train_pos:,} ({(train_total-train_pos)/train_total*100:.2f}%)")

print(f"\nValidation Set:")
print(f"  Total rows: {valid_total:,}")
print(f"  Positive samples: {valid_pos:,} ({valid_pos/valid_total*100:.2f}%)")
print(f"  Negative samples: {valid_total - valid_pos:,} ({(valid_total-valid_pos)/valid_total*100:.2f}%)")

print("\nFeature Columns:")
feature_cols = [c for c in X_train_spark.columns if c not in ['session_id', 'atc_ts', 'category_id', 'y']]
print(f"  Total features: {len(feature_cols)}")
print(f"    - Base features: 18")
print(f"    - Category Embeddings: 16")
print(f"  Feature list: {', '.join(feature_cols)}")


print("\n" + "=" * 80)


Data Processing Completion Summary
Training Set:
  Total rows: 970,995
  Positive samples: 26,548 (2.73%)
  Negative samples: 944,447 (97.27%)

Validation Set:
  Total rows: 551,775
  Positive samples: 15,325 (2.78%)
  Negative samples: 536,450 (97.22%)

Feature Columns:
  Total features: 34
    - Base features: 18
    - Category Embeddings: 16
  Feature list: n_prefix_items, n_prefix_events, cat_count_in_prefix, cat_share_in_prefix, recency_sec, log_recency, hour_of_day, day_of_week, is_weekend, time_since_session_start, session_cat_diversity, cat_popularity, log_cat_pop, user_cat_hist, log_user_cat_hist, user_cat_sessions, user_total_sessions, user_avg_session_dur, cat_emb_0, cat_emb_1, cat_emb_2, cat_emb_3, cat_emb_4, cat_emb_5, cat_emb_6, cat_emb_7, cat_emb_8, cat_emb_9, cat_emb_10, cat_emb_11, cat_emb_12, cat_emb_13, cat_emb_14, cat_emb_15



In [79]:
# Calculate total execution time
pipeline_end_time = time.time()
pipeline_end_datetime = datetime.now()

total_seconds = pipeline_end_time - pipeline_start_time
hours = int(total_seconds // 3600)
minutes = int((total_seconds % 3600) // 60)
seconds = int(total_seconds % 60)
milliseconds = int((total_seconds % 1) * 1000)

print("=" * 80)
print("PIPELINE EXECUTION TIME SUMMARY")
print("=" * 80)
print(f"Start Time:    {pipeline_start_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"End Time:      {pipeline_end_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total Duration: {hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}")
print(f"               ({total_seconds:.2f} seconds)")
print("=" * 80)

# Breakdown by major steps (if needed, estimate based on typical execution)
print("\nEstimated Time Breakdown:")
print(f"  • Data Loading & Sessionization:  ~{int(total_seconds * 0.10):,} seconds ({int(total_seconds * 0.10 / 60)} min)")
print(f"  • Candidate Set Building:         ~{int(total_seconds * 0.30):,} seconds ({int(total_seconds * 0.30 / 60)} min)")
print(f"  • Word2Vec Training:              ~{int(total_seconds * 0.05):,} seconds ({int(total_seconds * 0.05 / 60)} min)")
print(f"  • Feature Engineering:            ~{int(total_seconds * 0.35):,} seconds ({int(total_seconds * 0.35 / 60)} min)")
print(f"  • Data Saving:                    ~{int(total_seconds * 0.15):,} seconds ({int(total_seconds * 0.15 / 60)} min)")
print(f"  • Verification & Summary:         ~{int(total_seconds * 0.05):,} seconds ({int(total_seconds * 0.05 / 60)} min)")




PIPELINE EXECUTION TIME SUMMARY
Start Time:    2025-12-08 13:52:45
End Time:      2025-12-08 14:00:09
Total Duration: 00:07:24.131
               (444.13 seconds)

Estimated Time Breakdown:
  • Data Loading & Sessionization:  ~44 seconds (0 min)
  • Candidate Set Building:         ~133 seconds (2 min)
  • Word2Vec Training:              ~22 seconds (0 min)
  • Feature Engineering:            ~155 seconds (2 min)
  • Data Saving:                    ~66 seconds (1 min)
  • Verification & Summary:         ~22 seconds (0 min)


## Stop Spark Session


In [80]:
# Stop Spark session
spark.stop()
print("Spark session stopped")


Spark session stopped
