# Libaries

In [None]:
from google.colab import drive
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.functions import count, min, max, sum, countDistinct
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql.types import FloatType
from pyspark.sql.functions import date_add, lit, col, max, min, sum, count, countDistinct, when, datediff

# I/ Introduction + content loading

In [None]:
# Names: REICHARD/ZORKANI/READY
# Course name: Big data

In [None]:
#Setting up Spark in Colab
!pip install -q pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").appName("Spark_in_Colab").getOrCreate()
spark.createDataFrame([{"status": "Spark is working!", "location": "Google Colab"}]).show()

+------------+-----------------+
|    location|           status|
+------------+-----------------+
|Google Colab|Spark is working!|
+------------+-----------------+



In [None]:
#Mounting our Google Drive folder so you we access our data files on Colab
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Setting the file paths for the datasets used for this project
order_items_filepath = "/content/drive/MyDrive/Big data group project/Data/order_items.parquet"
orders_filepath = "/content/drive/MyDrive/Big data group project/Data/orders.parquet"
products_filepath = "/content/drive/MyDrive/Big data group project/Data/products.parquet"
website_pageviews_filepath = "/content/drive/MyDrive/Big data group project/Data/website_pageviews.parquet"
website_sessions_filepath = "/content/drive/MyDrive/Big data group project/Data/website_sessions.parquet"
website_pageviews_holdout_filepath = "/content/drive/MyDrive/Big data group project/Holdout data/website_pageviews_holdout.parquet"
website_sessions_holdout_filepath = "/content/drive/MyDrive/Big data group project/Holdout data/website_sessions_holdout.parquet"

# II/ Data exploration

## 1/ Order_items

In [None]:
# order_items visualization + main information
orderit=spark\
.read\
.format("parquet")\
.option("header","true")\
.option("inferSchema","true")\
.load(order_items_filepath)

orderit.show()
orderit.printSchema()

+-------------+-------------------+--------+----------+---------------+----------+---------+
|order_item_id|         created_at|order_id|product_id|is_primary_item|price_euro|cogs_euro|
+-------------+-------------------+--------+----------+---------------+----------+---------+
|            1|2022-03-19 10:42:46|       1|         1|              1|    149.99|    69.49|
|            2|2022-03-19 19:27:37|       2|         1|              1|    149.99|    69.49|
|            3|2022-03-20 06:44:45|       3|         1|              1|    149.99|    69.49|
|            4|2022-03-20 09:41:45|       4|         1|              1|    149.99|    69.49|
|            5|2022-03-20 11:28:15|       5|         1|              1|    149.99|    69.49|
|            6|2022-03-20 16:12:47|       6|         1|              1|    149.99|    69.49|
|            7|2022-03-20 17:03:41|       7|         1|              1|    149.99|    69.49|
|            8|2022-03-20 23:35:27|       8|         1|              1

In [None]:
# Casting created_at to timestamp
orderit = orderit.withColumn("created_at_ts", to_timestamp(col("created_at"), "yyyy-MM-dd HH:mm:ss"))
orderit = orderit.drop('created_at')
orderit.show()
orderit.printSchema()

+-------------+--------+----------+---------------+----------+---------+-------------------+
|order_item_id|order_id|product_id|is_primary_item|price_euro|cogs_euro|      created_at_ts|
+-------------+--------+----------+---------------+----------+---------+-------------------+
|            1|       1|         1|              1|    149.99|    69.49|2022-03-19 10:42:46|
|            2|       2|         1|              1|    149.99|    69.49|2022-03-19 19:27:37|
|            3|       3|         1|              1|    149.99|    69.49|2022-03-20 06:44:45|
|            4|       4|         1|              1|    149.99|    69.49|2022-03-20 09:41:45|
|            5|       5|         1|              1|    149.99|    69.49|2022-03-20 11:28:15|
|            6|       6|         1|              1|    149.99|    69.49|2022-03-20 16:12:47|
|            7|       7|         1|              1|    149.99|    69.49|2022-03-20 17:03:41|
|            8|       8|         1|              1|    149.99|    69.4

In [None]:
# Counting the number of null values in each column
orderit.select([count(when(col(c).isNull(), c)).alias(c) for c in orderit.columns]).show()

+-------------+--------+----------+---------------+----------+---------+-------------+
|order_item_id|order_id|product_id|is_primary_item|price_euro|cogs_euro|created_at_ts|
+-------------+--------+----------+---------------+----------+---------+-------------+
|            0|       0|         0|              0|         0|        0|            0|
+-------------+--------+----------+---------------+----------+---------+-------------+



In [None]:
# Duplicates check
total_rows = orderit.count()
distinct_ids = orderit.select("order_item_id").distinct().count()

duplicates = total_rows - distinct_ids
print("Number of duplicates:", duplicates)

Number of duplicates: 0


## 2/ Orders

In [None]:
# orders visualization + main information
orders=spark\
.read\
.format("parquet")\
.option("header","true")\
.option("inferSchema","true")\
.load(orders_filepath)

orders.show()
orders.printSchema()

+--------+-------------------+------------------+-------+------------------+---------------+----------+---------+
|order_id|         created_at|website_session_id|user_id|primary_product_id|items_purchased|price_euro|cogs_euro|
+--------+-------------------+------------------+-------+------------------+---------------+----------+---------+
|       1|2022-03-19 10:42:46|                20|     20|                 1|              1|    149.99|    69.49|
|       2|2022-03-19 19:27:37|               104|    104|                 1|              1|    149.99|    69.49|
|       3|2022-03-20 06:44:45|               147|    147|                 1|              1|    149.99|    69.49|
|       4|2022-03-20 09:41:45|               160|    160|                 1|              1|    149.99|    69.49|
|       5|2022-03-20 11:28:15|               177|    177|                 1|              1|    149.99|    69.49|
|       6|2022-03-20 16:12:47|               232|    232|                 1|            

In [None]:
# Casting created_at to timestamp
orders = orders.withColumn("created_at_ts", to_timestamp(col("created_at"), "yyyy-MM-dd HH:mm:ss"))
orders = orders.drop('created_at')
orders.show()
orders.printSchema()

+--------+------------------+-------+------------------+---------------+----------+---------+-------------------+
|order_id|website_session_id|user_id|primary_product_id|items_purchased|price_euro|cogs_euro|      created_at_ts|
+--------+------------------+-------+------------------+---------------+----------+---------+-------------------+
|       1|                20|     20|                 1|              1|    149.99|    69.49|2022-03-19 10:42:46|
|       2|               104|    104|                 1|              1|    149.99|    69.49|2022-03-19 19:27:37|
|       3|               147|    147|                 1|              1|    149.99|    69.49|2022-03-20 06:44:45|
|       4|               160|    160|                 1|              1|    149.99|    69.49|2022-03-20 09:41:45|
|       5|               177|    177|                 1|              1|    149.99|    69.49|2022-03-20 11:28:15|
|       6|               232|    232|                 1|              1|    149.99|    6

In [None]:
# Counting the number of null values in each column
orders.select([count(when(col(c).isNull(), c)).alias(c) for c in orders.columns]).show()

+--------+------------------+-------+------------------+---------------+----------+---------+-------------+
|order_id|website_session_id|user_id|primary_product_id|items_purchased|price_euro|cogs_euro|created_at_ts|
+--------+------------------+-------+------------------+---------------+----------+---------+-------------+
|       0|                 0|      0|                 0|              0|         0|        0|            0|
+--------+------------------+-------+------------------+---------------+----------+---------+-------------+



In [None]:
# Duplicates check
total_rows = orders.count()
distinct_ids = orders.select("order_id").distinct().count()

duplicates = total_rows - distinct_ids
print("Number of duplicates:", duplicates)

Number of duplicates: 0


## 3/ Products

In [None]:
# products visualization + main information
products=spark\
.read\
.format("parquet")\
.option("header","true")\
.option("inferSchema","true")\
.load(products_filepath)

products.show()
products.printSchema()

+----------+-------------------+------------+
|product_id|         created_at|product_name|
+----------+-------------------+------------+
|         1|2022-03-19 08:00:00|    CorePack|
|         2|2023-01-06 13:00:00|TechFortress|
|         3|2023-12-12 09:00:00|     AirLite|
|         4|2024-02-05 10:00:00|    EcoShell|
+----------+-------------------+------------+

root
 |-- product_id: integer (nullable = true)
 |-- created_at: string (nullable = true)
 |-- product_name: string (nullable = true)



In [None]:
# Casting created_at to timestamp
products = products.withColumn("created_at_ts", to_timestamp(col("created_at"), "yyyy-MM-dd HH:mm:ss"))
products = products.drop('created_at')
products.show()
products.printSchema()

+----------+------------+-------------------+
|product_id|product_name|      created_at_ts|
+----------+------------+-------------------+
|         1|    CorePack|2022-03-19 08:00:00|
|         2|TechFortress|2023-01-06 13:00:00|
|         3|     AirLite|2023-12-12 09:00:00|
|         4|    EcoShell|2024-02-05 10:00:00|
+----------+------------+-------------------+

root
 |-- product_id: integer (nullable = true)
 |-- product_name: string (nullable = true)
 |-- created_at_ts: timestamp (nullable = true)



In [None]:
# Counting the number of null values in each column
products.select([count(when(col(c).isNull(), c)).alias(c) for c in products.columns]).show()

+----------+------------+-------------+
|product_id|product_name|created_at_ts|
+----------+------------+-------------+
|         0|           0|            0|
+----------+------------+-------------+



In [None]:
# No need for duplicates check here

## 4/ Website_pageviews

In [None]:
# website_pageviews visualization + main information
wpv = spark\
.read\
.format("parquet")\
.option("header","true")\
.option("inferSchema","true")\
.load(website_pageviews_filepath)

wpv.show()
wpv.printSchema()

+-------------------+-------------------+------------------+-------------+
|website_pageview_id|         created_at|website_session_id| pageview_url|
+-------------------+-------------------+------------------+-------------+
|                  1|2022-03-19 08:04:16|                 1|        /home|
|                  2|2022-03-19 08:16:49|                 2|        /home|
|                  3|2022-03-19 08:26:55|                 3|        /home|
|                  4|2022-03-19 08:37:33|                 4|        /home|
|                  5|2022-03-19 09:00:55|                 5|        /home|
|                  6|2022-03-19 09:05:46|                 6|        /home|
|                  7|2022-03-19 09:06:27|                 7|        /home|
|                  8|2022-03-19 09:10:08|                 6|    /products|
|                  9|2022-03-19 09:10:52|                 6|/the-corepack|
|                 10|2022-03-19 09:14:02|                 6|        /cart|
|                 11|2022

In [None]:
# Casting created_at to timestamp, using try_to_timestamp for robustness
wpv = wpv.withColumn("created_at_ts", try_to_timestamp(col("created_at"), lit("yyyy-MM-dd HH:mm:ss")))
wpv = wpv.drop('created_at')
wpv.show()
wpv.printSchema()

+-------------------+------------------+-------------+-------------------+
|website_pageview_id|website_session_id| pageview_url|      created_at_ts|
+-------------------+------------------+-------------+-------------------+
|                  1|                 1|        /home|2022-03-19 08:04:16|
|                  2|                 2|        /home|2022-03-19 08:16:49|
|                  3|                 3|        /home|2022-03-19 08:26:55|
|                  4|                 4|        /home|2022-03-19 08:37:33|
|                  5|                 5|        /home|2022-03-19 09:00:55|
|                  6|                 6|        /home|2022-03-19 09:05:46|
|                  7|                 7|        /home|2022-03-19 09:06:27|
|                  8|                 6|    /products|2022-03-19 09:10:08|
|                  9|                 6|/the-corepack|2022-03-19 09:10:52|
|                 10|                 6|        /cart|2022-03-19 09:14:02|
|                 11|    

In [None]:
# Counting the number of null values in each column
wpv.select([count(when(col(c).isNull(), c)).alias(c) for c in wpv.columns]).show()

+-------------------+------------------+------------+-------------+
|website_pageview_id|website_session_id|pageview_url|created_at_ts|
+-------------------+------------------+------------+-------------+
|                  0|                 0|           0|            5|
+-------------------+------------------+------------+-------------+



In [None]:
# Dropping null values
wpv = wpv.na.drop()
wpv.select([count(when(col(c).isNull(), c)).alias(c) for c in wpv.columns]).show()

+-------------------+------------------+------------+-------------+
|website_pageview_id|website_session_id|pageview_url|created_at_ts|
+-------------------+------------------+------------+-------------+
|                  0|                 0|           0|            0|
+-------------------+------------------+------------+-------------+



In [None]:
# Duplicates check
total_rows = wpv.count()
distinct_ids = wpv.select("website_pageview_id").distinct().count()

duplicates = total_rows - distinct_ids
print("Number of duplicates:", duplicates)

Number of duplicates: 0


In [None]:
wpv.groupBy("pageview_url").count().orderBy("count").show(truncate=False)

+-----------------+------+
|pageview_url     |count |
+-----------------+------+
|/the-ecoshell    |1463  |
|/lander-4        |9385  |
|/the-airlite     |16476 |
|/the-techfortress|22135 |
|/billing         |46876 |
|/lander-1        |47574 |
|/lander-5        |49518 |
|/shipping        |58057 |
|/lander-3        |72143 |
|/cart            |85443 |
|/home            |124214|
|/lander-2        |131170|
|/the-corepack    |150229|
|/products        |237835|
+-----------------+------+



## 5/ Website_sessions

In [None]:
# website_sessions visualization + main information
wss=spark\
.read\
.format("parquet")\
.option("header","true")\
.option("inferSchema","true")\
.load(website_sessions_filepath)

wss.show()
wss.printSchema()

+------------------+-------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+
|website_session_id|         created_at|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|        http_referer|traffic_source|
+------------------+-------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+
|                 1|2022-03-19 08:04:16|      1|                0|   gsearch|    nonbrand|     g_ad_1|     mobile|https://www.gsear...|   paid_search|
|                 2|2022-03-19 08:16:49|      2|                0|   gsearch|    nonbrand|     g_ad_1|    desktop|https://www.gsear...|   paid_search|
|                 3|2022-03-19 08:26:55|      3|                0|   gsearch|    nonbrand|     g_ad_1|    desktop|https://www.gsear...|   paid_search|
|                 4|2022-03-19 08:37:33|      4|                0|   gsearch|    nonbrand|    

In [None]:
# Casting created_at to timestamp
wss = wss.withColumn("created_at_ts", try_to_timestamp(col("created_at"), lit("yyyy-MM-dd HH:mm:ss")))
wss = wss.drop('created_at')
wss.show()
wss.printSchema()

+------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+-------------------+
|website_session_id|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|        http_referer|traffic_source|      created_at_ts|
+------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+-------------------+
|                 1|      1|                0|   gsearch|    nonbrand|     g_ad_1|     mobile|https://www.gsear...|   paid_search|2022-03-19 08:04:16|
|                 2|      2|                0|   gsearch|    nonbrand|     g_ad_1|    desktop|https://www.gsear...|   paid_search|2022-03-19 08:16:49|
|                 3|      3|                0|   gsearch|    nonbrand|     g_ad_1|    desktop|https://www.gsear...|   paid_search|2022-03-19 08:26:55|
|                 4|      4|                0|   gsearch|    nonbrand|     g_ad_1|    desktop|

In [None]:
# Counting the number of null values in each column
wss.select([count(when(col(c).isNull(), c)).alias(c) for c in wss.columns]).show()

+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|website_session_id|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|http_referer|traffic_source|created_at_ts|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|                 0|      0|                0|         0|           0|          0|          0|           0|             0|            4|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+



In [None]:
# Dropping null values
wss = wss.na.drop()
wss.select([count(when(col(c).isNull(), c)).alias(c) for c in wss.columns]).show()

+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|website_session_id|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|http_referer|traffic_source|created_at_ts|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|                 0|      0|                0|         0|           0|          0|          0|           0|             0|            0|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+



In [None]:
# Duplicates check
total_rows = wss.count()
distinct_ids = wss.select("website_session_id").distinct().count()

duplicates = total_rows - distinct_ids
print("Number of duplicates:", duplicates)

Number of duplicates: 0


In [None]:
wss.groupBy("traffic_source").count().orderBy("count").show()

+--------------+------+
|traffic_source| count|
+--------------+------+
|   paid_social| 10685|
|        direct| 35663|
|organic_search| 38713|
|   paid_search|348943|
+--------------+------+



In [None]:
wss.groupBy("http_referer").count().orderBy("count").show(truncate=False)

+-------------------------+------+
|http_referer             |count |
+-------------------------+------+
|https://www.instaview.com|10685 |
|NULL                     |35663 |
|https://www.bsearch.com  |65170 |
|https://www.gsearch.com  |322486|
+-------------------------+------+



## 6/ Website_pageviews_holdout

In [None]:
# website_pageviews_holdoult visualization + main information
wpvsh=spark\
.read\
.format("parquet")\
.option("header","true")\
.option("inferSchema","true")\
.load(website_pageviews_holdout_filepath)

wpvsh.show()
wpvsh.printSchema()

+-------------------+-------------------+------------------+-------------+
|website_pageview_id|         created_at|website_session_id| pageview_url|
+-------------------+-------------------+------------------+-------------+
|                  1|2025-02-01 00:01:38|            434011|    /lander-5|
|                  2|2025-02-01 00:04:59|            434011|    /products|
|                  3|2025-02-01 00:05:08|            434012|    /lander-3|
|                  4|2025-02-01 00:06:11|            434011|/the-corepack|
|                  5|2025-02-01 00:08:11|            434013|    /lander-5|
|                  6|2025-02-01 00:09:48|            434014|    /lander-3|
|                  7|2025-02-01 00:10:52|            434014|    /products|
|                  8|2025-02-01 00:11:05|            434015|    /lander-5|
|                  9|2025-02-01 00:11:45|            434015|    /products|
|                 10|2025-02-01 00:12:03|            434016|    /lander-5|
|                 11|2025

In [None]:
# Casting created_at to timestamp
wpvsh = wpvsh.withColumn("created_at_ts", try_to_timestamp(col("created_at"), lit("yyyy-MM-dd HH:mm:ss")))
wpvsh = wpvsh.drop('created_at')
wpvsh.show()
wpvsh.printSchema()

+-------------------+------------------+-------------+-------------------+
|website_pageview_id|website_session_id| pageview_url|      created_at_ts|
+-------------------+------------------+-------------+-------------------+
|                  1|            434011|    /lander-5|2025-02-01 00:01:38|
|                  2|            434011|    /products|2025-02-01 00:04:59|
|                  3|            434012|    /lander-3|2025-02-01 00:05:08|
|                  4|            434011|/the-corepack|2025-02-01 00:06:11|
|                  5|            434013|    /lander-5|2025-02-01 00:08:11|
|                  6|            434014|    /lander-3|2025-02-01 00:09:48|
|                  7|            434014|    /products|2025-02-01 00:10:52|
|                  8|            434015|    /lander-5|2025-02-01 00:11:05|
|                  9|            434015|    /products|2025-02-01 00:11:45|
|                 10|            434016|    /lander-5|2025-02-01 00:12:03|
|                 11|    

In [None]:
# Counting the number of null values in each column
wpvsh.select([count(when(col(c).isNull(), c)).alias(c) for c in wpvsh.columns]).show()

+-------------------+------------------+------------+-------------+
|website_pageview_id|website_session_id|pageview_url|created_at_ts|
+-------------------+------------------+------------+-------------+
|                  0|                 0|           0|            2|
+-------------------+------------------+------------+-------------+



In [None]:
# Dropping null values
wpvsh = wpvsh.na.drop()
wpvsh.select([count(when(col(c).isNull(), c)).alias(c) for c in wpvsh.columns]).show()

+-------------------+------------------+------------+-------------+
|website_pageview_id|website_session_id|pageview_url|created_at_ts|
+-------------------+------------------+------------+-------------+
|                  0|                 0|           0|            0|
+-------------------+------------------+------------+-------------+



In [None]:
# Duplicates check
total_rows = wpvsh.count()
distinct_ids = wpvsh.select("website_pageview_id").distinct().count()

duplicates = total_rows - distinct_ids
print("Number of duplicates:", duplicates)

Number of duplicates: 0


## 7/ Website_sessions_holdout

In [None]:
# website_sessions_holdoult visualization + main information
wsessh=spark\
.read\
.format("parquet")\
.option("header","true")\
.option("inferSchema","true")\
.load(website_sessions_holdout_filepath)

wsessh.show()
wsessh.printSchema()

+------------------+-------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+
|website_session_id|         created_at|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|        http_referer|traffic_source|
+------------------+-------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+
|            434011|2025-02-01 00:01:38| 363707|                0|   gsearch|    nonbrand|     g_ad_1|    desktop|https://www.gsear...|   paid_search|
|            434012|2025-02-01 00:05:08| 363708|                0|   gsearch|    nonbrand|     g_ad_1|     mobile|https://www.gsear...|   paid_search|
|            434013|2025-02-01 00:08:11| 363709|                0|   bsearch|    nonbrand|     b_ad_1|    desktop|https://www.bsear...|   paid_search|
|            434014|2025-02-01 00:09:48| 363710|                0|   gsearch|    nonbrand|    

In [None]:
# Casting created_at to timestamp
wsessh = wsessh.withColumn("created_at_ts", try_to_timestamp(col("created_at"), lit("yyyy-MM-dd HH:mm:ss")))
wsessh = wsessh.drop('created_at')
wsessh.show()
wsessh.printSchema()

+------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+-------------------+
|website_session_id|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|        http_referer|traffic_source|      created_at_ts|
+------------------+-------+-----------------+----------+------------+-----------+-----------+--------------------+--------------+-------------------+
|            434011| 363707|                0|   gsearch|    nonbrand|     g_ad_1|    desktop|https://www.gsear...|   paid_search|2025-02-01 00:01:38|
|            434012| 363708|                0|   gsearch|    nonbrand|     g_ad_1|     mobile|https://www.gsear...|   paid_search|2025-02-01 00:05:08|
|            434013| 363709|                0|   bsearch|    nonbrand|     b_ad_1|    desktop|https://www.bsear...|   paid_search|2025-02-01 00:08:11|
|            434014| 363710|                0|   gsearch|    nonbrand|     g_ad_1|     mobile|

In [None]:
# Counting the number of null values in each column
wsessh.select([count(when(col(c).isNull(), c)).alias(c) for c in wsessh.columns]).show()

+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|website_session_id|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|http_referer|traffic_source|created_at_ts|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|                 0|      0|                0|         0|           0|          0|          0|           0|             0|            2|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+



In [None]:
# Dropping null values
wsessh = wsessh.na.drop()
wsessh.select([count(when(col(c).isNull(), c)).alias(c) for c in wsessh.columns]).show()

+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|website_session_id|user_id|is_repeat_session|utm_source|utm_campaign|utm_content|device_type|http_referer|traffic_source|created_at_ts|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+
|                 0|      0|                0|         0|           0|          0|          0|           0|             0|            0|
+------------------+-------+-----------------+----------+------------+-----------+-----------+------------+--------------+-------------+



In [None]:
# Duplicates check
total_rows = wsessh.count()
distinct_ids = wsessh.select("website_session_id").distinct().count()

duplicates = total_rows - distinct_ids
print("Number of duplicates:", duplicates)

Number of duplicates: 0


# III/ Basetable creation

## 1/ Temporal split

In [None]:
# It is important in such a project to separate data into both training and validation sets
# While not necessarily a must, it helps preventing from data leakage and is more realistic in the context of business

In [None]:
# Handling time windows
# Still exploring data
orderit.select(max("created_at_ts")).show()
orderit.select(min("created_at_ts")).show()
orders.select(max("created_at_ts")).show()
orders.select(min("created_at_ts")).show()
products.select(max("created_at_ts")).show()
products.select(min("created_at_ts")).show()
wpv.select(max("created_at_ts")).show()
wpv.select(min("created_at_ts")).show()
wss.select(max("created_at_ts")).show()
wss.select(min("created_at_ts")).show()

+-------------------+
| max(created_at_ts)|
+-------------------+
|2025-01-31 23:30:48|
+-------------------+

+-------------------+
| min(created_at_ts)|
+-------------------+
|2022-03-19 10:42:46|
+-------------------+

+-------------------+
| max(created_at_ts)|
+-------------------+
|2025-01-31 23:30:48|
+-------------------+

+-------------------+
| min(created_at_ts)|
+-------------------+
|2022-03-19 10:42:46|
+-------------------+

+-------------------+
| max(created_at_ts)|
+-------------------+
|2024-02-05 10:00:00|
+-------------------+

+-------------------+
| min(created_at_ts)|
+-------------------+
|2022-03-19 08:00:00|
+-------------------+

+-------------------+
| max(created_at_ts)|
+-------------------+
|2025-01-31 23:55:44|
+-------------------+

+-------------------+
| min(created_at_ts)|
+-------------------+
|2022-03-19 08:04:16|
+-------------------+

+-------------------+
| max(created_at_ts)|
+-------------------+
|2025-01-31 23:54:52|
+-------------------+

+

In [None]:
# Defining cut-off dates
train_end = "2023-10-31 23:59:59"
val_end = "2024-12-31 23:59:59"
test_end = "2025-01-31 23:59:59"

# Defining gaps --> simulate real business speed + prevent temporal leakage between tables
train_end_gap = date_add(lit(train_end), 7)
val_end_gap = date_add(lit(val_end), 7)

# Separating both to prevent data leakage
# orderit
orderit_train = orderit.where(col("created_at_ts") <= lit(train_end))
orderit_val = orderit.where((col("created_at_ts") > train_end_gap) & (col("created_at_ts") <= lit(val_end)))
orderit_test = orderit.where((col("created_at_ts") > val_end_gap) & (col("created_at_ts") <= lit(test_end)))

# orders
orders_train = orders.where(col("created_at_ts") <= lit(train_end))
orders_val = orders.where((col("created_at_ts") > train_end_gap) & (col("created_at_ts") <= lit(val_end)))
orders_test = orders.where((col("created_at_ts") > val_end_gap) & (col("created_at_ts") <= lit(test_end)))

# wpv
wpv_train = wpv.where(col("created_at_ts") <= lit(train_end))
wpv_val = wpv.where((col("created_at_ts") > train_end_gap) & (col("created_at_ts") <= lit(val_end)))
wpv_test = wpv.where((col("created_at_ts") > val_end_gap) &(col("created_at_ts") <= lit(test_end)))

# wss
wss_train = wss.where(col("created_at_ts") <= lit(train_end))
wss_val = wss.where((col("created_at_ts") > train_end_gap) & (col("created_at_ts") <= lit(val_end)))
wss_test = wss.where((col("created_at_ts") > val_end_gap) & (col("created_at_ts") <= lit(test_end)))

In [None]:
# Checking it worked as planned
orderit_train.select(max("created_at_ts")).show()
orders_train.select(max("created_at_ts")).show()
wpv_train.select(max("created_at_ts")).show()
wss_train.select(max("created_at_ts")).show()

+-------------------+
| max(created_at_ts)|
+-------------------+
|2023-10-31 21:21:45|
+-------------------+

+-------------------+
| max(created_at_ts)|
+-------------------+
|2023-10-31 21:21:45|
+-------------------+

+-------------------+
| max(created_at_ts)|
+-------------------+
|2023-10-31 23:57:47|
+-------------------+

+-------------------+
| max(created_at_ts)|
+-------------------+
|2023-10-31 23:57:47|
+-------------------+



## 2/ Feature engineering

In [None]:
# We decided to wrap the feature engineering into a single function to be reapplied on the train, validation and test sets
# It is a cleaner way to work since it avoids code repetition
# Also, it ensures consistency between the basetables

In [None]:
from pyspark.sql.functions import date_add, lit, col, max, min, sum, count, countDistinct, when, datediff

def create_snapshot_basetable(wss_df, wpv_df, orders_df, snapshot_date, prediction_window_days=30):
    """
    Creates a strictly temporal dataset where:
    - Features are derived ONLY from history BEFORE snapshot_date.
    - Labels are derived ONLY from orders BETWEEN snapshot_date AND (snapshot_date + window).
    """

    # 1. Define Temporal Boundaries
    # Observation Window: Start of time -> Snapshot Date (Exclusive)
    # Performance Window: Snapshot Date (Inclusive) -> Snapshot Date + Window
    perf_window_end = date_add(lit(snapshot_date), prediction_window_days)

    print(f"--- Creating Snapshot for {snapshot_date} ---")
    print(f"Features: Data < {snapshot_date}")
    print(f"Target:   {snapshot_date} <= Order < {perf_window_end}")

    # 2. Filter Data for Features (The "Past")
    wss_hist = wss_df.filter(col("created_at_ts") < lit(snapshot_date))
    wpv_hist = wpv_df.filter(col("created_at_ts") < lit(snapshot_date))

    # If a user has no history before the snapshot, they shouldn't be in the model
    # (or they are 'cold start' users with 0 features)
    basetable = wss_hist.select("user_id").distinct()

    # 3. Filter Data for Target (The "Future")
    # Label = 1 if they ordered in the window
    orders_future = orders_df.filter(
        (col("created_at_ts") >= lit(snapshot_date)) &
        (col("created_at_ts") < perf_window_end)
    ).select("user_id").distinct().withColumn("label", lit(1))

    # Join Label (Left Join: Non-buyers get 0)
    basetable = basetable.join(orders_future, "user_id", "left").fillna(0, subset=["label"])

    # Now since we work at the user-level of granularity, we are supposed to aggregate their behavior in order to merge them into the basetable
    # Also, we can only use the wss and wpv tables since product is not relevant here and orderit and orders would cause an issue later
    # Indeed, predicting a behavior by taking into account if the client already bought might mislead the model into associating a high
    # correlation between buying and buying again, and mislabelling potential first buyers

    # Pageview Features (Behavioral Interest)
    # We join first to filter pageviews by the history window
    wpv_with_user = wpv_hist.join(wss_hist.select("website_session_id", "user_id"), "website_session_id")

    wpv_feats = wpv_with_user.groupBy("user_id").agg(
        # Note: We removed the "Total Pageviews" leak previously.
        # We assume specific product interest is safe IF it happened in the past.
        sum(when(col("pageview_url").contains("/the-corepack"), 1).otherwise(0)).alias("views_corepack"),
        sum(when(col("pageview_url").contains("/the-techfortress"), 1).otherwise(0)).alias("views_techfortress"),
        sum(when(col("pageview_url").contains("/airlite"), 1).otherwise(0)).alias("views_airlite"),
        sum(when(col("pageview_url").contains("/ecoshell"), 1).otherwise(0)).alias("views_ecoshell"),
        max(when(col("pageview_url").like("/lander%"), 1).otherwise(0)).alias("ever_saw_lander")
    )

    # Session Features (Recency, Frequency, Loyalty)
    wss_feats = wss_hist.groupBy("user_id").agg(
        count("website_session_id").alias("num_sessions"),
        sum("is_repeat_session").alias("num_repeat_sessions"),
        # Traffic Sources
        max(when(col("traffic_source") == "paid_search", 1).otherwise(0)).alias("is_paid_search_user"),
        max(when(col("traffic_source") == "organic_search",1).otherwise(0)).alias("is_organic_search_user"),
        max(when(col("traffic_source") == "direct", 1).otherwise(0)).alias("is_direct_user"),
        # Recency: Calculated relative to SNAPSHOT DATE (Not dataset end)
        # This fixes the "Recency Leak"
        datediff(lit(snapshot_date), max("created_at_ts")).alias("days_since_last_visit"),
        # Tenure
        datediff(lit(snapshot_date), min("created_at_ts")).alias("customer_tenure_days"),
        # Device
        avg(when(col("device_type") == "mobile", 1).otherwise(0)).alias("mobile_desktop_ratio")
    )

    # Merge Features
    final_df = basetable.join(wpv_feats, "user_id", "left").join(wss_feats, "user_id", "left").fillna(0) # since NULL values would indicate no activity here, which is a relevant metric also

    return final_df

## 3/ Basetables creation

In [None]:
# Here we apply the function above to create 3 ready-to-use basetables
# We define distinct dates for Train, Val, and Test snapshots.
# We choose dates that have enough history behind them.

# Train Snapshot: "Who will buy in Oct 2023 based on history pre-Oct 2023?"
train_snapshot_date = "2023-10-01"

# Val Snapshot: "Who will buy in Dec 2024 based on history pre-Dec 2024?"
val_snapshot_date = "2024-12-01"

# Test Snapshot: "Who will buy in Jan 2025 based on history pre-Jan 2025?"
test_snapshot_date = "2025-01-01"

# Create the clean basetables
# Note: Pass the RAW dataframes (wss, wpv, orders) because the function handles the filtering internally.
df_train = create_snapshot_basetable(wss, wpv, orders, train_snapshot_date)
df_val = create_snapshot_basetable(wss, wpv, orders, val_snapshot_date)
df_test = create_snapshot_basetable(wss, wpv, orders, test_snapshot_date)

# Calculate Class Weights for the NEW training set
total_train = df_train.count()
buyer_train = df_train.filter(col("label") == 1).count()
balancing_ratio = (total_train - buyer_train) / buyer_train
df_train = df_train.withColumn("class_weight", when(col("label") == 1, balancing_ratio).otherwise(1.0))

print(f"Balancing Ratio: {balancing_ratio:.2f}")
df_train.show(5)

--- Creating Snapshot for 2023-10-01 ---
Features: Data < 2023-10-01
Target:   2023-10-01 <= Order < Column<'date_add('2023-10-01', 30)'>
--- Creating Snapshot for 2024-12-01 ---
Features: Data < 2024-12-01
Target:   2024-12-01 <= Order < Column<'date_add('2024-12-01', 30)'>
--- Creating Snapshot for 2025-01-01 ---
Features: Data < 2025-01-01
Target:   2025-01-01 <= Order < Column<'date_add('2025-01-01', 30)'>
Balancing Ratio: 1292.95
+-------+-----+--------------+------------------+-------------+--------------+---------------+------------+-------------------+-------------------+----------------------+--------------+---------------------+--------------------+--------------------+------------+
|user_id|label|views_corepack|views_techfortress|views_airlite|views_ecoshell|ever_saw_lander|num_sessions|num_repeat_sessions|is_paid_search_user|is_organic_search_user|is_direct_user|days_since_last_visit|customer_tenure_days|mobile_desktop_ratio|class_weight|
+-------+-----+--------------+-----

In [None]:
# Before standardizing, making sure the variables we do standardize are relevant to standardize
df_train.show(7)

+-------+-----+--------------+------------------+-------------+--------------+---------------+------------+-------------------+-------------------+----------------------+--------------+---------------------+--------------------+--------------------+------------+
|user_id|label|views_corepack|views_techfortress|views_airlite|views_ecoshell|ever_saw_lander|num_sessions|num_repeat_sessions|is_paid_search_user|is_organic_search_user|is_direct_user|days_since_last_visit|customer_tenure_days|mobile_desktop_ratio|class_weight|
+-------+-----+--------------+------------------+-------------+--------------+---------------+------------+-------------------+-------------------+----------------------+--------------+---------------------+--------------------+--------------------+------------+
|    148|    0|             0|                 0|            0|             0|              0|           1|                  0|                  1|                     0|             0|                  560|    

In [None]:
# Verifying class imbalance
df_train.groupBy("label").count().show()
df_val.groupBy("label").count().show()
df_test.groupBy("label").count().show()

+-----+------+
|label| count|
+-----+------+
|    1|    91|
|    0|117658|
+-----+------+

+-----+------+
|label| count|
+-----+------+
|    1|   347|
|    0|319503|
+-----+------+

+-----+------+
|label| count|
+-----+------+
|    1|   421|
|    0|343883|
+-----+------+



In [None]:
# Calculate the balancing ratio based ONLY on training data
total_train = df_train.count()
buyer_train = df_train.filter(col("label") == 1).count()
balancing_ratio = (total_train - buyer_train) / buyer_train

# Add the weight column to df_train
df_train = df_train.withColumn("class_weight", when(col("label") == 1, balancing_ratio).otherwise(1.0))
print(f"Balancing Ratio: {balancing_ratio:.2f}")

Balancing Ratio: 1292.95


In [None]:
# Now we need to standardize the necessary columns and vectorize the tables to be able to model them through Spark

# We will implement en pipeline because it allows us to make sure every step of the process happens in desired time
cols_to_standardize = ["days_since_last_visit", "customer_tenure_days"]
# Exclude 'class_weight' as it's a label weight, not a feature
other_features = [c for c in df_train.columns if c not in cols_to_standardize and c not in ["user_id", "label", "class_weight"]]

assembler_to_scale = VectorAssembler(inputCols=cols_to_standardize, outputCol="temp_vector")
scaler = StandardScaler(inputCol="temp_vector", outputCol="scaled_vector", withMean=True, withStd=True)
final_assembler = VectorAssembler(inputCols=["scaled_vector"] + other_features, outputCol="features")

scaling_pipeline = Pipeline(stages=[assembler_to_scale, scaler, final_assembler])

# We only fit the pipeline on the training set to make sure we avoid data leakage
scaling_model = scaling_pipeline.fit(df_train)

# Finally we transform on all of them
basetable_train = scaling_model.transform(df_train).select("user_id", "label", "features", "class_weight")
basetable_val = scaling_model.transform(df_val).select("user_id", "label", "features")
basetable_test = scaling_model.transform(df_test).select("user_id", "label", "features")

In [None]:
basetable_train.show(5, truncate=False)

+-------+-----+--------------------------------------------------------------------------+------------+
|user_id|label|features                                                                  |class_weight|
+-------+-----+--------------------------------------------------------------------------+------------+
|148    |0    |(13,[0,1,7,9],[2.0741468107166474,2.0518961109615814,1.0,1.0])            |1.0         |
|463    |0    |(13,[0,1,7,9],[2.067550406149753,2.04528343523231,1.0,1.0])               |1.0         |
|471    |0    |(13,[0,1,7,9,12],[2.067550406149753,2.04528343523231,1.0,1.0,1.0])        |1.0         |
|496    |0    |(13,[0,1,7,9],[2.0609540015828585,2.0386707595030384,1.0,1.0])            |1.0         |
|833    |0    |(13,[0,1,2,7,9,12],[2.0477611924490695,2.025445408044496,1.0,1.0,1.0,1.0])|1.0         |
+-------+-----+--------------------------------------------------------------------------+------------+
only showing top 5 rows


# IV/ Binary models implementation

--> As we have seen in the basetable creating stage, the basetables are severely imbalanced.

--> Therefore we have to take it into consideration in the way we implement our models.

--> One famous method would be to use SMOTE.

--> However SMOTE is not as easy to implement in pyspark as it can be in scikitlearn.

--> One alternative that presents advantages as well is Class Weights.

--> It preserves data integrity (because unlike SMOTE, it does not create new syntetic users).

--> It preserves the size of the dataset, and therefore saves computing time, power and cost.

--> Since it works by penalizing more when it misclassifies a buyer than when it misclassifies a simple browser, it becomes able to identify the small details that lead to buy and therefore has better recall.

--> We are aware that it could also lead to offerfitting to outliers.

--> We are also aware that it could lead to overestimation of the probability to buy, which could lead to some marketing "waste".

--> But we think it is still better than SMOTE because SMOTE could lead to harder interpretation of the results since it produces "artificial" results.

In [None]:
# Here we will define some hyperparameters
# Train the data once on the training set
# Then tune the hyperparameters on the validation set
# Finally we will score the model on the test set

In [None]:
# Here are the hyperparameters
reg_params = [0.01, 0.1, 1.0]           # Regularization strength
elastic_net_params = [0.0, 0.5, 1.0]    # 0 = L2 (Ridge), 1 = L1 (Lasso), 0.5 = Mix

# Initialize evaluator
evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC")

best_auc = 0.0
best_model = None
best_params = {}

print("--- Starting Grid Search ---")

# The Tuning Loop
for reg in reg_params:
    for enet in elastic_net_params:

        # Instantiate Logistic Regression with class imbalance handling as explained above
        lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="class_weight", regParam=reg,
            elasticNetParam=enet)

        # Train on train set
        model = lr.fit(basetable_train)

        # Predict on validation set
        predictions_val = model.transform(basetable_val)

        # Evaluate
        auc_val = evaluator.evaluate(predictions_val)

        print(f"Params: reg={reg}, elasticNet={enet} -> Val AUC: {auc_val:.4f}")

        # Track the best model
        if auc_val > best_auc:
            best_auc = auc_val
            best_model = model
            best_params = {"regParam": reg, "elasticNetParam": enet}

print("-" * 30)
print(f"Best Parameters found: {best_params}")
print(f"Best Validation AUC: {best_auc:.4f}")

--- Starting Grid Search ---
Params: reg=0.01, elasticNet=0.0 -> Val AUC: 0.9466
Params: reg=0.01, elasticNet=0.5 -> Val AUC: 0.9476
Params: reg=0.01, elasticNet=1.0 -> Val AUC: 0.9484
Params: reg=0.1, elasticNet=0.0 -> Val AUC: 0.9429
Params: reg=0.1, elasticNet=0.5 -> Val AUC: 0.9475
Params: reg=0.1, elasticNet=1.0 -> Val AUC: 0.9464
Params: reg=1.0, elasticNet=0.0 -> Val AUC: 0.9309
Params: reg=1.0, elasticNet=0.5 -> Val AUC: 0.5000
Params: reg=1.0, elasticNet=1.0 -> Val AUC: 0.5000
------------------------------
Best Parameters found: {'regParam': 0.01, 'elasticNetParam': 1.0}
Best Validation AUC: 0.9484


In [None]:
# Predict on test set using the best model
predictions_test = best_model.transform(basetable_test)

# Evaluate AUC on test
auc_test = evaluator.evaluate(predictions_test)
print(f"Final Test Set AUC: {auc_test:.4f}")

# Confusion Matrix & Other Metrics (Precision/Recall)

# Spark's MulticlassMetrics requires RDDs of (prediction, label)
# We cast to Float because MulticlassMetrics can sometimes be picky with types
prediction_and_labels = predictions_test.select(
    col("prediction").cast(FloatType()),
    col("label").cast(FloatType())
).rdd

metrics = MulticlassMetrics(prediction_and_labels)

print("\n--- Confusion Matrix ---")
print(metrics.confusionMatrix().toArray())

print("\n--- Detailed Metrics ---")
print(f"Accuracy: {metrics.accuracy:.4f}")
print(f"Precision (Label 1): {metrics.precision(1.0):.4f}")
print(f"Recall (Label 1): {metrics.recall(1.0):.4f}")
print(f"F1 Score (Label 1): {metrics.fMeasure(1.0):.4f}")

Final Test Set AUC: 0.9454

--- Confusion Matrix ---
[[289652.  54231.]
 [     0.    421.]]

--- Detailed Metrics ---
Accuracy: 0.8425
Precision (Label 1): 0.0077
Recall (Label 1): 1.0000
F1 Score (Label 1): 0.0153
