In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql import functions
from pyspark.sql.types import *
from pyspark.sql.functions import *

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [2]:
#Create a SparkSession
spark = SparkSession.builder.appName("downsampling").getOrCreate()

## Random Sample valid user id 

In [3]:
valid_user_id = pd.read_csv('../data/valid_user_id.csv')

In [4]:
valid_user_id.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 135549 entries, 0 to 135548
Data columns (total 2 columns):
uid      135549 non-null float64
count    135549 non-null int64
dtypes: float64(1), int64(1)
memory usage: 2.1 MB


In [8]:
lines = spark.sparkContext.textFile("../data/all_play.log.fn")

In [10]:
def parseLine(line):
    """Parse lines from large text file and save it to Row RDD
    
    TODO: write these parseLine into a class. The chosen_uid argument can be object attribute
    """
    fields = line.split('\t')
    if len(fields) == 10:
        try:
            uid = float(fields[0])
            device = str(fields[1])
            song_id = str(fields[2])
            song_type = float(fields[3])
            song_name = str(fields[4])
            singer = str(fields[5])
            play_time = str(fields[6])
            song_length = float(fields[7])
            paid_flag = float(fields[8])
            fn = str(fields[9])
            return Row(uid, device, song_id, song_type, song_name, singer, play_time, song_length, paid_flag, fn)
        except:
            return Row(None)
    else:
        return Row(None)


schema = StructType([StructField('uid', FloatType(), False),
                     StructField('device', StringType(), True),
                     StructField('song_id', StringType(), False),
                     StructField('song_type', FloatType(), True),
                     StructField('song_name', StringType(), True),
                     StructField('singer', StringType(), True),
                     StructField('play_time', StringType(), False),
                     StructField('song_length', FloatType(), True),
                     StructField('paid_flag', FloatType(), True),
                     StructField('fn', StringType(), True),])

## Filter User ID (remove test uid) and drop uninformative columns

In [13]:
songs = lines.map(parseLine).filter(lambda x: len(x) == len(schema))
# Convert that to a DataFrame
df_all_valid = spark.createDataFrame(songs,schema)
df_all_valid = df_all_valid.drop('song_type',
                                     'song_name', 
                                     'singer',
                                     'paid_flag').filter(df_all_valid['uid'].isin(list(valid_user_id.uid.values)))

## Transform file_name to datatime

In [15]:
df_all_valid = df_all_valid.withColumn("device", trim(df_all_valid.device)) \
                               .withColumn('date_str', trim(df_all_valid.fn.substr(1,9))) \
                               .withColumn('date_string', regexp_replace('date_str', '20170339', '20170329')) \
                               .withColumn("unix_date", unix_timestamp('date_string', 'yyyyMMdd')) \
                               .withColumn("date", from_unixtime('unix_date').cast(DateType())) \
                               .drop('date_str') \
                               .drop('date_string') \
                               .drop('unix_date').cache()

In [16]:
pd.DataFrame(df_all_valid.take(5), columns=df_all_valid.columns)

Unnamed: 0,uid,device,song_id,play_time,song_length,fn,date
0,154422688.0,ar,20870993,22013,332.0,20170301_play.log,2017-03-01
1,154421904.0,ip,6560858,96,161.0,20170301_play.log,2017-03-01
2,154422624.0,ar,3385963,235868,235.0,20170301_play.log,2017-03-01
3,154410272.0,ar,6777172,164,237.0,20170301_play.log,2017-03-01
4,154407792.0,ar,19472465,24,201.0,20170301_play.log,2017-03-01


In [17]:
# play log dated from 20170301 to 20170512, 
# use last 2 week as churn window 
active_uid = df_all_valid.filter(df_all_valid.date >= '2017-04-29') \
                            .select(df_all_valid.uid.alias('active_uid')) \
                            .distinct()
        
active_uid.repartition(1).write.csv('../data/active_uid', header=True)
print('total number of active user is', active_uid.count())

total number of active user is 83928


In [18]:
# converting spark dataframe column to list
active_list = active_uid.select('active_uid').rdd.flatMap(lambda x: x).collect()
valid_user_id['churn_label'] = valid_user_id['uid'].isin(active_list)

### Check the ratio of active and churned user

In [23]:
valid_user_id.groupby('churn_label').count()

Unnamed: 0_level_0,uid,count
churn_label,Unnamed: 1_level_1,Unnamed: 2_level_1
False,51621,51621
True,83928,83928


In [25]:
valid_user_id.to_csv('../data/valid_user_id.csv', index=False)

## Down sample 1:1 in churned and active user

In [28]:
# keep 5% of the valid user id for training and modeling
sampled_uid = valid_user_id.groupby('churn_label').apply(lambda x: x.sample(frac=0.05)).set_index('uid')
# we have 6777 sampled user in total

In [30]:
sampled_playlog = df_all_valid.filter(df_all_valid['uid'].isin(list(sampled_uid.index.values))).cache()

In [31]:
pd.DataFrame(sampled_playlog.take(5),columns=sampled_playlog.columns)

Unnamed: 0,uid,device,song_id,play_time,song_length,fn,date
0,154421168.0,ar,1967689,139,275.0,20170301_play.log,2017-03-01
1,154422592.0,ar,6468891,261,261.0,20170301_play.log,2017-03-01
2,154422592.0,ar,20870993,332,332.0,20170301_play.log,2017-03-01
3,154416928.0,ar,1691087,86,358.0,20170301_play.log,2017-03-01
4,154421664.0,ip,7153193,4,256.0,20170301_play.log,2017-03-01


In [33]:
sampled_playlog.coalesce(1).write.format('json').save('../data/sampled_playlog.json')

In [37]:
df_sampled = pd.read_json('../data/sampled_playlog.json/part-00000-badae413-8a2b-46ea-9ce7-c83aaa3fa0f5-c000.json'
                          ,lines=True)

In [38]:
df_sampled.head()

Unnamed: 0,date,device,fn,play_time,song_id,song_length,uid
0,2017-03-01,ar,20170301_play.log,139,1967689.0,275.0,154421168.0
1,2017-03-01,ar,20170301_play.log,261,6468891.0,261.0,154422592.0
2,2017-03-01,ar,20170301_play.log,332,20870993.0,332.0,154422592.0
3,2017-03-01,ar,20170301_play.log,86,1691087.0,358.0,154416928.0
4,2017-03-01,ip,20170301_play.log,4,7153193.0,256.0,154421664.0


In [39]:
spark.stop()