# Initialize pyspark environment

In [1]:
import findspark

# initialize findspark with spark directory

#ALWAYS HAVE TO BE CHANGED 
#path = "/Users/konstantinlazarov/Desktop/Big_Data/PySpark/Week_5/spark"
path = "/Users/Artur/spark"
findspark.init(path) 

# import pyspark
import pyspark
# create spark context
sc = pyspark.SparkContext()
# create spark session 
spark = pyspark.sql.SparkSession(sc)

In [2]:
'''Voor Konstantin 
# import pyspark
import pyspark
# create spark context
sc = pyspark.SparkContext()
# create spark session 
spark = pyspark.sql.SparkSession(sc)
'''

'Voor Konstantin \n# import pyspark\nimport pyspark\n# create spark context\nsc = pyspark.SparkContext()\n# create spark session \nspark = pyspark.sql.SparkSession(sc)\n'

# Import necessary packages and data

#### Import necessary packages

In [3]:
# import packages
import os 
import pickle

import re
from datetime import datetime
import requests

import pytz
import emojis

import pandas as pd
import numpy as np

import ast

import pyspark.sql.functions as F
from pyspark.sql.types import *

from pyspark.ml.feature import Tokenizer
from pyspark.ml.feature import StopWordsRemover

import tweepy
import csv
import time
import pandas as pd
import datetime
import os
import json
from pandas.tseries.holiday import nearest_workday, \
    AbstractHolidayCalendar, Holiday, \
    USMartinLutherKingJr, USPresidentsDay, GoodFriday, \
    USMemorialDay, USLaborDay, USThanksgivingDay
from datetime import date

from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
from string import punctuation
from nltk.corpus import stopwords
from textblob import TextBlob
from nltk.tokenize import word_tokenize

from nltk.tokenize.treebank import TreebankWordDetokenizer
import re

#### Import the twitter data 

In [4]:
list_brands = ["healthyfood",
               "healthylifestyle",
               "vegan",
               "keto",
               "ketodiet",
               "ketolifestyle",
               "veganism",
               "vegetarian"]
from re import search



data_dir = ".././../data/Topic/"
tweet_files = [os.path.join(data_dir, obs) for obs in os.listdir(data_dir)]



files_brand = [file for file in tweet_files if (file.find(list_brands[2]) != -1)]
files_brand               
               
df_json = spark.read.option("multiline","true").json(files_brand)  
df_json.count()



1827680

# Predict the engagement rate of a tweet

## 1. Goal of our analysis

In this notebook, we are going to predict the engagement rate of tweets. Further, it will be interesting to see the driving factors behind the engagement rate. This can be valuable information when creating an own social media brand or when you want to increase the reach of your tweets.

We start by selecting the interesting variables for this analysis.

In [5]:
# Select the interesting variables
basetable_engr = df_json.select(F.col('created_at').alias('tweet_created'), \
                                   F.col('entities.symbols').alias('symbols'), \
                                   F.col('display_text_range').alias('text_range'), \
                                   F.col('extended_entities.media.type').alias('media_type'), \
                                   F.col('favorite_count'), \
                                   F.col('full_text'), \
                                   F.col('is_quote_status').alias('quoted'), \
                                   F.col('lang').alias('language'), \
                                   F.col('retweet_count'),\
                                   F.col('user.created_at').alias('user_created'), \
                                   F.col('user.followers_count').alias('user_followers'), \
                                   F.col('user.friends_count').alias('user_following'), \
                                   F.col('user.verified').alias('user_verified'), \
                                   F.col("user.screen_name"), \
                                   F.col('user.statuses_count').alias('nr_tweets_by_user'))

Next, we perform some basic preprocessing steps on this dataset.

In [6]:
# First, we  convert Twitter date string format
def getDate(date):
    if date is not None:
        return str(datetime.datetime.strptime(date,'%a %b %d %H:%M:%S +0000 %Y').replace(tzinfo=pytz.UTC).strftime("%Y-%m-%d %H:%M:%S"))
    else:
        return None

# UDF declaration
date_udf = F.udf(getDate, StringType())

# apply udf
basetable_engr = basetable_engr.withColumn('tweet_created', F.to_utc_timestamp(date_udf("tweet_created"), "UTC"))
basetable_engr = basetable_engr.withColumn('user_created', F.to_utc_timestamp(date_udf("user_created"), "UTC"))

In [7]:
#drop duplicates and retweets 
basetable_engr = basetable_engr.filter(~F.col("full_text").startswith("RT"))\
                        .drop_duplicates()

#sorting such when dropping later we only keep the most recent post 
basetable_engr = basetable_engr.sort("tweet_created", ascending=False)

#removing spam accounts 
basetable_engr = basetable_engr.drop_duplicates(["full_text", "screen_name"])

Before we start the feature engineering, we need to filter our data. As we will use the vader package to determine the sentiment later in this notebook, we will only work with English tweets. As our insights in the data will be primarily for European and American companies and most of our tweets are in English, we do not see this as a problem.

In [8]:
# filter the data on language
basetable_engr = basetable_engr.filter(F.col("language") == "en")

In [9]:
## testing if we can compile the model for a smaller data set 
basetable_engr = basetable_engr.sample(0.01) #takiong 10 procent of the data 
basetable_engr.count()

4664

# 2. Create the dependent variable

AAN TE PASSEN

First, we start by defining our dependent variable. The engagement rate has already been discussed in the data exploration section. We will use the same definition in order to create a model to predict the engagement rate. Below, we repeat this definition:

Engagement on Twitter is measured by the number of retweets, follows, replies, favorites, and other people’s reactions to your tweets, including the clicks on the links and hashtags in those tweets. Your Twitter engagement rate is your engagement figure divided by the number of impressions on the tweet.

In [10]:
# add engagement rate to the dataframe
basetable_engr = basetable_engr.withColumn('eng_rate', ((basetable_engr['favorite_count'] + basetable_engr['retweet_count'])/basetable_engr['user_followers']))


Check if the dependent variable has no null values, this could be the case if the user has no followers. In this case, we will set the value of the engagement rate to zero.

In [11]:
# check missing values for the dependent variable
df = basetable_engr.select(F.col('eng_rate'))
df.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in df.columns]).toPandas()



Unnamed: 0,eng_rate
0,60


In [13]:
# handle the null values in the dependent variable for the created dataframe
df = df.fillna(0)

# inspect
df.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in df.columns]).toPandas()

Unnamed: 0,eng_rate
0,0


In [14]:
# handle the dependent variable in the basetable
basetable_engr = basetable_engr.fillna({'eng_rate': 0})

# 3. Feature engineering

In order to predict the engagement of a tweet, we will create some additional features in order to improve the performance of our model. For most of these features, we will first creata a function. The following features will be created:

    1) number of words
    2) number of hashtags
    3) number of tags
    4) number of emojis
    5) get the number of exclamation marks
    6) the month
    7) day of the month
    8) day of the week
    9) hour of the day
    10) The number of upper case words
    11) tweeted quote
    12) presence of a symbol
    13) The age of the account
    14) The number of media elements
    15) The media type present
    16) The number of text characters in the tweet
    17) Indicator if the account is verified

These features will be used next to some variables that are already present in the dataset.

    1) For the number of words, we can just use the function F.size() and apply it to the tokenized text. 

In [15]:
# 2) Define function to count hashtags
def get_hashtags(text):
    counter = 0
    for letter in text:
        if letter == "#":
            counter += 1
    return(counter)

In [16]:
# 3) Define function to count tags
def get_tags(text):
    counter = 0
    for letter in text:
        if letter == "@":
            counter += 1
    return(counter)

In [17]:
# 4) Define a function to get the number of emojis
def emoji_counter(text):
    nr_emojis = emojis.count(text)
    return(nr_emojis)

In [18]:
# 5) Define function to count exclamation marks
def get_exclamation_marks(text):
    counter = 0
    for letter in text:
        if letter ==  "!":
            counter += 1
    return(counter)

    6) the month
    7) day of the month
    8) day of the week
    9) hour of the day

We saw how to create these variables when solving the questions for this assignment. The same code will be used here. 

In [19]:
# 10) Define number of upper case words
def get_upper_case_words(text):
    counter = 0
    
    ## Tokenize
    word_tokens = word_tokenize(text)

    ## Check for uppercase words
    for word in word_tokens:
        if word.isupper():
            counter += 1
    return(counter)

In [20]:
# 11) Define a function that indicates if the tweet was a quote
def tweeted_quote_indicator(quoted):
    quote = 0
    if quoted == True:
        quote = 1
    return quote

In [21]:
# 12) define a function that indicates the presence of a symbol
def symbol_indicator(symbols):
    symbol = 0
    if(symbols > 0):
        symbol = 1
    return symbol

    13) Define the age of the account. This is defined as the number of days since the account has been created and the last day of scraping (2022-10-11). The last day of scraping was calculated in the exploration phase of the data. For this variable, we will use the function datediff.

In [22]:
# 14) define a function to get help get the number of media types included in the tweet
def adjust_nr_media(number):
    if number == -1:
        number = 0
        
    return number

In [23]:
# 15) define a function to get the first media element
def get_media_type(media):
    if media == None:
        media = 'no_media'
    else:
        media = media[0]
       
    return media

In [24]:
# 16) define a function to get the first media element
def get_nr_text_characters(text_range):
    number = text_range[1] - text_range[0]  
    return number



In [25]:
# 17) Look if the user is a verified user
def verified_ind(verified):
    indicator = 0
    if verified == True:
        indicator = 1
    return indicator

In [26]:
# register the functions as an udf
get_upper_case_words_UDF = F.udf(get_upper_case_words, IntegerType()) 
emoji_counter_udf = F.udf(emoji_counter, IntegerType())
get_hashtags_udf = F.udf(get_hashtags, IntegerType())
get_tags_udf = F.udf(get_tags, IntegerType())
get_exclamation_marks_UDF = F.udf(get_exclamation_marks, IntegerType())
tweeted_quote_indicator_UDF = F.udf(tweeted_quote_indicator, IntegerType())
symbol_indicator_udf = F.udf(symbol_indicator, IntegerType())
adjust_nr_media_udf = F.udf(adjust_nr_media, IntegerType())
get_media_type_udf = F.udf(get_media_type, StringType())
get_nr_text_characters_udf = F.udf(get_nr_text_characters, IntegerType())
verified_ind_udf = F.udf(verified_ind, IntegerType())

In [27]:
# Already define some text cleaning steps in order to define the correct number of words.

# define puncutation and stopwords
PUNCTUATION = [char for char in punctuation if char not in ["!", "@", "#"]]
STOPWORDS = stopwords.words("english")

# define function to remove punctuation
def remove_punct(text):
    ## Remove punctuation
    text = "".join([char for char in text if char not in PUNCTUATION])
    return(text)

# define function to remove stopwords
def remove_stops(text_tokenized):
    # remove stopwords
    text_tokenized = [word for word in text_tokenized if word not in STOPWORDS]
    return(text_tokenized)

# register as udf
remove_punct_UDF = F.udf(remove_punct, StringType())
remove_stops_UDF = F.udf(remove_stops, ArrayType(StringType()))

In [28]:
# create the final basetable for our analysis
basetable_engr_final = basetable_engr.withColumn("num_emojis", emoji_counter_udf(F.col("full_text")))\
                            .withColumn('upper_case_words', get_upper_case_words_UDF('full_text'))\
                            .withColumn("text_lower", F.lower("full_text")) \
                            .withColumn("text_cleaned", remove_punct_UDF("text_lower")) \
                            .withColumn("text_tokenized", F.split("text_cleaned", " ")) \
                            .withColumn("text_tokenized_no_stops", remove_stops_UDF("text_tokenized")) \
                            .withColumn("num_words", F.size("text_tokenized_no_stops")) \
                            .withColumn("num_hashtags", get_hashtags_udf("text_tokenized_no_stops")) \
                            .withColumn("num_mentions", get_tags_udf("text_tokenized_no_stops")) \
                            .withColumn('nr_exlcamations', get_exclamation_marks_UDF('text_tokenized_no_stops'))\
                            .withColumn("week_day", F.date_format(F.col("tweet_created"), "E"))\
                            .withColumn("hour", F.date_format(F.col("tweet_created"), "H").cast('string'))\
                            .withColumn("month", F.date_format(F.col("tweet_created"), "M"))\
                            .withColumn("day_month", F.date_format(F.col("tweet_created"), "d"))\
                            .withColumn('quoted_ind', tweeted_quote_indicator_UDF('quoted'))\
                            .withColumn('symbol_ind', F.size('symbols'))\
                            .withColumn('symbol_ind', symbol_indicator_udf('symbol_ind'))\
                            .withColumn('user_age_days', F.datediff(F.lit("2022-10-11"), F.col("user_created")))\
                            .withColumn('verified', verified_ind_udf('user_verified'))\
                            .withColumn("nr_media_elements", F.size("media_type"))\
                            .withColumn("nr_media_elements", adjust_nr_media_udf("nr_media_elements"))\
                            .withColumn("media_type", get_media_type_udf('media_type'))\
                            .withColumn("nr_text_char", get_nr_text_characters_udf('text_range'))\
                            .drop('tweet_created')\
                            .drop('quoted')\
                            .drop('symbols')\
                            .drop('user_created')\
                            .drop('user_verified')\
                            .drop('display_text_range')\
                            .drop('text_lower')\
                            .drop('text_cleaned')\
                            .drop('text_tokenized')\
                            .drop('text_tokenized_no_stops')\
                            .drop('text_range')\
                            .filter("num_words > 0")



In [29]:
# inspect the data
basetable_engr_final.show(5)

+----------+--------------+--------------------+--------+-------------+--------------+--------------+--------------+-----------------+--------------------+----------+----------------+---------+------------+------------+---------------+--------+----+-----+---------+----------+----------+-------------+--------+-----------------+------------+
|media_type|favorite_count|           full_text|language|retweet_count|user_followers|user_following|   screen_name|nr_tweets_by_user|            eng_rate|num_emojis|upper_case_words|num_words|num_hashtags|num_mentions|nr_exlcamations|week_day|hour|month|day_month|quoted_ind|symbol_ind|user_age_days|verified|nr_media_elements|nr_text_char|
+----------+--------------+--------------------+--------+-------------+--------------+--------------+--------------+-----------------+--------------------+----------+----------------+---------+------------+------------+---------------+--------+----+-----+---------+----------+----------+-------------+--------+------

# 3. Text Cleaning 

In [30]:
import emoji

Now, we are going to clean the text of the twitter data. Then, we can use this cleaned text to extract features about the sensitivity of the tweet.

Here, we remove numbers, punctuation, urls... Further, we transform the emojis to words. Emojis are at the very core of communication over social channels. One small image can completely describe one or more human emotions. A naive thing to do during pre-processing would be to remove all emojis. This could result in significant loss of meaning.
A good way to achieve this is to replace the emoji with corresponding text explaining the emoji.

In [31]:
# define function to clean text
def clean_text(string):
    
    # define numbers
    NUMBERS = '0123456789'
    PUNCT = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
    
    # convert text to lower case
    cleaned_string = string.lower()
    
    # remove URLS
    cleaned_string = re.sub(r'http\S+', ' ', cleaned_string)
    
    # replace emojis by words
    cleaned_string = emoji.demojize(cleaned_string)
    cleaned_string = cleaned_string.replace(":"," ").replace("_"," ")
    cleaned_string = ' '.join(cleaned_string.split())
    
    # remove numbers
    cleaned_string = "".join([char for char in cleaned_string if char not in NUMBERS])
    
    # remove punctuation
    cleaned_string = "".join([char for char in cleaned_string if char not in PUNCT])
    
    # remove words conisting of one character (or less)
    cleaned_string = ' '.join([w for w in cleaned_string.split() if len(w) > 1])
    
    # return
    return(cleaned_string)

In [32]:
# convert to udf
clean_text_udf = F.udf(clean_text, StringType())

In [33]:
# clean string
basetable_engr_final = basetable_engr_final.withColumn("cleaned_text", clean_text_udf(F.col("full_text")))

In [30]:
#basetable_engr_final.select("full_text", "cleaned_text").limit(5).toPandas()

                                                                                

Unnamed: 0,full_text,cleaned_text
0,#BeforeIWasVegan I ate dairy cheese &amp; Cadb...,beforeiwasvegan ate dairy cheese amp cadbury m...
1,#CocaCola #cocacolacruelty #cocacoladairyfarm ...,cocacola cocacolacruelty cocacoladairyfarm fai...
2,#Gunpowder &amp; Lead #ManCave #SoyCandle | Ha...,gunpowder amp lead mancave soycandle hand pour...
3,#ICYMI from Folio.YVR Magazine: Folio.YVR Issu...,icymi from folioyvr magazine folioyvr issue lo...
4,#ICYMI ☆ NUDESTIX: On-the-Go Palettes &amp; P...,icymi nudestix onthego palettes amp products f...


Next, we tokenize the text. Then, we remove stop words. Finally, we use a spelling correction library to correct the spelling of the tweet data. This library from the textblob package is based on the Levenshtein distance. To correct the spelling, we first define a helper function.

In [31]:
#from textblob import TextBlob
# define helper function for spelling
#correct_spelling_udf = F.udf(lambda tokens: [TextBlob(token).correct() for token in tokens], ArrayType(StringType()))



In [32]:
#tokenize the cleaned_text variable 
#tokenizer = Tokenizer(inputCol="cleaned_text", outputCol="tokens")
#basetable_engr_final = tokenizer.transform(basetable_engr_final)

#remove stop words 
#remover = StopWordsRemover(inputCol="tokens", outputCol="clean_tokens")
#basetable_engr_final = remover.transform(basetable_engr_final)
#basetable_engr_final.select('tokens', 'clean_tokens').show()

# correct spelling
#basetable_engr_final = basetable_engr_final.withColumn("tokens_stemmed", correct_spelling_udf("clean_tokens"))

In [33]:
# inspect the data
#basetable_engr_final.select('clean_tokens', 'tokens_stemmed').show()

Now that the text is cleaned, we can derive the sentiment of the text. In this next section, we derive the sentiment, the subjectivity and the polarity. These are the tree final features we add to our model.

In [34]:
# define the function to extract the sentiment
def get_sentiment(sentence):

    # initialize sentiment analyzer
    sid_obj = SentimentIntensityAnalyzer()

    # get sentiment dict
    sentiment_dict = sid_obj.polarity_scores(sentence)
    
    # get positive sentiment score
    pos_sentiment = sentiment_dict["pos"]
    
    # return positive sentiment score
    return(pos_sentiment)

# define function to get polarity score of text 
def get_polarity(row):
    textBlob_review = TextBlob(row)
    return textBlob_review.sentiment[0]

# define function to get subjectivity score of text 
def get_subjectivity(row):
    textBlob_review = TextBlob(row)
    return textBlob_review.sentiment[1]


# register the functions as udf
get_sentiment_udf = F.udf(get_sentiment, DoubleType())
get_polarity_udf = F.udf(get_polarity, DoubleType())
get_subjectivity_udf = F.udf(get_subjectivity, DoubleType())

In [35]:
# Create the final basetable for our analysis
basetable_engr_final = basetable_engr_final.withColumn("sentiment", get_sentiment_udf(F.col("cleaned_text")))\
                                .withColumn('polarity', get_polarity_udf(F.col('cleaned_text')))\
                                .withColumn('subjectivity', get_subjectivity_udf(F.col('cleaned_text')))

# 4. Basetable creation

In [36]:
# Inspect the structure of the basetable:
basetable_engr_final.printSchema()

root
 |-- media_type: string (nullable = true)
 |-- favorite_count: long (nullable = true)
 |-- full_text: string (nullable = true)
 |-- language: string (nullable = true)
 |-- retweet_count: long (nullable = true)
 |-- user_followers: long (nullable = true)
 |-- user_following: long (nullable = true)
 |-- screen_name: string (nullable = true)
 |-- nr_tweets_by_user: long (nullable = true)
 |-- eng_rate: double (nullable = false)
 |-- num_emojis: integer (nullable = true)
 |-- upper_case_words: integer (nullable = true)
 |-- num_words: integer (nullable = false)
 |-- num_hashtags: integer (nullable = true)
 |-- num_mentions: integer (nullable = true)
 |-- nr_exlcamations: integer (nullable = true)
 |-- week_day: string (nullable = true)
 |-- hour: string (nullable = true)
 |-- month: string (nullable = true)
 |-- day_month: string (nullable = true)
 |-- quoted_ind: integer (nullable = true)
 |-- symbol_ind: integer (nullable = true)
 |-- user_age_days: integer (nullable = true)
 |-- ve

    - Check for missing values. If there are some - handle them (delete, impute,..). In this exercise: write code to handle them even if they aren't present.
    - Adjust datatypes where needed
    - Remove unnecessary columns
    - Add data pre-processing steps to the pipeline

## 4.1 Drop unnecessary columns

In [37]:
# drop (the 3 last variables were used to define the dependent variable)
basetable_engr_final = basetable_engr_final.drop('full_text')\
                                .drop('screen_name')\
                                .drop('language')\
                                .drop('cleaned_text')\
                                .drop('retweet_count')\
                                .drop('tokens_stemmed')\
                                .drop('favorite_count')\
                                .drop('user_followers')

#.drop('tokens')\
#.drop('clean_tokens')\

In [38]:
# inspect the data
basetable_engr_final.printSchema()

root
 |-- media_type: string (nullable = true)
 |-- user_following: long (nullable = true)
 |-- nr_tweets_by_user: long (nullable = true)
 |-- eng_rate: double (nullable = false)
 |-- num_emojis: integer (nullable = true)
 |-- upper_case_words: integer (nullable = true)
 |-- num_words: integer (nullable = false)
 |-- num_hashtags: integer (nullable = true)
 |-- num_mentions: integer (nullable = true)
 |-- nr_exlcamations: integer (nullable = true)
 |-- week_day: string (nullable = true)
 |-- hour: string (nullable = true)
 |-- month: string (nullable = true)
 |-- day_month: string (nullable = true)
 |-- quoted_ind: integer (nullable = true)
 |-- symbol_ind: integer (nullable = true)
 |-- user_age_days: integer (nullable = true)
 |-- verified: integer (nullable = true)
 |-- nr_media_elements: integer (nullable = true)
 |-- nr_text_char: integer (nullable = true)
 |-- sentiment: double (nullable = true)
 |-- polarity: double (nullable = true)
 |-- subjectivity: double (nullable = true)



In [39]:
# inspect the data
#basetable_engr_final.show(5)

[Stage 57:>                                                         (0 + 1) / 1]

+----------+--------------+-----------------+--------------------+----------+----------------+---------+------------+------------+---------------+--------+----+-----+---------+----------+----------+-------------+--------+-----------------+------------+---------+-------------------+------------------+
|media_type|user_following|nr_tweets_by_user|            eng_rate|num_emojis|upper_case_words|num_words|num_hashtags|num_mentions|nr_exlcamations|week_day|hour|month|day_month|quoted_ind|symbol_ind|user_age_days|verified|nr_media_elements|nr_text_char|sentiment|           polarity|      subjectivity|
+----------+--------------+-----------------+--------------------+----------+----------------+---------+------------+------------+---------------+--------+----+-----+---------+----------+----------+-------------+--------+-----------------+------------+---------+-------------------+------------------+
|     photo|          2497|            15906|0.008379888268156424|         0|               6|

                                                                                

## 4.2 Handle missing values

We see that there are no missing values in our dataset, so this step is completed.

In [39]:
# check number of missing values per column
for col in basetable_engr_final.columns:
    
    # look at the perecentage of null values
    print("Number of null values for variable", col,":", basetable_engr_final.filter(F.col(col).isNull()).count())

                                         

Number of null values for variable media_type : 0
Number of null values for variable user_following : 0
Number of null values for variable nr_tweets_by_user : 0
Number of null values for variable eng_rate : 0
Number of null values for variable num_emojis : 0
Number of null values for variable upper_case_words : 0
Number of null values for variable num_words : 0
Number of null values for variable num_hashtags : 0
Number of null values for variable num_mentions : 0
Number of null values for variable nr_exlcamations : 0
Number of null values for variable week_day : 0
Number of null values for variable hour : 0
Number of null values for variable month : 0
Number of null values for variable day_month : 0
Number of null values for variable quoted_ind : 0
Number of null values for variable symbol_ind : 0
Number of null values for variable user_age_days : 0
Number of null values for variable verified : 0
Number of null values for variable nr_media_elements : 0
Number of null values for variabl

In [41]:
# check missing values
basetable_engr_final.select([F.count(F.when(F.isnan(c), c)).alias(c) for c in basetable_engr_final.columns]).toPandas().head()


                                                                                

Unnamed: 0,media_type,user_following,nr_tweets_by_user,eng_rate,num_emojis,upper_case_words,num_words,num_hashtags,num_mentions,nr_exlcamations,...,day_month,quoted_ind,symbol_ind,user_age_days,verified,nr_media_elements,nr_text_char,sentiment,polarity,subjectivity
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


Further, we see that our variables all have the correct datatype.

# 5. Modelling

## 5.1 Pipelines

**Import required transformers and estimators**

In [57]:
# import pyspark ml packages
from pyspark.ml import Pipeline
from pyspark.ml.feature import StopWordsRemover, StandardScaler, Word2Vec
from pyspark.ml.feature import OneHotEncoder, VectorAssembler, StringIndexer, VectorIndexer
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor

# import models and evaluator
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

**Create a vector for each type of features (numeric, categorical)**

#### handle numeric features

In [48]:
# define the numeric variables
num_cols = ['user_following', 'nr_tweets_by_user', 'num_emojis', 'upper_case_words',
           'num_words', 'num_hashtags', 'num_mentions', 'nr_exlcamations',  'user_age_days', 
           'nr_media_elements', 'nr_text_char']

# define the assembler
num_vec_assembler = VectorAssembler(inputCols=num_cols, outputCol="num_features")



In [49]:
# define the standard scaler for the linear regression
SS = StandardScaler(inputCol="num_features", outputCol="scaled_numeric_features")

#### handle categorical features

First, we transform the text variables in our dataset into numerical categories

In [50]:
# define the categorical variables
cat_cols = ['media_type', 'hour', 'week_day', 'month', 'day_month', 'quoted_ind',
           'symbol_ind', 'verified']

# create an object of StringIndexer class for each categorical variable
SI_media = StringIndexer(inputCol= 'media_type', outputCol= 'media_type_index').setHandleInvalid("skip")
SI_hour = StringIndexer(inputCol= 'hour', outputCol= 'hour_index').setHandleInvalid("skip")
SI_week_day = StringIndexer(inputCol= 'week_day', outputCol= 'week_day_index').setHandleInvalid("skip")
SI_month = StringIndexer(inputCol= 'month', outputCol= 'month_index').setHandleInvalid("skip")
SI_day_month = StringIndexer(inputCol= 'day_month', outputCol= 'day_month_index').setHandleInvalid("skip")




Next, we use a VectorAssembler to assemble all indexed variables together.

In [51]:
# define the assembler
cat_cols_indexed = ['media_type_index', 'hour_index', 'week_day_index', 'month_index', 
                    'day_month_index','quoted_ind', 'symbol_ind', 'verified']
cat_vec_assembler = VectorAssembler(inputCols= cat_cols_indexed, outputCol="cat_features")


Define a oneHotEncoder in order to add to the pipeline if you perform linear regression

In [52]:
# Define one hot encoded variables 
encoder = OneHotEncoder(inputCol= "cat_features", outputCol="cat_features_OHE")

Finally ,we use the VectorIndexer that helps index categorical features in datasets of vectors. This is required for tree based methods, which we will use.

In [53]:
# define indexer
indexer = VectorIndexer(inputCol= "cat_features", outputCol="cat_features_indexed")

In [54]:
# define assembler
VA_all_RF = VectorAssembler(inputCols=["cat_features_indexed", "num_features", "polarity", "subjectivity", "sensitivity"], outputCol="features")

In [58]:
# define the Linear regression model
LR = LinearRegression(labelCol="label", featuresCol="features")

In [59]:
# define the RF model
RF = RandomForestClassifier(labelCol="label", featuresCol="features")

**Define, fit and apply Pipeline on data**

Now, we create preprocessed data that we will use for both models.

In [61]:
# define pipeline stages
stages = [num_vec_assembler, SI_media, SI_hour, SI_week_day, SI_month, SI_day_month,
          cat_vec_assembler]

# define pipeline model and fit on data
pipeline_model_preprocess = Pipeline().setStages(stages).fit(basetable_engr_final)

# transform data by applying pipeline model on data
preprocessed_data = pipeline_model_preprocess.transform(basetable_engr_final)


**Select features and label**

In [63]:
# select features and labels
preprocessed_data = preprocessed_data.select(["num_features", "cat_features", "eng_rate"])
#preprocessed_data = preprocessed_data.select(["num_features", "eng_rate"])

# rename engagement rate to label
preprocessed_data = preprocessed_data.withColumnRenamed("eng_rate", "label")

In [60]:
# look at the data
preprocessed_data.show(5)

[Stage 287:>                                                        (0 + 1) / 1]

+--------------------+--------------------+
|        num_features|               label|
+--------------------+--------------------+
|(17,[0,1,4,10,12,...|                 0.0|
|(17,[0,1,3,4,10,1...|3.619254433586681E-4|
|(17,[0,1,4,10,13,...|7.874015748031496E-4|
|(17,[0,1,3,4,10,1...|0.003115264797507788|
|(17,[0,1,4,10,12,...|0.025477707006369428|
+--------------------+--------------------+
only showing top 5 rows



                                                                                

#### Now create the preprocessed data for the linear regression

In [65]:
# define pipeline stages
stages = [encoder]

# define pipeline model and fit on data
preprocessed_data_LR = Pipeline().setStages(stages).fit(preprocessed_data)

# transform data by applying pipeline model on data
preprocessed_data_LR = pipeline_model_preprocess.transform(preprocessed_data_LR)



IllegalArgumentException: requirement failed: Column cat_features must be of type numeric but was actually of type struct<type:tinyint,size:int,indices:array<int>,values:array<double>>.

#### Now create the preprocessed data for the random forest

In [None]:
# define pipeline stages
stages = [indexer]

# define pipeline model and fit on data
preprocessed_data_RF = Pipeline().setStages(stages).fit(preprocessed_data)

# transform data by applying pipeline model on data
preprocessed_data_RF = pipeline_model_preprocess.transform(preprocessed_data_RF)

In [48]:
# check number of observations in both datasets
#print("Number of observations in the training set: %s " % train.count())
#print("Number of observations in the test set: %s " %test.count())

                                                                                

Number of observations in the training set: 32657 




Number of observations in the test set: 14287 


                                                                                

## 5.2 Perform linear regression

In [None]:
#### 5.2.1 Split dataset into training and test set

In [None]:
# split data in train and test set
train, test = preprocessed_data_LR.randomSplit([0.7, 0.3], seed= 100)

**Define scaler to standardize numeric features**

The `StandardScaler` should only be performed on the trainingset, because an equal `mean` and `standard deviation` between the training- and testset need to be assumed to avoid methodological mistakes. However, as we will only train a random forest model it is not needed to standardize.

In [None]:
# define scaler
scaler = StandardScaler(inputCol="num_features", outputCol="num_features_scaled")

**Create final feature vector containing the indexed categorical features and scaled numeric features**

Most ML algorithms can only handle one vector as input, so we add all preprocessed columns (here two vectors) into one vector, using the ``VectorAssembler`` again.

In [49]:
# define vector assembler
VA_all_LR = VectorAssembler(inputCols=["cat_features_OHE", "scaled_numeric_features", "polarity", "subjectivity", "sensitivity"], outputCol="features")

**Define, fit and apply Pipeline on data**

In [50]:
# define pipeline model stages
stages = [scaler, vec_assembler]

# define pipeline model
pipeline_model_split = Pipeline().setStages(stages)

# fit pipeline model on training data
pipeline_model_split = pipeline_model_split.fit(train)

# transform training set
train_final = pipeline_model_split.transform(train).select(["label", "features"])

# transform test set
test_final = pipeline_model_split.transform(test).select(["label", "features"])

In [51]:
# inspect the final data
train_final.select('features').show(10)

[Stage 132:>                                                        (0 + 1) / 1]

+--------------------+
|            features|
+--------------------+
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
|(17,[0,1,2,3,4,10...|
+--------------------+
only showing top 10 rows



                                                                                

# 5. Modelling

In this part, we are going to train a random forest model on the training data. This is an advanced machine learning model of which we expect good performance. After we train this model, we will evaluate it. As the goal of our analysis was to find the driving factors behind the engagement rate, we will use the shap values to determine variable importance. This way, we know in which direction the variable affects the engagement rate.

In [52]:
# import models and evaluator
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

In [53]:
# define rf model
RF = RandomForestRegressor(labelCol="label", featuresCol="features", numTrees=500)

In [54]:
# fit model on training set
rf_model = RF.fit(train_final)

[Stage 148:>                                                        (0 + 8) / 9]

22/12/08 20:58:28 ERROR Executor: Exception in task 5.0 in stage 148.0 (TID 3356)
scala.MatchError: [null,1.0,(17,[0,1,3,4,10,13,14,15,16],[14.0,69.0,10.0,9.0,718.0,66.0,0.452,0.5,0.7])] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)
	at org.apache.spark.ml.PredictorParams.$anonfun$extractInstances$1(Predictor.scala:81)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.a

22/12/08 20:58:29 WARN TaskSetManager: Lost task 8.0 in stage 148.0 (TID 3359) (192.168.1.22 executor driver): TaskKilled (Stage cancelled)


[Stage 148:>                                                        (0 + 7) / 9]

22/12/08 20:58:31 WARN PythonUDFRunner: Incomplete task 0.0 in stage 148 (TID 3351) interrupted: Attempting to kill Python Worker
22/12/08 20:58:31 WARN PythonUDFRunner: Incomplete task 0.0 in stage 148 (TID 3351) interrupted: Attempting to kill Python Worker
22/12/08 20:58:31 WARN PythonUDFRunner: Incomplete task 0.0 in stage 148 (TID 3351) interrupted: Attempting to kill Python Worker
22/12/08 20:58:31 WARN PythonUDFRunner: Incomplete task 4.0 in stage 148 (TID 3355) interrupted: Attempting to kill Python Worker
22/12/08 20:58:31 WARN PythonUDFRunner: Incomplete task 4.0 in stage 148 (TID 3355) interrupted: Attempting to kill Python Worker
22/12/08 20:58:31 WARN PythonUDFRunner: Incomplete task 0.0 in stage 148 (TID 3351) interrupted: Attempting to kill Python Worker
22/12/08 20:58:31 WARN TaskSetManager: Lost task 4.0 in stage 148.0 (TID 3355) (192.168.1.22 executor driver): TaskKilled (Stage cancelled)
22/12/08 20:58:31 WARN TaskSetManager: Lost task 0.0 in stage 148.0 (TID 3351) (

[Stage 148:>                                                        (0 + 5) / 9]

22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 2.0 in stage 148 (TID 3353) interrupted: Attempting to kill Python Worker


Py4JJavaError: An error occurred while calling o704.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 5 in stage 148.0 failed 1 times, most recent failure: Lost task 5.0 in stage 148.0 (TID 3356) (192.168.1.22 executor driver): scala.MatchError: [null,1.0,(17,[0,1,3,4,10,13,14,15,16],[14.0,69.0,10.0,9.0,718.0,66.0,0.452,0.5,0.7])] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)
	at org.apache.spark.ml.PredictorParams.$anonfun$extractInstances$1(Predictor.scala:81)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$2(RDD.scala:1198)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2322)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2323)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$1(RDD.scala:1200)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.aggregate(RDD.scala:1193)
	at org.apache.spark.ml.tree.impl.DecisionTreeMetadata$.buildMetadata(DecisionTreeMetadata.scala:125)
	at org.apache.spark.ml.tree.impl.RandomForest$.run(RandomForest.scala:274)
	at org.apache.spark.ml.regression.RandomForestRegressor.$anonfun$train$1(RandomForestRegressor.scala:150)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.regression.RandomForestRegressor.train(RandomForestRegressor.scala:134)
	at org.apache.spark.ml.regression.RandomForestRegressor.train(RandomForestRegressor.scala:43)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:115)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: scala.MatchError: [null,1.0,(17,[0,1,3,4,10,13,14,15,16],[14.0,69.0,10.0,9.0,718.0,66.0,0.452,0.5,0.7])] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)
	at org.apache.spark.ml.PredictorParams.$anonfun$extractInstances$1(Predictor.scala:81)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$aggregate$2(RDD.scala:1198)
	at org.apache.spark.SparkContext.$anonfun$runJob$6(SparkContext.scala:2322)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more


22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 2.0 in stage 148 (TID 3353) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 2.0 in stage 148 (TID 3353) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN TaskSetManager: Lost task 2.0 in stage 148.0 (TID 3353) (192.168.1.22 executor driver): TaskKilled (Stage cancelled)
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 6.0 in stage 148 (TID 3357) interrupted: Attempting to kill Python Worker


Traceback (most recent call last):
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 187, in manager
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 74, in worker
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 730, in main
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 595, in read_int
    raise EOFError
EOFError
[Stage 148:>                                                        (0 + 4) / 9]

22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 1.0 in stage 148 (TID 3352) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 6.0 in stage 148 (TID 3357) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 1.0 in stage 148 (TID 3352) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 3.0 in stage 148 (TID 3354) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 3.0 in stage 148 (TID 3354) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 3.0 in stage 148 (TID 3354) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN TaskSetManager: Lost task 3.0 in stage 148.0 (TID 3354) (192.168.1.22 executor driver): TaskKilled (Stage cancelled)


Traceback (most recent call last):
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 187, in manager
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 74, in worker
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 730, in main
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 595, in read_int
    raise EOFError
EOFError
Traceback (most recent call last):
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 187, in manager
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pysp

22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 6.0 in stage 148 (TID 3357) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 1.0 in stage 148 (TID 3352) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN TaskSetManager: Lost task 6.0 in stage 148.0 (TID 3357) (192.168.1.22 executor driver): TaskKilled (Stage cancelled)
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 1.0 in stage 148 (TID 3352) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN TaskSetManager: Lost task 1.0 in stage 148.0 (TID 3352) (192.168.1.22 executor driver): TaskKilled (Stage cancelled)
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 7.0 in stage 148 (TID 3358) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 7.0 in stage 148 (TID 3358) interrupted: Attempting to kill Python Worker
22/12/08 20:58:32 WARN PythonUDFRunner: Incomplete task 7.0 in stage 1

Traceback (most recent call last):
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 187, in manager
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 74, in worker
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 730, in main
    if read_int(infile) == SpecialLengths.END_OF_STREAM:
  File "/usr/local/Cellar/jupyterlab/3.4.8/libexec/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 595, in read_int
    raise EOFError
EOFError


In [None]:
# get predictions on test set
rf_preds = rf_model.transform(test_final)

In [None]:
# create evaluator object
rfEvaluator = RegressionEvaluator(labelCol = 'label', predictionCol = 'prediction')

In [None]:
# get different metrics using your created evaluator object
rfsq = rfEvaluator.evaluate(rf_preds, {rfEvaluator.metricName: 'r2'})
rfmae = rfEvaluator.evaluate(rf_preds, {rfEvaluator.metricName: 'mae'})
rfrmse = rfEvaluator.evaluate(rf_preds, {rfEvaluator.metricName: 'rmse'})
rfmse = rfEvaluator.evaluate(rf_preds, {rfEvaluator.metricName: 'mse'})

In [None]:
# inspect evaluation metrics
print('R^2  : %g' % rfsq)
print('MAE  : %g' % rfmae)
print('RMSE : %g' % rfrmse)
print('MSE  : %g' % rfmse)

#### Use cross validation to optimize the model

In [None]:
# define the parameter space
rfparamGrid = (ParamGridBuilder().addGrid(RF.maxDepth, [3, 7, 10])
                                   .addGrid(RF.maxBins, [15, 25, 40])
                                   .addGrid(RF.numTrees, [5, 25, 60])
                                   .build())

In [None]:
# perform 5-fold cross validation
rfcv_model = CrossValidator(estimator=RF, #random forest model we created before
                          estimatorParamMaps=rfparamGrid, 
                          evaluator=rfEvaluator,
                          numFolds=5)

In [None]:
# run cross validation on trainig set
rfcv_model = rfcv_model.fit(train_final)

In [None]:
# inspect best params
print("best max depth: %s" %rfcv_model.bestModel._java_obj.getMaxDepth())
print("best max bins: %s" %rfcv_model.bestModel._java_obj.getMaxBins())
print("best num trees: %s" %rfcv_model.bestModel._java_obj.getNumTrees())

In [None]:
# get predictions of best model of on test set (cv_model automatically uses best model)
rfcv_preds = rfcv_model.transform(test_final)

In [None]:
# define evaluator
rfcv_evaluator = RegressionEvaluator(labelCol="label", predictionCol="prediction")

In [None]:
# check which of both algorithms is the best:
print("RANDOM FOREST WITHOUT CV:")
print('  R^2  : %g' % rfsq)
print('  MAE  : %g' % rfmae)
print('  RMSE : %g' % rfrmse)
print('  MSE  : %g' % rfmse)
print("------------------")
print("RANDOM FOREST WITH CV:")
print('  R^2  : %g' % rfcv_evaluator.evaluate(rfcv_preds, {rfcv_evaluator.metricName: 'r2'}))
print('  MAE  : %g' % rfcv_evaluator.evaluate(rfcv_preds, {rfcv_evaluator.metricName: 'mae'}))
print('  RMSE : %g' % rfcv_evaluator.evaluate(rfcv_preds, {rfcv_evaluator.metricName: 'rmse'}))
print('  MSE  : %g' % rfcv_evaluator.evaluate(rfcv_preds, {rfcv_evaluator.metricName: 'mse'}))

# 6. Interpretation Random Forest

As we want use the random forest model, we would also like some insights into the drivers behind the algorithm.

In [None]:
!pip3 install shap
#!pip3 install numpy
import shap

# explain the model's predictions using SHAP
explainer = shap.Explainer(rf_model)

In [None]:
# plot the results
shap_values = explainer.shap_values(X_test)
shap.initjs()
shap.summary_plot(shap_values,  X_test)