In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [3]:
spark = (
  SparkSession.builder
    .appName("Codeway")
    .master("local[*]")
    .config("spark.driver.memory", "4g")
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/26 21:30:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Read the dataframe

In [4]:
df_features = spark.read.json("/Users/macbookpro/PyCharmMiscProject/df_features_json")

                                                                                

In [4]:
df_features.printSchema()

root
 |-- auto_renew_off: long (nullable = true)
 |-- avg_event_hour: double (nullable = true)
 |-- country: string (nullable = true)
 |-- first_event_date: string (nullable = true)
 |-- first_year_revenue: double (nullable = true)
 |-- free_trial: long (nullable = true)
 |-- operating_system: string (nullable = true)
 |-- paywall: long (nullable = true)
 |-- refund: long (nullable = true)
 |-- renewal: long (nullable = true)
 |-- season: string (nullable = true)
 |-- stickiness_ratio: double (nullable = true)
 |-- subscribe: long (nullable = true)
 |-- total_revenue: double (nullable = true)
 |-- user_id: string (nullable = true)



In [5]:
# 1. Group by user_id and count occurrences
dup_df = (
    df_features
      .groupBy("user_id")
      .agg(F.count("*").alias("cnt"))
      .filter("cnt > 1")
)

# 2. Check if any duplicates exist
dup_count = dup_df.count()
if dup_count == 0:
    print("✅ All user_id values are unique.")
else:
    print(f"⚠️ Found {dup_count} user_id(s) appearing more than once:")
    dup_df.show(truncate=False)

[Stage 4:>                                                          (0 + 8) / 8]

✅ All user_id values are unique.


                                                                                

Build an expression for each column that sums 1 whenever it's null

In [6]:
null_count_exprs = [
    F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c)
    for c in df_features.columns
]

null_counts = df_features.select(null_count_exprs)

In [7]:
null_counts.show(truncate=False)



+--------------+--------------+-------+----------------+------------------+----------+----------------+-------+------+-------+------+----------------+---------+-------------+-------+
|auto_renew_off|avg_event_hour|country|first_event_date|first_year_revenue|free_trial|operating_system|paywall|refund|renewal|season|stickiness_ratio|subscribe|total_revenue|user_id|
+--------------+--------------+-------+----------------+------------------+----------+----------------+-------+------+-------+------+----------------+---------+-------------+-------+
|0             |0             |0      |0               |0                 |0         |0               |0      |0     |0      |0     |0               |0        |0            |0      |
+--------------+--------------+-------+----------------+------------------+----------+----------------+-------+------+-------+------+----------------+---------+-------------+-------+



                                                                                

# Pre-processing for Cohort-level Analysis

In [8]:
check = df_features \
      .groupBy("first_event_date") \
      .agg(
         # diversity (distinct‐count) of categorical fields
          F.countDistinct("country"         ).alias("n_countries"),
          F.countDistinct("operating_system").alias("n_operating_systems"),
          F.countDistinct("season"          ).alias("n_seasons")
      )

Since each cohort is active only in one single season, only one season represents each cohort.

In [5]:
from pyspark.sql.window import Window

Normalize “ios” → “iOS” (and leave everything else untouched)

In [6]:
df_features = df_features.withColumn(
    "operating_system",
    F.when(F.lower(F.col("operating_system")) == "ios", F.lit("iOS"))
     .otherwise(F.col("operating_system"))
)

2) Aggregate to get cohort-level stats

In [7]:
base_agg = (
    df_features
      .groupBy("first_event_date")
      .agg(
          F.first("season")        .alias("cohort_season"),
          F.count("user_id")       .alias("cohort_size"),

          # user‐level stats averaged up
          F.avg("stickiness_ratio").alias("avg_stickiness_ratio"),
          F.avg("auto_renew_off" ) .alias("avg_auto_renew_off"),
          F.stddev("auto_renew_off").alias("std_auto_renew_off"),
          F.avg("free_trial"      ).alias("avg_free_trial"),
          F.stddev("free_trial"   ).alias("std_free_trial"),
          F.avg("paywall"         ).alias("avg_paywall"),
          F.stddev("paywall"      ).alias("std_paywall"),
          F.avg("refund"          ).alias("avg_refund"),
          F.stddev("refund"       ).alias("std_refund"),
          F.avg("renewal"         ).alias("avg_renewal"),
          F.stddev("renewal"      ).alias("std_renewal"),
          F.avg("subscribe"       ).alias("avg_subscribe"),
          F.stddev("subscribe"    ).alias("std_subscribe"),

          # engagement hour & revenues
          F.avg("avg_event_hour"  ).alias("mean_event_hour"),
          F.stddev("avg_event_hour").alias("std_event_hour"),
          F.avg("total_revenue"    ).alias("mean_revenue_15d"),
          F.stddev("total_revenue" ).alias("std_revenue_15d"),
          F.avg("first_year_revenue"   ).alias("mean_revenue_1y"),
          F.stddev("first_year_revenue").alias("std_revenue_1y")
      )
)

3. Compute OS counts and percentage per cohort

In [8]:
os_counts = (
    df_features
      .groupBy("first_event_date", "operating_system")
      .agg(F.count("*").alias("os_count"))
)

os_pct = (
    os_counts
      .join(base_agg.select("first_event_date", "cohort_size"), on="first_event_date")
      .withColumn("os_pct", F.col("os_count") / F.col("cohort_size"))
)

4. Pivot to wide format: one column per OS giving its pct share

In [9]:
os_pivot = (
    os_pct
      .groupBy("first_event_date")
      .pivot("operating_system")
      .agg(F.first("os_pct"))
)

                                                                                

5) Join base_agg with OS-pct pivot

In [10]:
cohort_agg = base_agg.join(os_pivot, on="first_event_date", how="left")

6) Add 1-based cohort_index ordered by date

In [15]:
cohort_agg.printSchema()

root
 |-- first_event_date: string (nullable = true)
 |-- cohort_season: string (nullable = true)
 |-- cohort_size: long (nullable = false)
 |-- avg_auto_renew_off: double (nullable = true)
 |-- std_auto_renew_off: double (nullable = true)
 |-- avg_free_trial: double (nullable = true)
 |-- std_free_trial: double (nullable = true)
 |-- avg_paywall: double (nullable = true)
 |-- std_paywall: double (nullable = true)
 |-- avg_refund: double (nullable = true)
 |-- std_refund: double (nullable = true)
 |-- avg_renewal: double (nullable = true)
 |-- std_renewal: double (nullable = true)
 |-- avg_subscribe: double (nullable = true)
 |-- std_subscribe: double (nullable = true)
 |-- mean_event_hour: double (nullable = true)
 |-- std_event_hour: double (nullable = true)
 |-- mean_revenue_15d: double (nullable = true)
 |-- std_revenue_15d: double (nullable = true)
 |-- mean_revenue_1y: double (nullable = true)
 |-- std_revenue_1y: double (nullable = true)
 |-- iOS: double (nullable = true)
 |-- i

In [11]:
w = Window.orderBy("first_event_date")
df_with_cohort = (
    cohort_agg
      .withColumn("cohort_index", F.row_number().over(w))
)

In [17]:
df_with_cohort.printSchema()

root
 |-- first_event_date: string (nullable = true)
 |-- cohort_season: string (nullable = true)
 |-- cohort_size: long (nullable = false)
 |-- avg_auto_renew_off: double (nullable = true)
 |-- std_auto_renew_off: double (nullable = true)
 |-- avg_free_trial: double (nullable = true)
 |-- std_free_trial: double (nullable = true)
 |-- avg_paywall: double (nullable = true)
 |-- std_paywall: double (nullable = true)
 |-- avg_refund: double (nullable = true)
 |-- std_refund: double (nullable = true)
 |-- avg_renewal: double (nullable = true)
 |-- std_renewal: double (nullable = true)
 |-- avg_subscribe: double (nullable = true)
 |-- std_subscribe: double (nullable = true)
 |-- mean_event_hour: double (nullable = true)
 |-- std_event_hour: double (nullable = true)
 |-- mean_revenue_15d: double (nullable = true)
 |-- std_revenue_15d: double (nullable = true)
 |-- mean_revenue_1y: double (nullable = true)
 |-- std_revenue_1y: double (nullable = true)
 |-- iOS: double (nullable = true)
 |-- i

Build an aggregate of null‐counts per column

In [13]:
null_counts = df_with_cohort.select([
    F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c)
    for c in df_with_cohort.columns
])

In [14]:
# display the results
null_counts.show(truncate=False)

25/07/26 21:31:11 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 29:>                                                         (0 + 1) / 1]

+----------------+-------------+-----------+--------------------+------------------+------------------+--------------+--------------+-----------+-----------+----------+----------+-----------+-----------+-------------+-------------+---------------+--------------+----------------+---------------+---------------+--------------+---+------+------------+
|first_event_date|cohort_season|cohort_size|avg_stickiness_ratio|avg_auto_renew_off|std_auto_renew_off|avg_free_trial|std_free_trial|avg_paywall|std_paywall|avg_refund|std_refund|avg_renewal|std_renewal|avg_subscribe|std_subscribe|mean_event_hour|std_event_hour|mean_revenue_15d|std_revenue_15d|mean_revenue_1y|std_revenue_1y|iOS|iPadOS|cohort_index|
+----------------+-------------+-----------+--------------------+------------------+------------------+--------------+--------------+-----------+-----------+----------+----------+-----------+-----------+-------------+-------------+---------------+--------------+----------------+---------------+---

                                                                                

# Write

In [15]:
# 1) repartition to use all cores
num_parts = spark.sparkContext.defaultParallelism
df_out = df_with_cohort.repartition(num_parts)
print("After:", df_out.rdd.getNumPartitions())

25/07/26 21:31:40 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:31:40 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:31:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:31:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:32:02 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:32:02 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 2

After: 8


In [16]:
# 2) write as compressed JSON
df_out.write \
    .mode("overwrite") \
    .option("compression", "gzip") \
    .json("/Users/macbookpro/PyCharmMiscProject/df_with_cohort_json")

25/07/26 21:32:46 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:32:46 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:32:51 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:32:51 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:32:55 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 21:32:55 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/07/26 2