In [26]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, DateType
from datetime import datetime
from pyspark.sql.functions import *
from pyspark.sql.window import Window
# Initialize Spark Session
spark = SparkSession.builder.appName("PySparkTransactions").getOrCreate()

# Function to Convert String to Date
def parse_date(date_str):
    return datetime.strptime(date_str, "%Y-%m-%d").date()

# Define Schema for Transactions Table
transactions_schema = StructType([
    StructField("order_id", IntegerType(), True),
    StructField("cust_id", IntegerType(), True),
    StructField("order_date", DateType(), True),
    StructField("amount", IntegerType(), True)
])

# Transactions Data with Converted Dates
transactions_data = [
    (1, 1, parse_date('2020-01-15'), 150),
    (2, 1, parse_date('2020-02-10'), 150),
    (3, 2, parse_date('2020-01-16'), 150),
    (4, 2, parse_date('2020-02-25'), 150),
    (5, 3, parse_date('2020-01-10'), 150),
    (6, 3, parse_date('2020-02-20'), 150),
    (7, 4, parse_date('2020-01-20'), 150),
    (8, 5, parse_date('2020-02-20'), 150),
    (9, 5, parse_date('2020-03-20'), 150)

]

# Create Transactions DataFrame
transactions_df = spark.createDataFrame(transactions_data, schema=transactions_schema)
transactions_df.createOrReplaceTempView("transactions")

# Verify Transactions Table
spark.sql("SELECT * FROM transactions").show()


+--------+-------+----------+------+
|order_id|cust_id|order_date|amount|
+--------+-------+----------+------+
|       1|      1|2020-01-15|   150|
|       2|      1|2020-02-10|   150|
|       3|      2|2020-01-16|   150|
|       4|      2|2020-02-25|   150|
|       5|      3|2020-01-10|   150|
|       6|      3|2020-02-20|   150|
|       7|      4|2020-01-20|   150|
|       8|      5|2020-02-20|   150|
|       9|      5|2020-03-20|   150|
+--------+-------+----------+------+



In [23]:
spark.sql(
"""
    select post_month, count(*) from (select *, post_month - month as diff from (
    select *, month(order_date) as month,
    lead(month(order_date)) over(partition by cust_id order by order_date) as post_month
    from transactions)) where diff = 1
    group by post_month
    
""").show()

+----------+--------+
|post_month|count(1)|
+----------+--------+
|         3|       1|
|         2|       3|
+----------+--------+



In [41]:
spark.sql(
"""
WITH monthly_orders AS (
    SELECT 
        cust_id, 
        month(order_date) AS month,
        LEAD(month(order_date)) OVER (PARTITION BY cust_id ORDER BY order_date) AS post_month
    FROM transactions
)
SELECT 
    post_month, 
    COUNT(*) AS retained_customers
FROM monthly_orders
WHERE post_month - month = 1
GROUP BY post_month
ORDER BY post_month;

""").show()

+----------+------------------+
|post_month|retained_customers|
+----------+------------------+
|         2|                 3|
|         3|                 1|
+----------+------------------+



In [37]:
window_spec = Window.partitionBy(col("cust_id")).orderBy(col("order_date"))
monthly_orders = transactions_df.select("cust_id", 
                        month(col("order_date")).alias("month"), 
                        lead(month(col("order_date"))).over(window_spec).alias("post_month"))
                                        

In [40]:
monthly_orders.filter(col("post_month") - col("month") == 1).groupBy(col("post_month")).agg(
    count("*").alias("retained_customers")
).show()

+----------+------------------+
|post_month|retained_customers|
+----------+------------------+
|         3|                 1|
|         2|                 3|
+----------+------------------+

