In [1]:
import numpy as np
import pandas as pd

from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType

from pyspark.ml.stat import Correlation
from pyspark.ml.feature import MinMaxScaler, VectorAssembler

from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier 
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from time import time

In [4]:
spark = SparkSession.builder \
    .appName("Sparkify_Project_Spark_Local") \
    .getOrCreate()

data_path = "./mini_sparkify_event_data.json"
df = spark.read.json(data_path)
# check the schema of the dataset
df.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



In [5]:
print('The dataset has {} rows.'.format(df.count()))

The dataset has 286500 rows.


In [6]:
df.select(min(to_timestamp(col('ts')/1000)).alias('Start time')).show()

+-------------------+
|         Start time|
+-------------------+
|2018-09-30 20:01:57|
+-------------------+



In [7]:
df.select(max(to_timestamp(col('ts')/1000)).alias('End time')).show()

+-------------------+
|           End time|
+-------------------+
|2018-12-02 20:11:16|
+-------------------+



In [8]:
display(df.select([count(when(isnan(c),c)).alias(c) for c in df.columns]))

DataFrame[artist: bigint, auth: bigint, firstName: bigint, gender: bigint, itemInSession: bigint, lastName: bigint, length: bigint, level: bigint, location: bigint, method: bigint, page: bigint, registration: bigint, sessionId: bigint, song: bigint, status: bigint, ts: bigint, userAgent: bigint, userId: bigint]

In [12]:
df.select("page").distinct().show()

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
|            Settings|
|               Login|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
|               Error|
|      Submit Upgrade|
+--------------------+
only showing top 20 rows



In [13]:
cancellation_check_function = udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())

# Applied and generated new columns named churn
df = df.withColumn("churn", cancellation_check_function("page"))

In [14]:
from pyspark.sql import Window
from pyspark.sql.functions import col, sum as Fsum

# Define the window bounds to use Fsum to count for the churn
windowval = Window.partitionBy("userId").rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)

# Apply the window function to the DataFrame df
df = df.withColumn("churn", Fsum(col("churn")).over(windowval))

In [15]:
def missing_values(df, col):
    return df.filter((isnan(df[col])) | (df[col].isNull()) | (df[col] == "")).count()

print("\n[Missing values]\n")
for col in df.columns:
    missing_count = missing_values(df, col)
    if missing_count > 0:
        print("{}: {}".format(col, missing_count))


[Missing values]

artist: 58392
firstName: 8346
gender: 8346
lastName: 8346
length: 58392
location: 8346
registration: 8346
song: 58392
userAgent: 8346
userId: 8346


In [16]:
df = df.filter(df["userId"] != "")
df.count()

278154

In [17]:
stat_df = spark.createDataFrame(df.dropDuplicates(['userId']).collect())
stat_df1 = stat_df[['gender', 'churn']]
print('The avg churn rate of females is:', stat_df1.groupby(['gender']).mean().collect()[0][1]*100)
print('The avg churn rate of males is:', stat_df1.groupby(['gender']).mean().collect()[1][1]*100)

The avg churn rate of females is: 19.230769230769234
The avg churn rate of males is: 26.446280991735538


In [18]:
stat_df1 = stat_df[['artist', 'churn']]
display(stat_df1.groupBy(['artist']).sum().orderBy('sum(churn)', ascending = False).collect()[:5])

[Row(artist=None, sum(churn)=27),
 Row(artist='P!nk', sum(churn)=1),
 Row(artist='Gorillaz', sum(churn)=1),
 Row(artist='Modjo', sum(churn)=1),
 Row(artist="Christopher O'Riley", sum(churn)=1)]