This project uses PySpark to predict churn based on a 12GB dataset of a fictitious music service platform, "Spartify". Check out my blog post for more details!
Losing customers (aka., churn) is an important business problem, because the cost of acquiring a new customer is often much higher than retaining an existing one. Being able to predict churn can help a company to prioritize retention program on customers who are most likely to churn, taking actions even before a churn happens.
In this project, I used Spark to analyze user activity dataset and build a machine learning model to identify users who are most likely to churn.
-
User activity dataset from Udacity
The dataset logs user demographic information (e.g. user name, gender, location State) and activity (e.g. song listened, event type, device used) at individual timestamps. -
Census regions table from cphalpert's GitHub
The table links State names to geographical divisions.
A small subset (~120MB) of the full dataset was used for exploratory data analysis and pilot modeling; the full dataset (~12GB) was used for tuning the machine learning model.
-
Data loading
- Load subset from JSON
- Assess missing values
-
Exploratory data analysis
- Overview of numerical columns: descriptive statistics
- Overview of non-numerical columns: possibel categories
- Define churn as cancellation of service
- Compare behavior of churn vs. non-churn users in terms of:
- Usage at different hours of the day
- Usage at different days of a week
- User level (free vs. paid)
- Event types (e.g. add a friend, advertisement, thumbs up)
- Device used (e.g. Mac, Windows)
- User location (e.g. New England, Pacific)
- Time from downgrade to churn
-
Feature engineering for machine learning
- Create features on per user basis:
- Latest user level
- Time since registration
- Gender of user
- Time, number of artists, number of songs, and number of session that user has engaged
- Mean and standard deviation of the number of songs listened per artist, the number of songs listened per session, and time spent per session
- Device used
- Count and proportion of each event type
- User location
- Remove strongly correlated features (one from each pair)
- Transform features to have distributions closer to normal
- Compile feature engineering code to scale up later
- Create features on per user basis:
-
Develop machine learning pipeline
- Split training and testing sets
- Choose evaluation metrics
- Create functions to build cross validation pipeline (documentation), train machine learning model, and evaluate model performance
- Initial model evaluation with:
- Naive predictor
- Logistic regression (documentation)
- Random forest (documentation)
- Gradient-boosted tree (documentation)
-
Scale up machine learning on the full dataset on AWS
- Tune hyperparameters of gradient-boosted tree
- Evaluate model performance
- Evaluate feature importance
-
Model performance on testing set:
testing accuracy score testing F1 score 0.8387 0.8229 Note that the scores are reasonably well, although not subperb. One limitation is that I was not able to run the model for a long time period given my AWS tier, so only few hyperparameters were tuned. The model performance could be further improved by tuning broader ranges of hyperparameters.
-
Churns relate to users who have received more advertisements, disliked songs more often than liked, and registered more recently.
-
Check out my blog post for more details!
-
Prototype on local machine: The code was developed using the Anaconda distribution of Python, versions 3. Libraries used include
PySpark
,Pandas
,Seaborn
, andMatplotlib
. -
Cloud deployment on AWS EMR:
- Release: emr-5.20.0
- Applications: Spark: Spark 2.4.0 on Hadoop 2.8.5 YARN with Ganglia 3.7.2 and Zeppelin 0.8.0
- Instance type: m4.xlarge
- Number of instance: 3
Sparkify.ipynb
: exploratory data analysis, data preprocessing, and pilot development of machine learning model on local machine using data subset.Sparkify_AWS.ipynb
: data preprocessing and model tuning on AWS using the full dataset.mini_sparkify_event_data.json
: subset of user activity data.region.csv
: census regions table.