In [None]:
%run common.ipynb

In [1]:
from pyspark.sql import SparkSession
import random

# Initialize Spark Session with MySQL JDBC Driver
spark = SparkSession.builder.appName("churn-prediction") \
    .config("spark.jars.packages", "mysql:mysql-connector-java:8.0.33") \
    .getOrCreate()


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/home/glue_user/spark/jars/log4j-slf4j-impl-2.17.2.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/glue_user/spark/jars/slf4j-reload4j-1.7.36.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/glue_user/aws-glue-libs/jars/log4j-slf4j-impl-2.17.2.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/glue_user/aws-glue-libs/jars/slf4j-reload4j-1.7.36.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.apache.logging.slf4j.Log4jLoggerFactory]


:: loading settings :: url = jar:file:/home/glue_user/spark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/glue_user/.ivy2/cache
The jars for the packages stored in: /home/glue_user/.ivy2/jars
mysql#mysql-connector-java added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-36b06eb3-3159-4e87-ab3c-c0488d8105bc;1.0
	confs: [default]
	found mysql#mysql-connector-java;8.0.33 in central
	found com.mysql#mysql-connector-j;8.0.33 in central
	found com.google.protobuf#protobuf-java;3.21.9 in central
:: resolution report :: resolve 265ms :: artifacts dl 6ms
	:: modules in use:
	com.google.protobuf#protobuf-java;3.21.9 from central in [default]
	com.mysql#mysql-connector-j;8.0.33 from central in [default]
	mysql#mysql-connector-java;8.0.33 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	-------------------------------------------------------------

In [3]:
# Read MySQL table into PySpark DataFrame
customer_df = spark.read.jdbc(url=USER_MYSQL_URL, table="Customers", properties=MYSQL_PROPERTIES)
order_df = spark.read.jdbc(url=ORDER_MYSQL_URL, table="Orders", properties=MYSQL_PROPERTIES)
products_df = spark.read.jdbc(url=PRODUCT_MYSQL_URL, table="Products", properties=MYSQL_PROPERTIES)
# Register as Temp Tables
customer_df.createOrReplaceTempView("customers")
order_df.createOrReplaceTempView("orders")


In [4]:
spark.sql("""
    SELECT
        customer_id,
        order_id,
        order_date,
        LAG(order_date) OVER (PARTITION BY customer_id ORDER BY order_date) AS prev_order_date
    FROM orders """).show()

[Stage 2:>                                                          (0 + 1) / 1]

+-----------+--------+-------------------+-------------------+
|customer_id|order_id|         order_date|    prev_order_date|
+-----------+--------+-------------------+-------------------+
|          2|    9898|2024-06-14 00:00:00|               null|
|          3|    9426|2023-12-05 00:00:00|               null|
|          4|    7885|2020-11-25 00:00:00|               null|
|          4|     393|2021-07-11 00:00:00|2020-11-25 00:00:00|
|          5|     534|2021-08-04 00:00:00|               null|
|          5|    9077|2022-09-03 00:00:00|2021-08-04 00:00:00|
|          5|    2349|2024-07-15 00:00:00|2022-09-03 00:00:00|
|          6|    9643|2020-09-03 00:00:00|               null|
|          6|    1232|2021-02-18 00:00:00|2020-09-03 00:00:00|
|          8|    7769|2020-11-02 00:00:00|               null|
|          8|    6016|2024-05-31 00:00:00|2020-11-02 00:00:00|
|          8|    5599|2024-07-20 00:00:00|2024-05-31 00:00:00|
|          9|    9277|2020-08-16 00:00:00|             

                                                                                

In [5]:
spark.sql(""" WITH customer_activity AS (
    SELECT
        customer_id,
        order_id,
        order_date,
        LAG(order_date) OVER (PARTITION BY customer_id ORDER BY order_date) AS prev_order_date
    FROM orders
),
churn_risk AS (
    SELECT
        customer_id,
        COUNT(order_id) AS total_orders,
        MAX(order_date) AS last_order_date,
        DATEDIFF(current_date, MAX(order_date)) AS days_since_last_purchase,  
        AVG(DATEDIFF(order_date, prev_order_date)) AS avg_order_gap  
    FROM customer_activity
    GROUP BY customer_id
)
SELECT *
FROM churn_risk
WHERE days_since_last_purchase > (avg_order_gap * 2)  -- Customers inactive for double their average gap
ORDER BY days_since_last_purchase DESC;
 """).show()

+-----------+------------+-------------------+------------------------+-------------+
|customer_id|total_orders|    last_order_date|days_since_last_purchase|avg_order_gap|
+-----------+------------+-------------------+------------------------+-------------+
|       8615|           2|2020-01-22 00:00:00|                    1867|         17.0|
|       8099|           2|2020-03-07 00:00:00|                    1822|         21.0|
|       1343|           2|2020-03-20 00:00:00|                    1809|         17.0|
|       2443|           2|2020-04-03 00:00:00|                    1795|         76.0|
|       1814|           2|2020-04-17 00:00:00|                    1781|         71.0|
|       6126|           2|2020-04-30 00:00:00|                    1768|         17.0|
|       7014|           2|2020-05-14 00:00:00|                    1754|        109.0|
|       3617|           2|2020-05-15 00:00:00|                    1753|         50.0|
|       5367|           2|2020-05-22 00:00:00|        

In [6]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Define window specification for LAG function
window_spec = Window.partitionBy("customer_id").orderBy("order_date")

# Add previous order date using LAG function
customer_activity = order_df.withColumn(
    "prev_order_date", F.lag("order_date").over(window_spec)
)

# Compute total orders, last order date, days since last purchase, and average order gap
churn_risk = (
    customer_activity.groupBy("customer_id")
    .agg(
        F.count("order_id").alias("total_orders"),
        F.max("order_date").alias("last_order_date"),
        F.datediff(F.current_date(), F.max("order_date")).alias("days_since_last_purchase"),
        F.avg(F.datediff("order_date", "prev_order_date")).alias("avg_order_gap"),
    )
)

# Filter customers who are inactive for more than twice their average order gap
churn_risk_filtered = churn_risk.filter(
    F.col("days_since_last_purchase") > (F.col("avg_order_gap") * 2)
).orderBy(F.desc("days_since_last_purchase"))

# Show results
churn_risk_filtered.show(10)


+-----------+------------+-------------------+------------------------+-------------+
|customer_id|total_orders|    last_order_date|days_since_last_purchase|avg_order_gap|
+-----------+------------+-------------------+------------------------+-------------+
|       8615|           2|2020-01-22 00:00:00|                    1867|         17.0|
|       8099|           2|2020-03-07 00:00:00|                    1822|         21.0|
|       1343|           2|2020-03-20 00:00:00|                    1809|         17.0|
|       2443|           2|2020-04-03 00:00:00|                    1795|         76.0|
|       1814|           2|2020-04-17 00:00:00|                    1781|         71.0|
|       6126|           2|2020-04-30 00:00:00|                    1768|         17.0|
|       7014|           2|2020-05-14 00:00:00|                    1754|        109.0|
|       3617|           2|2020-05-15 00:00:00|                    1753|         50.0|
|       5367|           2|2020-05-22 00:00:00|        

In [7]:
# Define AWS Glue database and table names
glue_database = "customer_analytics"
glue_table = "churn_risk"

# Define S3 output path
s3_output_path = "s3://feb2025-training-bucket/analytics/churn_risk/"

# Create the AWS Glue Catalog table using the DataFrame
churn_risk_filtered.write \
    .format("parquet") \
    .mode("overwrite") \
    .option("path", s3_output_path) \
    .saveAsTable(f"{glue_database}.{glue_table}")

print(f"Aggregated sales data written to S3: {s3_output_path}")
print(f"Glue table '{glue_database}.{glue_table}' created successfully.")


25/03/03 12:51:30 INFO HiveConf: Found configuration file file:/home/glue_user/spark/conf/hive-site.xml
25/03/03 12:51:32 WARN InstanceMetadataServiceResourceFetcher: Fail to retrieve token 
com.amazonaws.SdkClientException: Failed to connect to service endpoint: 
	at com.amazonaws.internal.EC2ResourceFetcher.doReadResource(EC2ResourceFetcher.java:100)
	at com.amazonaws.internal.InstanceMetadataServiceResourceFetcher.getToken(InstanceMetadataServiceResourceFetcher.java:91)
	at com.amazonaws.internal.InstanceMetadataServiceResourceFetcher.readResource(InstanceMetadataServiceResourceFetcher.java:69)
	at com.amazonaws.internal.EC2ResourceFetcher.readResource(EC2ResourceFetcher.java:66)
	at com.amazonaws.util.EC2MetadataUtils.getItems(EC2MetadataUtils.java:407)
	at com.amazonaws.util.EC2MetadataUtils.getData(EC2MetadataUtils.java:376)
	at com.amazonaws.util.EC2MetadataUtils.getData(EC2MetadataUtils.java:372)
	at com.amazonaws.util.EC2MetadataUtils.getEC2InstanceRegion(EC2MetadataUtils.java

Aggregated sales data written to S3: s3://feb2025-training-bucket/analytics/churn_risk/
Glue table 'customer_analytics.churn_risk' created successfully.
