# Sparkify - Prediction of Customer Churn

In [33]:
# import libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg, col, concat, desc, explode, lit, min, max, split, udf
#from pyspark.sql.types import IntegerType

#from pyspark.ml import Pipeline
#from pyspark.ml.classification import LogisticRegression
#from pyspark.ml.evaluation import MulticlassClassificationEvaluator
#from pyspark.ml.feature import CountVectorizer, IDF, Normalizer, PCA, RegexTokenizer, StandardScaler, StopWordsRemover, StringIndexer, VectorAssembler
#from pyspark.ml.regression import LinearRegression
#from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

#import re

In [2]:
# create Spark session
spark = SparkSession.builder \
    .master("local") \
    .appName("Sparkify Model") \
    .getOrCreate()

## Load and Clean Dataset

In [4]:
df = spark.read.json('mini_sparkify_event_data.json') # mini dataset to run code on local machine
df.persist() # lazy evaluation

DataFrame[artist: string, auth: string, firstName: string, gender: string, itemInSession: bigint, lastName: string, length: double, level: string, location: string, method: string, page: string, registration: bigint, sessionId: bigint, song: string, status: bigint, ts: bigint, userAgent: string, userId: string]

In [9]:
df.head()

Row(artist='Martha Tilston', auth='Logged In', firstName='Colin', gender='M', itemInSession=50, lastName='Freeman', length=277.89016, level='paid', location='Bakersfield, CA', method='PUT', page='NextSong', registration=1538173362000, sessionId=29, song='Rockpools', status=200, ts=1538352117000, userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0', userId='30')

In [10]:
df.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



### Investigate Session ID & User ID

In [71]:
id_list = ['sessionId','userId']

print('The dataset has {} rows.\n'.format(df.count()))

for idx in id_list:
    # min and max values
    df.agg(min(col(idx)), max(col(idx))).show()
    # count and share of missing values
    null_count = df.where((df[idx].isNull()) | (df[idx] == '')).count()
    null_share = (null_count/df.count())*100
    print('{} rows ({}%) have missing {}.\n'.format(null_count, "{:.2f}".format(null_share), idx))

This dataset has 286500 rows.

+--------------+--------------+
|min(sessionId)|max(sessionId)|
+--------------+--------------+
|             1|          2474|
+--------------+--------------+

0 rows (0.00%) have missing sessionId.

+-----------+-----------+
|min(userId)|max(userId)|
+-----------+-----------+
|           |         99|
+-----------+-----------+

8346 rows (2.91%) have missing userId.



In [60]:
# page visits for rows with missing userId
df.select(['userId','page','ts']).where(col('userId')=='').groupby('page').count().sort(desc('count')).show()

+-------------------+-----+
|               page|count|
+-------------------+-----+
|               Home| 4375|
|              Login| 3241|
|              About|  429|
|               Help|  272|
|           Register|   18|
|              Error|    6|
|Submit Registration|    5|
+-------------------+-----+



Since the events with missing user IDs seem to be attributed to users that are not logged in (most of these events are located at pages "Home" & "Login"), these events will be excluded from further analyses.

In [72]:
# drop rows with missing userId
df_clean = df.filter(df['userId'] != '')
print('The clean dataset is reduced to {} rows.\n'.format(df_clean.count()))

The dataset is reduced to 278154 rows.

