In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [2]:
# Before doing any DataFrame ops:
from pyspark.sql import SparkSession

spark = (
  SparkSession.builder
    .appName("Codeway")
    .master("local[*]")
    .config("spark.driver.memory", "4g")
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/25 13:14:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/07/25 13:14:57 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
# read the Parquet file
df = spark.read.parquet('/Users/macbookpro/Desktop/Codeway Case Study/ltvprediction_case.parquet')

                                                                                

# Pre-processing

In [4]:
df =  df \
    .withColumn("first_event_date", F.to_date(F.col("first_event_date"), "yyyy-MM-dd")) \
    .withColumn("event_date",       F.to_date(F.col("event_date"),       "yyyy-MM-dd"))

In [5]:
df.agg(
    F.min(F.col("event_date")).alias("earliest_event_date")
).show()



+-------------------+
|earliest_event_date|
+-------------------+
|         2022-08-19|
+-------------------+



                                                                                

In [6]:
# Show the latest event_date in df
df.agg(
    F.max(F.col("event_date")).alias("latest_event_date")
).show()



+-----------------+
|latest_event_date|
+-----------------+
|       2022-12-31|
+-----------------+



                                                                                

In [9]:
# Identify user_ids with more than one distinct first_year_revenue
users_multi_fyr = (
    df
    .groupBy("user_id")
    .agg(F.countDistinct("first_year_revenue").alias("fyr_count"))
    .filter(F.col("fyr_count") > 1)
    .select("user_id")
)

# Show those user_ids
users_multi_fyr.show(truncate=False)

                                                                                

+--------------+
|user_id       |
+--------------+
|66298D8D5959-7|
|037B19192323-5|
|113C1010A5A5-8|
|3042EBEB6666-8|
|CBFC2C2C4141-F|
|E84DEAEAD9D9-3|
+--------------+



There are 6 users with multiple "first_year_revenue"

In [10]:
df.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- first_event_date: date (nullable = true)
 |-- operating_system: string (nullable = true)
 |-- country: string (nullable = true)
 |-- event_name: string (nullable = true)
 |-- event_time: long (nullable = true)
 |-- event_date: date (nullable = true)
 |-- revenue: double (nullable = true)
 |-- first_year_revenue: double (nullable = true)



For each user, return the row where "event_date" is the latest among all the events of the user

In [11]:
from pyspark.sql.window import Window

# 1. Window for latest event per user
w = Window.partitionBy("user_id").orderBy(F.col("event_date").desc())

# 2. Pick each user’s most recent event
latest_event = (
    df
    .withColumn("rn", F.row_number().over(w))
    .filter(F.col("rn") == 1)
    .drop("rn")
    .withColumn("last_event_date", F.col("event_date"))
)

In [12]:
# Count how many user_ids appear more than once
duplicate_count = (
    latest_event
    .groupBy("user_id")
    .agg(F.count("*").alias("cnt"))
    .filter(F.col("cnt") > 1)
    .count()
)

if duplicate_count == 0:
    print("No duplicate user_id values in latest_event.")
else:
    print(f"Found {duplicate_count} user_id(s) with duplicates in latest_event.")



No duplicate user_id values in latest_event.


                                                                                

This part is redundant, but just to make sure I still wanted to check: \
Only keep the rows where: \
"first_event_date" + 15 days >= "event_date"

In [16]:
df_15 = df.filter(
    F.date_add(F.col("first_event_date"), 15) >= F.col("event_date")
)

In [17]:
# compute counts and their difference
orig_count  = df.count()
first_count = df_15.count()
rows_eliminated = orig_count - first_count

print(f"Rows eliminated from df to df_15: {rows_eliminated}")



Rows eliminated from df to df_15: 0


                                                                                

It turns out that the data INDEED does not include any records for any user after the first 15 days of their registration.

Check: \
"first_event_date" <= "event_date" \
"revenue" >= 0

In [18]:
invalid_counts = df_15.agg(
    F.sum(F.when(F.col("first_event_date") > F.col("event_date"), 1).otherwise(0))
     .alias("invalid_date_order"),
    F.sum(F.when(F.col("revenue") < 0, 1).otherwise(0))
     .alias("negative_revenue")
)

invalid_counts.show()



+------------------+----------------+
|invalid_date_order|negative_revenue|
+------------------+----------------+
|                 0|            2040|
+------------------+----------------+



                                                                                

Check: \
"revenue" = Null \
"revenue" = 0

In [19]:
df_15.agg(
    F.sum(F.when(F.col("revenue").isNull(), 1).otherwise(0))
     .alias("revenue_null_count"),
    F.sum(F.when(F.col("revenue") == 0, 1).otherwise(0))
     .alias("revenue_zero_count")
).show()



+------------------+------------------+
|revenue_null_count|revenue_zero_count|
+------------------+------------------+
|           6818187|                96|
+------------------+------------------+



                                                                                

Replace "revenue" = null with 0 \
Remove "revenue" < 0 (Remove the negative revenues (by definition, revenue is not negative.)

In [20]:
df_15_clean = (
    df_15
      .withColumn("revenue", F.coalesce(F.col("revenue"), F.lit(0)))
      .filter(F.col("revenue") >= 0)
)

In [21]:
invalid_counts_1 = df_15_clean.agg(
    F.sum(F.when(F.col("first_event_date") > F.col("event_date"), 1).otherwise(0))
     .alias("invalid_date_order"),
    F.sum(F.when(F.col("revenue") < 0, 1).otherwise(0))
     .alias("negative_revenue")
)

invalid_counts_1.show()



+------------------+----------------+
|invalid_date_order|negative_revenue|
+------------------+----------------+
|                 0|               0|
+------------------+----------------+



                                                                                

In [22]:
df_15_clean.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- first_event_date: date (nullable = true)
 |-- operating_system: string (nullable = true)
 |-- country: string (nullable = true)
 |-- event_name: string (nullable = true)
 |-- event_time: long (nullable = true)
 |-- event_date: date (nullable = true)
 |-- revenue: double (nullable = false)
 |-- first_year_revenue: double (nullable = true)



# Additional Pre-processing before User-level Analysis

In [23]:
user_first = (
    df_15_clean
    .groupBy("user_id")
    .agg(
        F.countDistinct("first_event_date").alias("first_count"),
    )
)

In [24]:
user_first.filter(F.col("first_count") > 1).count()

                                                                                

36

In [25]:
# 1. Identify users with exactly one distinct first_event_date
single_date_users = (
    df_15_clean
      .groupBy("user_id")
      .agg(F.countDistinct("first_event_date").alias("date_cnt"))
      .filter(F.col("date_cnt") == 1)
      .select("user_id")
)

# 2. Keep only those users in df_15_clean
df_15_clean_1 = df_15_clean.join(single_date_users, on="user_id", how="inner")

In [26]:
user_first_1 = (
    df_15_clean_1
    .groupBy("user_id")
    .agg(
        F.countDistinct("first_event_date").alias("first_count"),
    )
)

In [27]:
user_first_1.filter(F.col("first_count") > 1).count()

                                                                                

0

In [28]:
# 1. Compute distinct‐counts per user
user_counts = (
    df_15_clean_1
    .groupBy("user_id")
    .agg(
        F.countDistinct("operating_system").alias("os_count"),
        F.countDistinct("country").alias("country_count")
    )
)

# 2. Aggregate to get the three metrics
summary_counts = user_counts.agg(
    F.sum(F.when(F.col("os_count")  > 1, 1).otherwise(0)).alias("users_multi_os"),
    F.sum(F.when(F.col("country_count") > 1, 1).otherwise(0)).alias("users_multi_country"),
    F.sum(
        F.when(
            (F.col("os_count") > 1) & (F.col("country_count") > 1),
            1
        ).otherwise(0)
    ).alias("users_multi_os_and_country")
)

summary_counts.show()



+--------------+-------------------+--------------------------+
|users_multi_os|users_multi_country|users_multi_os_and_country|
+--------------+-------------------+--------------------------+
|             0|                  0|                         0|
+--------------+-------------------+--------------------------+



                                                                                

There were 36 users with multi-country and/or multi-OS, which happened to be exactly the users with multiple "first_event_date". This implies that the application stores the "first_event_date" for each user per the "operating_system" they use (not per "country", though).

Use seasons in the  algorithm:

In [29]:
# 1. Extract month
df_season = df_15_clean_1.withColumn("month", F.month("first_event_date"))

# 2. Map month → season
df_season = df_season.withColumn(
    "season",
    F.when(F.col("month").isin(12,  1,  2), "winter")
     .when(F.col("month").isin( 3,  4,  5), "spring")
     .when(F.col("month").isin( 6,  7,  8), "summer")
     .otherwise("autumn")
).drop("month")

In [30]:
df_season = df_season.filter(F.col("country").isNotNull())

In [31]:
null_counts = df_season.agg(
    *[
        F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c)
        for c in df_season.columns
    ]
)

null_counts.show()



+-------+----------------+----------------+-------+----------+----------+----------+-------+------------------+------+
|user_id|first_event_date|operating_system|country|event_name|event_time|event_date|revenue|first_year_revenue|season|
+-------+----------------+----------------+-------+----------+----------+----------+-------+------------------+------+
|      0|               0|               0|      0|         0|         0|         0|      0|                 0|     0|
+-------+----------------+----------------+-------+----------+----------+----------+-------+------------------+------+



                                                                                

In [32]:
# Identify user_ids with more than one distinct first_year_revenue
users_multi_fyr = (
    df_season
    .groupBy("user_id")
    .agg(F.countDistinct("first_year_revenue").alias("fyr_count"))
    .filter(F.col("fyr_count") > 1)
    .select("user_id")
)

# Show those user_ids
users_multi_fyr.show(truncate=False)



+-------+
|user_id|
+-------+
+-------+



                                                                                

No user has multiple "first_year_revenue" values.

In [27]:
df_season.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- first_event_date: date (nullable = true)
 |-- operating_system: string (nullable = true)
 |-- country: string (nullable = true)
 |-- event_name: string (nullable = true)
 |-- event_time: long (nullable = true)
 |-- event_date: date (nullable = true)
 |-- revenue: double (nullable = false)
 |-- first_year_revenue: double (nullable = true)
 |-- season: string (nullable = false)



# Getting ready for the  analysis

Extract “hour of day” from event_time (assumes event_time is UNIX seconds)

In [33]:
df_events = df_season \
  .withColumn("event_ts",   F.from_unixtime(F.col("event_time"))) \
  .withColumn("event_hour", F.hour(F.col("event_ts")))

Compute per-user aggregates: total_revenue, avg_event_hour, stickiness_ratio

In [34]:
user_agg = (
  df_events
    .groupBy("user_id")
    .agg(
      F.sum("revenue").alias("total_revenue"),
      F.avg("event_hour").alias("avg_event_hour"),
      (F.countDistinct("event_date")/F.lit(15)).alias("stickiness_ratio")
    )
)

Pivot event_name into per-user event-count features

In [35]:
df_events.select("event_name").distinct().show()



+--------------+
|    event_name|
+--------------+
|       paywall|
|auto_renew_off|
|    free_trial|
|     subscribe|
|        refund|
|       renewal|
+--------------+



                                                                                

In [36]:
user_events = (
  df_events
    .groupBy("user_id")
    .pivot("event_name")
    .count()
    .na.fill(0)    # fill zero for missing event types
)

                                                                                

In [37]:
user_events.show(10)

[Stage 174:>                                                        (0 + 1) / 1]

+--------------+--------------+----------+-------+------+-------+---------+
|       user_id|auto_renew_off|free_trial|paywall|refund|renewal|subscribe|
+--------------+--------------+----------+-------+------+-------+---------+
|00004B4BB8B8-1|             0|         0|      1|     0|      0|        0|
|0000CFCF6B6B-E|             0|         1|      1|     0|      1|        1|
|00013737EEEE-2|             0|         1|      1|     0|      1|        1|
|0001BCBCDADA-0|             0|         0|      1|     0|      0|        0|
|00023A3A5959-6|             1|         1|      1|     0|      0|        0|
|00024D4DD7D7-6|             0|         0|      1|     0|      0|        0|
|00029191ACAC-0|             0|         0|      1|     0|      0|        0|
|000358581B1B-F|             0|         0|      2|     0|      0|        0|
|0003AEAE3F3F-F|             0|         0|      6|     0|      0|        0|
|0003CECE2424-B|             0|         0|      4|     0|      0|        0|
+-----------

                                                                                

Check if there any users who are active on more than one season (highly unlikely due to our 15-day data, but just to make sure)

In [38]:
# 1. Identify users with more than one distinct season
multi_season_users = (
    df_season
    .groupBy("user_id")
    .agg(F.countDistinct("season").alias("season_count"))
    .filter(F.col("season_count") > 1)
    .select("user_id")
)

# 2. Show all rows in df_season for those users
df_season_multi = df_season.join(multi_season_users, on="user_id", how="inner")

In [39]:
df_season_multi.show(20)



+-------+----------------+----------------+-------+----------+----------+----------+-------+------------------+------+
|user_id|first_event_date|operating_system|country|event_name|event_time|event_date|revenue|first_year_revenue|season|
+-------+----------------+----------------+-------+----------+----------+----------+-------+------------------+------+
+-------+----------------+----------------+-------+----------+----------+----------+-------+------------------+------+



                                                                                

No user is active during more than one season.

Bring in your static features (OS, country, season)

In [40]:
df_season.select("season").distinct().show(10)



+------+
|season|
+------+
|winter|
|summer|
|autumn|
+------+



                                                                                

In [42]:
df_season = df_season.dropDuplicates()

In [43]:
# This uses all columns to define “distinctness”
unique_rows = df_season.distinct().count()
print(f"Distinct rows: {unique_rows}")
print(f"Total rows:    {df_season.count()}")

                                                                                

Distinct rows: 6974445




Total rows:    6974445


                                                                                

In [44]:
user_static = df_season \
  .select("user_id", "operating_system", "country", "season", "first_event_date", "first_year_revenue") \
  .distinct()

Drop exact duplicates across all columns

Join everything into one features DataFrame

In [45]:
df_features = (
  user_static
    .join(user_agg,    on="user_id", how="inner")
    .join(user_events, on="user_id", how="left")
)

In [46]:
df_features.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- operating_system: string (nullable = true)
 |-- country: string (nullable = true)
 |-- season: string (nullable = false)
 |-- first_event_date: date (nullable = true)
 |-- first_year_revenue: double (nullable = true)
 |-- total_revenue: double (nullable = true)
 |-- avg_event_hour: double (nullable = true)
 |-- stickiness_ratio: double (nullable = true)
 |-- auto_renew_off: long (nullable = true)
 |-- free_trial: long (nullable = true)
 |-- paywall: long (nullable = true)
 |-- refund: long (nullable = true)
 |-- renewal: long (nullable = true)
 |-- subscribe: long (nullable = true)



# Write the dataframe

In [47]:
df_features.count()

                                                                                

2198579

In [48]:
df_features.rdd.getNumPartitions()



9

In [49]:
# master & driver memory
print("master =", spark.sparkContext.master)
print("driver.memory =", spark.sparkContext.getConf().get("spark.driver.memory"))

master = local[*]
driver.memory = 4g


In [50]:
# Write
df_features \
  .repartition(50) \
  .write \
  .mode("overwrite") \
  .json("/Users/macbookpro/PyCharmMiscProject/df_features_json")

                                                                                