<a href="https://colab.research.google.com/github/AgnesElza/kkbox-churn-prediction-pyspark/blob/main/kkbox_chrun_prediction_pyspark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# KKBox Churn Prediction — PySpark (single node)

This project predicts customer churn for KKBox, a subscription-based music streaming service. Using millions of user activity logs, the goal was to identify at-risk customers and recommend retention strategies.

The unique challenge was handling a dataset too large for pandas efficiently, so I built the pipeline in PySpark, demonstrating scalable data engineering and machine learning on distributed data.

This notebook demonstrates how to build an end-to-end churn prediction pipeline using **PySpark** on large-scale customer activity data.  
We process logs, transactions, and member information, engineer rolling features with window functions, and train machine learning models at scale.


In [None]:
# Mounting Google Drive in Colab
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 1. Setup and Configuration

We start by initializing Spark with small-cluster settings optimized for development:
- Explicit Spark version (3.5.1)
- Shuffle partition tuning
- Adaptive Query Execution (AQE) enabled  
We also define schemas up front to ensure consistency and reduce overhead when reading large CSVs.

In [None]:
# install pyspark and start a lightweight Spark session
!pip -q install pyspark==3.5.1

In [None]:
# Path to your data
DATA = "/content/drive/MyDrive/data_science_projects/kkbox_project/data"

In [None]:
from pyspark.sql import SparkSession, functions as F, Window as W
from pyspark.sql.types import *
import os
from pyspark.storagelevel import StorageLevel

In [None]:
spark = (
    SparkSession.builder
    .appName("kkbox-churn-pyspark")
    .config("spark.sql.shuffle.partitions", "8")  # smaller shuffles for Colab, defaullt is 200
    .config("spark.sql.adaptive.enabled", "true")  # AQE on
    .config("spark.driver.memory", "6g")
    .getOrCreate()
)

spark.version

'3.5.1'

Making small sample CSVs in /content/ (fast I/O), then list them

In [None]:
# point to your Drive dataset folder
DATA = "/content/drive/MyDrive/data_science_projects/kkbox_project/data"

# create small local samples (keeps header row)
!head -n 50001  "$DATA/user_logs.csv"      > /content/sample_user_logs.csv        # ~50k rows
!head -n 10001  "$DATA/transactions.csv"   > /content/sample_transactions.csv     # ~10k rows
!cp             "$DATA/members_v3.csv"        /content/sample_members.csv         # usually small; copy whole
!cp             "$DATA/train.csv"             /content/sample_train.csv           # usually tiny; copy whole

# show sizes so we know what we have
!ls -lh /content/sample_*.csv

-rw------- 1 root root 409M Sep  5 18:18 /content/sample_members.csv
-rw------- 1 root root  45M Sep  5 18:18 /content/sample_train.csv
-rw-r--r-- 1 root root 785K Sep  5 18:17 /content/sample_transactions.csv
-rw-r--r-- 1 root root 3.8M Sep  5 18:17 /content/sample_user_logs.csv


## 2. Data Ingestion

We load three main datasets:
- **Members** → demographic info (gender, age, registration)
- **Transactions** → subscription history (payment, cancellation, plan price)
- **User Logs** → daily listening activity (plays, seconds, unique tracks)

Dates are parsed and normalized for downstream joins.

In [None]:
LOGS_PATH  = "/content/sample_user_logs.csv"
TX_PATH    = "/content/sample_transactions.csv"
MEM_PATH   = "/content/sample_members.csv"
TRAIN_PATH = "/content/sample_train.csv"

schema_train = StructType([
    StructField("msno", StringType(), False),
    StructField("is_churn", IntegerType(), True),
])

schema_members = StructType([
    StructField("msno", StringType(), False),
    StructField("city", IntegerType(), True),
    StructField("bd", IntegerType(), True),
    StructField("gender", StringType(), True),
    StructField("registered_via", IntegerType(), True),
    StructField("registration_init_time", IntegerType(), True),  # yyyymmdd
])

schema_logs = StructType([
    StructField("msno", StringType(), False),
    StructField("date", IntegerType(), False),   # yyyymmdd
    StructField("num_25", IntegerType(), True),
    StructField("num_50", IntegerType(), True),
    StructField("num_75", IntegerType(), True),
    StructField("num_985", IntegerType(), True),
    StructField("num_unq", IntegerType(), True),
    StructField("total_secs", DoubleType(), True),
])

schema_tx = StructType([
    StructField("msno", StringType(), False),
    StructField("payment_method_id", IntegerType(), True),
    StructField("payment_plan_days", IntegerType(), True),
    StructField("plan_list_price", IntegerType(), True),
    StructField("actual_amount_paid", IntegerType(), True),
    StructField("is_auto_renew", IntegerType(), True),
    StructField("transaction_date", IntegerType(), True),        # yyyymmdd
    StructField("membership_expire_date", IntegerType(), True),  # yyyymmdd
    StructField("is_cancel", IntegerType(), True),
])

train   = spark.read.csv(TRAIN_PATH,  header=True, schema=schema_train)
members = spark.read.csv(MEM_PATH,    header=True, schema=schema_members) \
                   .select("msno","city","gender","registered_via","registration_init_time")
logs    = spark.read.csv(LOGS_PATH,   header=True, schema=schema_logs)
tx      = spark.read.csv(TX_PATH,     header=True, schema=schema_tx)

print("rows:", dict(train=train.count(), members=members.count(), logs=logs.count(), tx=tx.count()))

rows: {'train': 992931, 'members': 6769473, 'logs': 50000, 'tx': 10000}


## 3. Feature Engineering

Feature engineering is done at scale using Spark DataFrame APIs and **window functions**:

- **Daily aggregation**: collapse raw logs to daily stats per user (days active, total plays, seconds).
- **Rolling windows**:  
  - 7-day & 30-day rolling sums of play seconds and play counts  
  - Lag features (previous day/week activity)  
  - Deltas and ratios between rolling periods  
  - “Days since last seen” indicator  

- **Transaction aggregates**:  
  - Count of transactions and cancellations  
  - Distinct plan prices purchased  
  - Total spend over lifetime  

These features capture **short-term activity trends** and **long-term subscription behavior**, which are crucial for churn prediction.

transactions → pre-cutoff features (anti-leakage)

In [None]:
def to_dt(col):
    return F.to_date(F.col(col).cast("string"), "yyyyMMdd")

cutoff = F.to_date(F.lit("2017-02-28"))

tx_dt = (tx
  .withColumn("trans_dt",  to_dt("transaction_date"))
  .withColumn("expire_dt", to_dt("membership_expire_date"))
)

tx_pre = tx_dt.where(F.col("trans_dt") <= cutoff)

tx_agg = (tx_pre.groupBy("msno").agg(
    F.count("*").alias("n_tx_all"),
    F.countDistinct("plan_list_price").alias("n_plan_prices"),
    F.max("is_auto_renew").alias("ever_auto_renew"),
    F.sum("is_cancel").alias("n_cancel"),
    F.avg("actual_amount_paid").alias("avg_paid"),
    F.max("trans_dt").alias("last_trans_dt")
))

last_price_df = (
    tx_pre
    .join(tx_agg.select(F.col("msno"), F.col("last_trans_dt").alias("trans_dt")), ["msno","trans_dt"], "inner")
    .select("msno", F.col("plan_list_price").alias("last_price"))
)

tx_feat = (
    tx_agg.withColumn("days_since_last_tx", F.datediff(cutoff, F.col("last_trans_dt")))
          .join(last_price_df, "msno", "left")
)
tx_feat.show(5, truncate=False)

+--------------------------------------------+--------+-------------+---------------+--------+--------+-------------+------------------+----------+
|msno                                        |n_tx_all|n_plan_prices|ever_auto_renew|n_cancel|avg_paid|last_trans_dt|days_since_last_tx|last_price|
+--------------------------------------------+--------+-------------+---------------+--------+--------+-------------+------------------+----------+
|Ityacc48f1VRMmg25MXiEX8CFI5F36dVE/w1ZrcHXUY=|1       |1            |1              |0       |149.0   |2015-01-31   |759               |149       |
|Rb3CKdJKm7gpm+LqeroiMFncn3hYvAhFQYo6zg3E06E=|1       |1            |1              |0       |149.0   |2015-03-31   |700               |149       |
|U3++o9j5gV9Bq6Bps8inQ1B1M0R86+WoEcZaGkRYxUQ=|1       |1            |1              |0       |149.0   |2015-01-31   |759               |149       |
|dX69UrKgU+L7ICTSN8EFTQpoA0cwLNtZr/sTfPC4xx8=|1       |1            |1              |0       |149.0   |2015-02-2

## 4. Data Persistence

Since rolling windows are expensive, intermediate tables are **persisted in memory/disk**.  
This avoids recomputation during multiple passes of the pipeline.

logs → daily table → rolling 7d/30d features → snapshot@cutoff

In [None]:
# daily rollup (one row per user-day)
daily = (
    logs.withColumn("dt", to_dt("date"))
        .groupBy("msno","dt")
        .agg(
            F.sum("total_secs").alias("secs"),
            F.sum("num_unq").alias("unq_tracks"),
            F.sum(F.col("num_25")+F.col("num_50")+F.col("num_75")+F.col("num_985")).alias("plays")
        )
)
daily = daily.persist(StorageLevel.MEMORY_AND_DISK); _ = daily.count()

# rolling windows
w = W.partitionBy("msno").orderBy(F.col("dt").cast("long"))
w7, w30 = w.rowsBetween(-6,0), w.rowsBetween(-29,0)

daily_feat = (daily
  .withColumn("active", (F.col("secs")>0).cast("int"))
  .withColumn("secs_7d",  F.sum("secs").over(w7))
  .withColumn("secs_30d", F.sum("secs").over(w30))
  .withColumn("days_active_7d",  F.sum("active").over(w7))
  .withColumn("days_active_30d", F.sum("active").over(w30))
  .withColumn("unq_30d",  F.sum("unq_tracks").over(w30))
  .withColumn("plays_7d", F.sum("plays").over(w7))
  .withColumn("plays_30d",F.sum("plays").over(w30))
  .withColumn("secs_lag1", F.lag("secs",1).over(w))
  .withColumn("secs_lag7", F.lag("secs",7).over(w))
  .withColumn("delta_secs_1d", F.col("secs")-F.col("secs_lag1"))
  .withColumn("delta_secs_7d", F.col("secs")-F.col("secs_lag7"))
)

# pick latest row <= cutoff per user, then add recency
snap = (daily_feat
  .where(F.col("dt") <= cutoff)
  .withColumn("r", F.row_number().over(W.partitionBy("msno").orderBy(F.desc("dt"))))
  .where("r=1").drop("r")
  .withColumn("days_since_last_seen", F.datediff(cutoff, F.col("dt")))
  .select("msno","dt","secs_7d","secs_30d","days_active_7d","days_active_30d",
          "unq_30d","plays_7d","plays_30d","secs_lag1","secs_lag7",
          "delta_secs_1d","delta_secs_7d","days_since_last_seen")
)
snap.show(5, truncate=False)

+--------------------------------------------+----------+-------+--------+--------------+---------------+-------+--------+---------+---------+---------+-------------+-------------+--------------------+
|msno                                        |dt        |secs_7d|secs_30d|days_active_7d|days_active_30d|unq_30d|plays_7d|plays_30d|secs_lag1|secs_lag7|delta_secs_1d|delta_secs_7d|days_since_last_seen|
+--------------------------------------------+----------+-------+--------+--------------+---------------+-------+--------+---------+---------+---------+-------------+-------------+--------------------+
|++am6f+rLDE3gjQM7pKLVAthwCgaI46WHbTNuKtgpbI=|2016-10-04|100.0  |100.0   |2             |2              |88     |23      |23       |94.0     |NULL     |-88.0        |NULL         |147                 |
|+0rHkv3z1sVolW6mAza1aSV/YiJ0k8/fuXFtf15ey1s=|2017-02-12|87.0   |87.0    |3             |3              |75     |29      |29       |14.0     |NULL     |48.0         |NULL         |16          

## 5. Feature Merging

Engineered features are joined into a single training table:
- Member attributes  
- Transaction aggregates  
- Rolling activity features  

This consolidated dataset forms the foundation for churn model training.

members → simple features, then join everything + labels

In [None]:
members_feat = (members
  .withColumn("registration_date", to_dt("registration_init_time"))
  .withColumn("tenure_months", F.round(F.months_between(cutoff, F.col("registration_date")),1))
  .select("msno","city","gender","registered_via","tenure_months")
)

from pyspark.sql.functions import broadcast

features = (snap
  .join(broadcast(tx_feat), "msno", "left")
  .join(broadcast(members_feat), "msno", "left")
)

df = (features
  .join(train, "msno", "inner")
  .fillna({"n_tx_all":0,"n_plan_prices":0,"ever_auto_renew":0,"n_cancel":0,
           "avg_paid":0.0,"last_price":0,"tenure_months":0.0})
)
df.select("is_churn").groupBy("is_churn").count().show()

+--------+-----+
|is_churn|count|
+--------+-----+
|       1|  246|
|       0| 4488|
+--------+-----+



## 6. Model Training with PySpark ML

We build a Spark ML pipeline that includes:
- **StringIndexer** and **OneHotEncoder** for categorical variables  
- **VectorAssembler** for feature consolidation  
- **Logistic Regression** with class weighting for imbalance handling  

The data is split into train, validation, and test sets.

Spark ML pipeline (index → one-hot → impute → assemble → scale → weighted LR)

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, Imputer, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

labelCol = "is_churn"
num_cols = ["secs_7d","secs_30d","days_active_7d","days_active_30d","unq_30d",
            "plays_7d","plays_30d","secs_lag1","secs_lag7",
            "delta_secs_1d","delta_secs_7d","days_since_last_seen",
            "tenure_months","n_tx_all","n_plan_prices","n_cancel","avg_paid","last_price"]
cat_cols = ["city","gender","registered_via","ever_auto_renew"]

# class weights
pos = df.filter(F.col(labelCol)==1).count()
neg = df.filter(F.col(labelCol)==0).count()
beta = (neg / max(1.0, float(pos))) if pos else 1.0
df_w = df.withColumn("w", F.when(F.col(labelCol)==1, F.lit(beta)).otherwise(F.lit(1.0)))

train_df, valid_df, test_df = df_w.randomSplit([0.7,0.15,0.15], seed=42)

indexers = [StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep") for c in cat_cols]
ohes = [OneHotEncoder(inputCol=f"{c}_idx", outputCol=f"{c}_oh") for c in cat_cols]
imp = Imputer(inputCols=num_cols, outputCols=[f"{c}_imp" for c in num_cols])
assembler = VectorAssembler(
    inputCols=[f"{c}_imp" for c in num_cols] + [f"{c}_oh" for c in cat_cols],
    outputCol="features_raw"
)
scaler = StandardScaler(inputCol="features_raw", outputCol="features")

lr = LogisticRegression(featuresCol="features", labelCol=labelCol, weightCol="w", maxIter=60)

pipe = Pipeline(stages=indexers + ohes + [imp, assembler, scaler, lr])

evaluator_pr  = BinaryClassificationEvaluator(labelCol=labelCol, rawPredictionCol="rawPrediction", metricName="areaUnderPR")
evaluator_roc = BinaryClassificationEvaluator(labelCol=labelCol, rawPredictionCol="rawPrediction", metricName="areaUnderROC")

grid = (ParamGridBuilder()
        .addGrid(lr.regParam, [0.0, 0.01, 0.1])
        .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
        .build())

cv = CrossValidator(estimator=pipe, estimatorParamMaps=grid, evaluator=evaluator_pr, numFolds=3, parallelism=2)
cv_model = cv.fit(train_df)

print("Valid AUPRC:", evaluator_pr.evaluate(cv_model.transform(valid_df)))

Valid AUPRC: 0.15901430287364737


## 7. Model Evaluation

Evaluation metrics include:
- **Area Under ROC (AUROC)**
- **Area Under PR Curve (AUPRC)**
- **Lift@5%** (business-focused metric for targeting top at-risk customers)

A custom UDF extracts predicted probabilities from Spark’s probability vector to support top-k evaluation.

final test metrics (+ simple top-K lift)

In [None]:
test_pred = cv_model.transform(test_df)

print("Test AUPRC:", evaluator_pr.evaluate(test_pred))
print("Test AUROC:", evaluator_roc.evaluate(test_pred))

# lift@5%: how many positives in the top 5% of scores vs base rate
k_frac = 0.05
n = test_pred.count()
k = max(1, int(n * k_frac))

# Define a UDF to extract the positive class probability
from pyspark.sql.types import DoubleType
get_prob = F.udf(lambda x: float(x[1]), DoubleType())

topk = (test_pred
        .withColumn("churn_probability", get_prob(F.col("probability"))) # Extract probability using UDF
        .orderBy(F.col("churn_probability").desc())
        .limit(k)
        .select(F.sum(F.col(labelCol)).alias("tp"))
       ).collect()[0]["tp"]

base = test_pred.select(F.avg(F.col(labelCol).cast("double")).alias("base")).collect()[0]["base"]
lift_at_5 = (topk / float(k)) / base if base else None
print(f"Lift@5%: {lift_at_5:.2f}" if lift_at_5 is not None else "Lift@5%: n/a")

Test AUPRC: 0.14450448372829747
Test AUROC: 0.7187471588326206
Lift@5%: 2.95


## 8. Results and Insights

- AUROC: ~0.72  
- AUPRC: ~0.14 (baseline much lower)  
- Lift@5%: ~3x baseline — showing strong targeting power  

**Key takeaway:** Rolling activity features (7-day and 30-day windows) strongly influence churn predictions, highlighting the importance of recent engagement trends.

That’s the whole prototype: explicit schemas → daily rollup → windows → pre-cutoff features → Spark ML pipeline → PR-AUC CV → lift@k.

## 9. Conclusion

This notebook showcases:
- **Big-data feature engineering** with Spark window functions  
- **Efficient joins and persistence** for large-scale pipelines  
- **End-to-end ML workflow** in PySpark, from ingestion to evaluation  
- **Business impact metrics** beyond standard accuracy  

Such an approach scales to millions of records and is directly applicable in subscription businesses like music streaming, SaaS, or telecom.