In [1]:
# import findspark
import findspark
# initialize findspark with spark directory
#findspark.init("C:\Program Files\Spark\spark-3.3.1-bin-hadoop3")
findspark.init("/Users/wouterdewitte/spark/")
# import pyspark
import pyspark
# create spark context
sc = pyspark.SparkContext()
# create spark session 
spark = pyspark.sql.SparkSession(sc)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/12/07 18:38:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
# import packages
import os 
import pickle
import re
from datetime import datetime
import requests
import pytz
import emojis
import pandas as pd
import numpy as np
import ast
import pyspark.sql.functions as F
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql.functions import array_contains
import matplotlib.pyplot as plt 

## General

In this notebook we will buid a model that predicts if the trend of a certain topic goes up or down on a certain day based on Twitter data of that day.

## 1. Import Data

### 1.1 Google Trends

In [3]:
# read trend data 
trend = spark.read.csv(".././../data/Google_trends/daily_trends.csv", header=True, inferSchema=True, sep=';')

In [4]:
trend

DataFrame[date: timestamp, dependent_vegan: int]

In [5]:
from pyspark.sql.window import Window

w = Window().partitionBy().orderBy(col("date"))
trend.withColumn("dependent_vegan", lag("dependent_vegan", -1, 0).over(w)).show()

22/12/07 18:38:30 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
22/12/07 18:38:30 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
22/12/07 18:38:30 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
22/12/07 18:38:30 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
22/12/07 18:38:30 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+-------------------+---------------+
|               date|dependent_vegan|
+-------------------+---------------+
|2021-10-04 00:00:00|              1|
|2021-10-05 00:00:00|  

In [6]:
# create SQL view
trend.createOrReplaceTempView("trendSQL")

The binary variable indicates if the trend goes up or down.

### 1.2 Twitter

In [7]:
# define data dir
data_dir = "../../data/Topic/"

# get all twitter files
tweet_files = [os.path.join(data_dir, obs) for obs in os.listdir(data_dir)] 

In [8]:
# import twitter data 
#twitter_df = spark.read.json(tweet_files)

In [9]:
list_hashtags = ["vegan",
               "veganism",
               "vegetarian",
                "veganfood",
                "vegano",
                "veganrecipes",
                "vegansofig",
                "vegansofinstagram"]

data_dir = ".././../data/Topic/"
tweet_files = [os.path.join(data_dir, obs) for obs in os.listdir(data_dir)]
files_hashtags = [file for file in tweet_files if (file.find(list_hashtags[0]) != -1)]             
twitter_df = spark.read.option("multiline","true").json(files_hashtags) 
twitter_df.count()

                                                                                

22/12/07 18:38:57 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


                                                                                

1827680

In [10]:
# select interesting features
twitter_df = twitter_df.select(F.col('user.name'),
                                F.col('user.screen_name'),
                                F.col('user.followers_count'),
                                F.col('user.following'),
                                F.col('user.statuses_count'),
                                F.col('user.listed_count'),
                                F.col('created_at'),
                                F.col('full_text'),
                                F.col('entities.hashtags'),
                                F.col('favorite_count'),
                                F.col('retweet_count'),
                                F.col('user.friends_count'))

## 2. Data Preprocessing

#### 2.1 Check time period

In [11]:
# function to convert Twitter date string format
def getDate(date):
    if date is not None:
        return str(datetime.strptime(date,'%a %b %d %H:%M:%S +0000 %Y').replace(tzinfo=pytz.UTC).strftime("%Y-%m-%d %H:%M:%S"))
    else:
        return None

# UDF declaration
date_udf = F.udf(getDate, StringType())

# apply udf
twitter_df = twitter_df.withColumn('post_created_at', F.to_utc_timestamp(date_udf("created_at"), "UTC"))

In [12]:
# get first post
first_post = F.min('post_created_at').alias('earliest')
# get latest post
latest_post = F.max('post_created_at').alias('latest')
# show tweet period in our dataset
twitter_df.select(first_post, latest_post).show()



+-------------------+-------------------+
|           earliest|             latest|
+-------------------+-------------------+
|2021-10-25 07:19:40|2022-10-11 23:17:33|
+-------------------+-------------------+



                                                                                

#### 2.2 Remove retweets and duplicates

In [13]:
# drop all retweets from dataset
no_retweets_df = twitter_df.filter(~F.col("full_text").startswith("RT"))

In [14]:
# first sort no_retweets_df based on date in chronological order (most recent ones on top)
no_retweets_sorted_df = no_retweets_df.sort("post_created_at", ascending=False)

In [15]:
# number of observations before dropping duplicates
no_retweets_sorted_df.count()

                                                                                

745916

In [16]:
# drop duplicates based on tweet text and the profile it was posted from
final_no_duplicates_df = no_retweets_sorted_df.drop_duplicates(["full_text", "screen_name"])

In [17]:
# number of observations after dropping duplicates
final_no_duplicates_df.count()

                                                                                

693932

In [18]:
# rename dataframe
final_twitter_df = final_no_duplicates_df

## 3. Independent Variables

For our independent variables we need to design a pipeline that transforms the data into the desired aggregated metrics per day.

In [19]:
# create SQL view
final_twitter_df.createOrReplaceTempView("twitterSQL")

### 3.1 Volume of tweets 

In [20]:
# select the relevant data
tweet_volume = spark.sql("SELECT DATE_FORMAT(post_created_at, 'Y-M-dd') as date, COUNT(*) as tweet_volume \
                                    FROM twitterSQL \
                                    GROUP BY DATE_FORMAT(post_created_at, 'Y-M-dd') \
                                    ORDER BY DATE_FORMAT(post_created_at, 'Y-M-dd')")

In [21]:
# show 
spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")
tweet_volume.show(100)

[Stage 25:>                                                         (0 + 8) / 9]

+----------+------------+
|      date|tweet_volume|
+----------+------------+
|2021-10-25|          50|
|2021-10-26|          45|
|2021-10-27|         894|
|2021-10-28|        2825|
|2021-10-29|       14021|
|2021-10-30|       12497|
|2021-10-31|       12414|
|2021-11-01|       24108|
|2021-11-02|       17623|
|2021-11-03|        3316|
|2021-11-04|        2560|
|2021-11-05|         593|
|2021-11-06|           6|
|2021-12-03|           4|
|2021-12-04|          66|
|2021-12-05|          72|
|2021-12-06|        1336|
|2021-12-07|        4560|
|2021-12-08|       13077|
|2021-12-09|       12693|
|2021-12-10|       13848|
|2021-12-11|       12213|
|2021-12-12|       10589|
|2021-12-13|        2930|
|2021-12-14|        1941|
|2021-12-15|        1596|
|2021-12-16|         107|
|2021-12-25|         637|
| 2022-1-01|        1744|
| 2022-1-02|         973|
| 2022-1-08|        1326|
| 2022-1-09|        1672|
| 2022-1-10|        1886|
| 2022-1-11|        1891|
| 2022-1-12|        1996|
| 2022-1-13|

                                                                                

In [22]:
# create SQL view
tweet_volume.createOrReplaceTempView("tweet_volumeSQL")

### 3.2 Average likes

We exclude tweets with 0 likes.

In [23]:
# select the relevant data
avg_likes = spark.sql("SELECT DATE_FORMAT(post_created_at, 'Y-M-dd') as date, AVG(favorite_count) as avg_likes \
                           FROM twitterSQL \
                           WHERE favorite_count > 0 \
                           GROUP BY DATE_FORMAT(post_created_at, 'Y-M-dd') \
                           ORDER BY DATE_FORMAT(post_created_at, 'Y-M-dd')")

In [24]:
# show 
spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")
avg_likes.show()

[Stage 36:>                                                         (0 + 8) / 9]

+----------+------------------+
|      date|         avg_likes|
+----------+------------------+
|2021-10-25|           4.65625|
|2021-10-26|             5.125|
|2021-10-27|11.645669291338583|
|2021-10-28|11.103731815306768|
|2021-10-29| 12.31424108305129|
|2021-10-30|11.979163693449408|
|2021-10-31|12.956186317321688|
|2021-11-01| 13.27580421620833|
|2021-11-02| 8.794319501636576|
|2021-11-03|15.065796937039138|
|2021-11-04|10.239657631954351|
|2021-11-05| 3.459016393442623|
|2021-11-06|               2.5|
|2021-12-03|              20.0|
|2021-12-04|             10.08|
|2021-12-05|5.7105263157894735|
|2021-12-06| 11.53735255570118|
|2021-12-07|26.699334319526628|
|2021-12-08| 14.52754383542731|
|2021-12-09|12.793646370349729|
+----------+------------------+
only showing top 20 rows



                                                                                

In [25]:
# create SQL view
avg_likes.createOrReplaceTempView("avg_likesSQL")

### 3.3 Average Retweets

We exclude tweets with 0 retweets.

In [26]:
# select the relevant data
avg_retweets = spark.sql("SELECT DATE_FORMAT(post_created_at, 'Y-M-dd') as date, AVG(retweet_count) as avg_retweets \
                          FROM twitterSQL \
                          WHERE retweet_count > 0 \
                          GROUP BY DATE_FORMAT(post_created_at, 'Y-M-dd') \
                          ORDER BY DATE_FORMAT(post_created_at, 'Y-M-dd')")

In [27]:
# show 
spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")
avg_retweets.show()

[Stage 47:>                                                         (0 + 8) / 9]

+----------+------------------+
|      date|      avg_retweets|
+----------+------------------+
|2021-10-25|               3.0|
|2021-10-26| 4.166666666666667|
|2021-10-27| 5.993006993006993|
|2021-10-28| 5.175879396984925|
|2021-10-29|6.7106673161227475|
|2021-10-30| 5.183630640083946|
|2021-10-31| 6.077004219409282|
|2021-11-01| 6.752923976608187|
|2021-11-02| 4.722175732217573|
|2021-11-03| 8.869379014989294|
|2021-11-04| 5.420485175202156|
|2021-11-05|1.7222222222222223|
|2021-11-06|               1.5|
|2021-12-03|               2.5|
|2021-12-04|2.5714285714285716|
|2021-12-05|2.7777777777777777|
|2021-12-06| 7.503703703703704|
|2021-12-07|13.976780185758514|
|2021-12-08| 6.731100963977676|
|2021-12-09|  6.67574931880109|
+----------+------------------+
only showing top 20 rows



                                                                                

In [28]:
# create SQL view
avg_retweets.createOrReplaceTempView("avg_retweetsSQL")

### 3.4 Engagement rate

We define engagement rate of a tweet as the sum of likes and retweets divided by the amount of followers of the account that sent out the tweet. For our purpose we will take the avergage engagement rate per day. We exclude accounts who have no followers and we only take tweets into account which are liked and retweeted at least once.

In [29]:
# select the relevant data
avg_engagement_rate = spark.sql("SELECT DATE_FORMAT(post_created_at, 'Y-M-dd') as date, AVG(engagement_rate) as avg_engagement_rate \
                                     FROM (  SELECT screen_name, post_created_at, (favorite_count+retweet_count)/followers_count as engagement_rate \
                                             FROM twitterSQL \
                                             WHERE favorite_count > 0 AND retweet_count > 0 AND followers_count > 0 ) \
                                     GROUP BY DATE_FORMAT(post_created_at, 'Y-M-dd') \
                                     ORDER BY DATE_FORMAT(post_created_at, 'Y-M-dd')")

In [30]:
# show
spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")
avg_engagement_rate.show()



22/12/07 18:44:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:44:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:44:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:44:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:44:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:44:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:44:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.




+----------+--------------------+
|      date| avg_engagement_rate|
+----------+--------------------+
|2021-10-25|0.035312352622552404|
|2021-10-26|0.041442045473123704|
|2021-10-27|  0.0741752578942463|
|2021-10-28|  0.2127516303746722|
|2021-10-29| 0.07491809784621552|
|2021-10-30| 0.11431344329991702|
|2021-10-31| 0.19613345872986776|
|2021-11-01| 0.09034112086921374|
|2021-11-02| 0.06331765741485894|
|2021-11-03|   0.342346333831607|
|2021-11-04|0.056622317733272476|
|2021-11-05| 0.21296895770236038|
|2021-11-06|0.005484460694698354|
|2021-12-03| 0.01529917011031044|
|2021-12-04|  0.3279277126010579|
|2021-12-05|0.009540321788060942|
|2021-12-06| 0.06278172079703152|
|2021-12-07| 0.22602079420407264|
|2021-12-08|  0.1298943012098277|
|2021-12-09| 0.06074244817993019|
+----------+--------------------+
only showing top 20 rows



                                                                                

In [31]:
# create SQL view
avg_engagement_rate.createOrReplaceTempView("avg_engagement_rateSQL")

### 3.5 Number of influencers

We will calculate how many influencers actively tweeted a certain day. We define an influencer as someone with:
- followers > 1000 
- engagement_rate > 0.20 
- weekly tweet frequency > 5

In [32]:
def get_influencers(follower_count_tresh, eng_rate_tresh, freq_week_tresh, data):

    #df
    df = data
    
    # get all users with their amount of followers
    influencers = df.groupBy("screen_name") \
                    .agg(first("followers_count").alias("followers_count"))

    # average engagement rate for each user
    eng_rate = df.withColumn('eng_rate', ((df['favorite_count'] + df['retweet_count'])/df['followers_count']))

    eng_rate_user = eng_rate.groupBy("screen_name") \
                            .agg(avg("eng_rate").alias("eng_rate"))

    # average freq_weekly per user
    freq_week = df.withColumn("year", year(df["post_created_at"]))
    freq_week = freq_week.withColumn('week', weekofyear('post_created_at'))

    freq_week = freq_week.groupBy('screen_name', 'year', 'week').agg(countDistinct("full_text"))\
                    .withColumnRenamed("count(full_text)", "freq") \
                        .sort('screen_name', 'year', 'week', ascending = True)
    freq_week = freq_week.select('screen_name', 'freq')

    freq_week = freq_week.groupby("screen_name").agg(avg(freq_week.freq).alias('freq'))

    # put the data together
    data_joined = eng_rate_user.join(influencers, "screen_name").join(freq_week, "screen_name")

    # filter the data
    data_joined = data_joined.filter((data_joined.followers_count > follower_count_tresh) & (data_joined.eng_rate > eng_rate_tresh) & (data_joined.freq > freq_week_tresh))
    
    # show the data
    data_joined.show()
    return data_joined

In [33]:
influencers = get_influencers(1000, 0.002, 2, final_twitter_df)

[Stage 71:(140 + 3) / 143][Stage 73:=>  (3 + 5) / 9][Stage 75:>   (0 + 0) / 9]3]

22/12/07 18:50:08 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:50:08 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 86:>                                                         (0 + 8) / 9]

+---------------+--------------------+---------------+------------------+
|    screen_name|            eng_rate|followers_count|              freq|
+---------------+--------------------+---------------+------------------+
|      5GenocIDe|0.003218405440206144|           1140| 4.942857142857143|
|        AQUAB23|0.022003034901365705|           1318|               3.0|
|AlsJane_therapy|0.008247976142192238|           6226|               2.5|
|AmazingArbuckle|0.003063373540111...|           3482|               3.0|
|   AmeliaLynn70|0.014513189093212512|           2234|2.3333333333333335|
|Antoniosaiyajin|0.005135345260946718|           3699|               3.0|
|   BDAWOSBranch|0.002719854941069...|           1103|               3.0|
|    BlogofVegan|0.003437569278129488|           9257| 5.115384615384615|
|   BrianKateman|0.004763913172491486|           1542|               3.5|
|   CathyGreen67|0.003029875597498...|           1161|3.1666666666666665|
|   ChubbieVegan|0.003564221783895...|

                                                                                

In [34]:
# create SQL view
influencers.createOrReplaceTempView("influencersSQL")

In [35]:
# select the relevant data
number_of_influencers = spark.sql(" SELECT DATE_FORMAT(a.post_created_at, 'Y-M-dd') as date, COUNT(b.screen_name) as influencers \
                                    FROM twitterSQL a \
                                    RIGHT OUTER JOIN influencersSQL b ON a.screen_name = b.screen_name\
                                    GROUP BY DATE_FORMAT(post_created_at, 'Y-M-dd') \
                                    ORDER BY DATE_FORMAT(post_created_at, 'Y-M-dd')")

In [36]:
# show
spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")
number_of_influencers.show()



22/12/07 18:53:43 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:53:43 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.




+----------+-----------+
|      date|influencers|
+----------+-----------+
|2021-10-26|          1|
|2021-10-27|         47|
|2021-10-28|        124|
|2021-10-29|        764|
|2021-10-30|        580|
|2021-10-31|        631|
|2021-11-01|       1216|
|2021-11-02|        887|
|2021-11-03|        180|
|2021-11-04|        147|
|2021-11-05|         20|
|2021-12-04|          3|
|2021-12-05|          6|
|2021-12-06|         26|
|2021-12-07|        264|
|2021-12-08|        584|
|2021-12-09|        698|
|2021-12-10|        759|
|2021-12-11|        775|
|2021-12-12|        528|
+----------+-----------+
only showing top 20 rows



                                                                                

In [37]:
# create SQL view
number_of_influencers.createOrReplaceTempView("number_of_influencersSQL")

## 4. Basetable

In [38]:
# create basetable
basetable = spark.sql("SELECT DATE_FORMAT(a.date, 'Y-M-dd') as date, a.dependent_vegan, b.tweet_volume, COALESCE(c.avg_likes,0) as avg_likes, \
                       COALESCE(d.avg_retweets,0) as avg_retweets, \
                       COALESCE(e.avg_engagement_rate,0) as avg_engagement_rate, COALESCE(f.influencers,0) as influencers \
                       FROM trendSQL a \
                       INNER JOIN tweet_volumeSQL b ON DATE_FORMAT(a.date, 'Y-M-dd') = b.date \
                       LEFT OUTER JOIN avg_likesSQL c ON b.date = c.date \
                       LEFT OUTER JOIN avg_retweetsSQL d ON c.date = d.date \
                       LEFT OUTER JOIN avg_engagement_rateSQL e ON d.date = e.date \
                       LEFT OUTER JOIN number_of_influencersSQL f ON e.date = f.date")

In [39]:
# show
basetable.show(50)

[Stage 167:(141 + 2) / 143][Stage 171:>  (0 + 6) / 9][Stage 173:>  (0 + 0) / 9]]

22/12/07 18:59:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 171:==>(7 + 2) / 9][Stage 173:>  (0 + 6) / 9][Stage 175:>  (0 + 0) / 9]]

22/12/07 18:59:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 171:==>(8 + 1) / 9][Stage 173:>  (0 + 7) / 9][Stage 175:>  (0 + 0) / 9]

22/12/07 18:59:57 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 175:>  (0 + 8) / 9][Stage 177:>  (0 + 0) / 9][Stage 179:>  (0 + 0) / 9]

22/12/07 18:59:57 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:57 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:57 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:57 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 179:>  (0 + 8) / 9][Stage 182:>  (0 + 0) / 9][Stage 184:>  (0 + 0) / 9]

22/12/07 18:59:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 18:59:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

+----------+---------------+------------+------------------+------------------+--------------------+-----------+
|      date|dependent_vegan|tweet_volume|         avg_likes|      avg_retweets| avg_engagement_rate|influencers|
+----------+---------------+------------+------------------+------------------+--------------------+-----------+
|2021-11-03|              1|        3316|15.065796937039138| 8.869379014989294|   0.342346333831607|        180|
| 2022-8-15|              0|        1456| 22.95566502463054|        8.76171875|0.061455990873796947|        209|
| 2022-3-03|              0|          26|2.6470588235294117|               1.0| 0.03432893276873259|          1|
|2021-10-25|              0|          50|           4.65625|               3.0|0.035312352622552404|          0|
| 2022-6-13|              0|         108| 7.879310344827586| 2.675675675675676| 0.03734894322976568|         13|
| 2022-8-14|              1|        1194|13.847765363128492| 5.109913793103448| 0.19406459767799

In [40]:
# import the required functions
from pyspark.ml.feature import Binarizer, StringIndexer, VectorIndexer, VectorAssembler, OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.sql.types import DoubleType

In [41]:
# define string indexer to index price 
SI = StringIndexer(inputCol = 'dependent_vegan', outputCol = 'label')

# define vector assembler for numeric variables
numColumns = ['avg_likes','avg_retweets','avg_engagement_rate','influencers']
VAnum = VectorAssembler(inputCols=numColumns, outputCol="numFeatures")

In [42]:
# define pipeline stages
stages = [SI, VAnum]
# define pipeline and fit on data
preprocessingPipeline = Pipeline().setStages(stages).fit(basetable)
# apply pipeline on data
basetable = preprocessingPipeline.transform(basetable)

[Stage 268:(142 + 1) / 143][Stage 272:>  (0 + 7) / 9][Stage 274:>  (0 + 0) / 9] 

22/12/07 19:08:27 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 272:==>(6 + 3) / 9][Stage 274:>  (0 + 5) / 9][Stage 276:>  (0 + 0) / 9]

22/12/07 19:08:27 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:27 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:27 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:28 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 276:>  (0 + 8) / 9][Stage 278:>  (0 + 0) / 9][Stage 280:>  (0 + 0) / 9]

22/12/07 19:08:28 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:28 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:28 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:28 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:28 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 280:>  (0 + 8) / 9][Stage 283:>  (0 + 0) / 9][Stage 285:>  (0 + 0) / 9]

22/12/07 19:08:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:08:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

In [43]:
# select features and labels
basetable = basetable.select(["numFeatures", "label"])

In [44]:
# check
basetable.show(5)

[Stage 375:(142 + 1) / 143][Stage 381:>  (0 + 7) / 9][Stage 383:>  (0 + 0) / 9] 

22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 383:>  (0 + 8) / 9][Stage 385:>  (0 + 0) / 9][Stage 387:>  (0 + 0) / 9]

22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:45 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 387:>  (0 + 8) / 9][Stage 390:>  (0 + 0) / 9][Stage 393:>  (0 + 0) / 9]

22/12/07 19:14:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:14:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

+--------------------+-----+
|         numFeatures|label|
+--------------------+-----+
|[15.0657969370391...|  0.0|
|[22.9556650246305...|  1.0|
|[2.64705882352941...|  1.0|
|[4.65625,3.0,0.03...|  1.0|
|[7.87931034482758...|  1.0|
+--------------------+-----+
only showing top 5 rows



**Logistic Regression**
- Split the data in a train and test set (70/30).
- Build one pipeline that:
  - standardizes the numerical variables
  - applies a logistic regression to the data
  - check the performance using the AUC.

We cannot use the randomsplit function, because we have time series data, so we have to use another approach

First we look at the amount of observations that will be assigned to the training set 

In [45]:
nr_train = int(basetable.count()*0.7)
nr_train

[Stage 476:(141 + 2) / 143][Stage 478:==>(6 + 3) / 9][Stage 480:>  (0 + 3) / 9] 

22/12/07 19:21:04 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:04 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 480:==>(8 + 1) / 9][Stage 482:>  (0 + 7) / 9][Stage 484:>  (0 + 0) / 9]]

22/12/07 19:21:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 482:==>(8 + 1) / 9][Stage 484:>  (0 + 7) / 9][Stage 486:>  (0 + 0) / 9]

22/12/07 19:21:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 488:>  (0 + 8) / 9][Stage 491:>  (0 + 0) / 9][Stage 493:>  (0 + 0) / 9]

22/12/07 19:21:06 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:06 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:06 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:06 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:06 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:21:06 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

165

convert the final basetable to a pandas dataset 

In [46]:
basetable_pd = basetable.toPandas()
basetable_pd.head()

[Stage 583:(140 + 3) / 143][Stage 585:==>(6 + 3) / 9][Stage 587:>  (0 + 2) / 9]]

22/12/07 19:27:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:17 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 583:(142 + 1) / 143][Stage 587:==>(6 + 3) / 9][Stage 589:>  (0 + 4) / 9]

22/12/07 19:27:17 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:17 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:17 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:17 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 589:==>(7 + 2) / 9][Stage 591:>  (0 + 6) / 9][Stage 593:>  (0 + 0) / 9]

22/12/07 19:27:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 593:==>(8 + 1) / 9][Stage 595:>  (0 + 7) / 9][Stage 598:>  (0 + 0) / 9]

22/12/07 19:27:19 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:19 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:19 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:19 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:27:19 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

Unnamed: 0,numFeatures,label
0,"[15.065796937039138, 8.869379014989294, 0.3423...",0.0
1,"[22.95566502463054, 8.76171875, 0.061455990873...",1.0
2,"[2.6470588235294117, 1.0, 0.03432893276873259,...",1.0
3,"[4.65625, 3.0, 0.035312352622552404, 0.0]",1.0
4,"[7.879310344827586, 2.675675675675676, 0.03734...",1.0


Split the dataframe into train and test 

In [47]:
train_pd = basetable_pd.iloc[:nr_train,:]
test_pd = basetable_pd.iloc[nr_train:,:]
train = spark.createDataFrame(train_pd)
test = spark.createDataFrame(test_pd)

  for column, series in pdf.iteritems():
  for column, series in pdf.iteritems():


In [48]:
# check number of observations in train and test set
print(train.count())
print(test.count())

[Stage 670:>                                                        (0 + 8) / 8]

165
71


                                                                                

In [49]:
# inspect distribution of label in train and test set
basetable.groupBy("label").count().show()
train.groupBy("label").count().show()
test.groupBy("label").count().show()

[Stage 690:(140 + 3) / 143][Stage 692:==>(6 + 3) / 9][Stage 694:>  (0 + 2) / 9]]

22/12/07 19:33:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 696:==>(6 + 3) / 9][Stage 698:>  (0 + 5) / 9][Stage 700:>  (0 + 0) / 9]]

22/12/07 19:33:55 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 698:>  (0 + 8) / 9][Stage 700:>  (0 + 0) / 9][Stage 702:>  (0 + 0) / 9]

22/12/07 19:33:55 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 700:==>(7 + 2) / 9][Stage 702:>  (0 + 6) / 9][Stage 705:>  (0 + 0) / 9]

22/12/07 19:33:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:33:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:33:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:33:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:33:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:33:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/12/07 19:33:56 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

+-----+-----+
|label|count|
+-----+-----+
|  0.0|  121|
|  1.0|  115|
+-----+-----+

+-----+-----+
|label|count|
+-----+-----+
|  0.0|   84|
|  1.0|   81|
+-----+-----+

+-----+-----+
|label|count|
+-----+-----+
|  0.0|   37|
|  1.0|   34|
+-----+-----+



In [50]:
# import required features
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [51]:
# define scaler
SS = StandardScaler(inputCol = 'numFeatures', outputCol = 'scaledNumFeatures', withStd = True, withMean = False)

# define vector assembler
VA = VectorAssembler(inputCols = ['scaledNumFeatures'], outputCol = 'features')

# define logistic regression model
LR = LogisticRegression(labelCol = 'label', featuresCol = 'features', maxIter = 10)

In [52]:
# define pipeline stages
stages = [SS, VA, LR]
# create pipeline and fit on training set
lrModelPipeline = Pipeline().setStages(stages).fit(train)
# apply pipeline on test set to get predictions
predictions = lrModelPipeline.transform(test)

22/12/07 19:34:01 WARN InstanceBuilder$JavaBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
22/12/07 19:34:01 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
22/12/07 19:34:01 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS


In [53]:
# inspect predictions
predictions.show(5)

+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+
|         numFeatures|label|   scaledNumFeatures|            features|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+--------------------+--------------------+----------+
|[5.49019607843137...|  1.0|[1.03296060626405...|[1.03296060626405...|[-0.0504193216419...|[0.48739783914756...|       1.0|
|[12.6830835117773...|  1.0|[2.38627645469564...|[2.38627645469564...|[-0.1032990966531...|[0.47419816540094...|       1.0|
|[11.2695214105793...|  0.0|[2.12031983963381...|[2.12031983963381...|[0.25179346522256...|[0.56261788446566...|       0.0|
|[42.0873859045338...|  1.0|[7.91858998093160...|[7.91858998093160...|[-0.1845362087328...|[0.45399642278111...|       1.0|
|[17.9372881355932...|  0.0|[3.37483611925374...|[3.37483611925374...|[-0.0175440398049...|[0.49561410254401...|       1.0|
+-------

In [54]:
# define evaluator
evaluator = BinaryClassificationEvaluator()
# get evaluation metric
lrAUC = evaluator.evaluate(predictions, {evaluator.metricName: 'areaUnderROC'})
# inspect model performance
print('AUC lr: %f' %(lrAUC))

AUC lr: 0.503975
