In [1]:
from collections import defaultdict
import numpy as np
from pyspark.ml.linalg import SparseVector
from pyspark.sql.functions import explode
from pyspark import SparkFiles
from pyspark.sql import Row

import bz2
import json
import time
from pyspark.ml import Pipeline
from pyspark.ml.feature import * # CountVectorizer, Tokenizer, RegexTokenizer, HashingTF
from pyspark.ml.regression import * # RandomForestRegressor, LinearRegression, DecisionTreeRegressor
from pyspark.ml.evaluation import RegressionEvaluator

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
0,application_1607295858388_0001,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [2]:
def timeit(method):
    '''
    Decorator to time functions.
    '''
    def timed(*args, **kw):
        ts = time.time()
        result = method(*args, **kw)
        te = time.time()

        print('%r took %2.2f sec\n' % (method.__name__, te-ts))
              
        return result
    return timed

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
import re
# DATAFILE_PATTERN = '^(.+),"(.+)",(.*),(.*),(.*)'
ID_PATTERN = '"id":(.*?(?=,|}))'
UPS_PATTERN = '"ups":(.*?(?=,|}))'
BODY_PATTERN = '"body":(.*?(?=,|}))'
# DOWNS_PATTERN = '"downs":(.*?(?=,|}))'
SCORE_PATTERN = '"score":(.*?(?=,|}))'
# CONTROVERSIALITY_PATTERN = '"controversiality":(.*?(?=,|}))'

def removeQuotes(s):
    """ Remove quotation marks from an input string
    Args:
        s (str): input string that might have the quote "" characters
    Returns:
        str: a string without the quote characters
    """
    return ''.join(i for i in s if i!='"')

def parseDatafileLine(datafileLine):
    """ Parse a line of the data file using the specified regular expression pattern
    Args:
        datafileLine (str): input string that is a line from the data file
    Returns:
        tuple: a tuple including the parsed results using the given regular expression and without the quote characters
    """
    id_match = re.search(ID_PATTERN, datafileLine.decode('utf-8'))
    ups_match = re.search(UPS_PATTERN, datafileLine.decode('utf-8'))
    body_match = re.search(BODY_PATTERN, datafileLine.decode('utf-8'))
    score_match = re.search(SCORE_PATTERN, datafileLine.decode('utf-8'))
    
    if (id_match is None) or (ups_match is None) or (body_match is None) or (score_match is None):
        print('Invalid datafile line: %s' % datafileLine)
        return (datafileLine, -1)
    else:
        viralness = 0
        if int(score_match.group(1)) < -10 or int(score_match.group(1)) > 10:
            viralness = 1
        comment = (id_match.group(1), int(ups_match.group(1)), removeQuotes(body_match.group(1)), int(score_match.group(1)), viralness)
        return (comment, 1)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
import sys
import os
from pyspark import SparkFiles

#RC_PATH = '/FileStore/shared_uploads/ddk1@andrew.cmu.edu/RC_2007_10'
RC_PATH = 's3://dsml-vasu-simar-daniel/RC_2015-0*'

def parseData(path):
    """ Parse a data file
    Args:
        filename (str): input file name of the data file
    Returns:
        RDD: a RDD of parsed lines
    """
#     sc.addFile(path)
    return (sc
            .textFile(path, 4, 0)
            .map(parseDatafileLine)
            .cache())
@timeit
def loadData(path):
    """ Load a data file
    Args:
        path (str): input file name of the data file
    Returns:
        RDD: a RDD of parsed valid lines
    """

    raw = parseData(path).cache()
    
    failed = (raw
              .filter(lambda s: s[1] == -1)
              .map(lambda s: s[0]))
    for line in failed.take(10):
        print('%s - Invalid datafile line: %s' % (path, line))
    
    deleted = (raw
             .filter(lambda s: s[0][2] == '[deleted]')
             .map(lambda s: s[0]))
    
    valid = (raw
             .filter(lambda s: s[1] == 1)
             .filter(lambda s: s[0][2] != '[deleted]')
             .map(lambda s: s[0])
             .cache())
    print('%s - Read %d lines, successfully parsed %d lines, failed to parse %d lines, %d lines were deleted' % (path,
                                                                                                                 raw.count(),
                                                                                                                 valid.count(),
                                                                                                                 failed.count(),
                                                                                                                 deleted.count()))
    return valid

reddit = loadData(RC_PATH)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

s3://dsml-vasu-simar-daniel/RC_2015-0* - Read 156758730 lines, successfully parsed 147697364 lines, failed to parse 0 lines, 9061366 lines were deleted
'loadData' took 211.51 sec

In [10]:
sentenceDF = reddit.toDF().selectExpr("_1 as id", "_2 as ups", "_3 as body", "_4 as score", "_5 as viralness")
sentenceDF.show(n=5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+---+--------------------+-----+---------+
|       id|ups|                body|score|viralness|
+---------+---+--------------------+-----+---------+
|"cnas8zv"| 14|Most of us have s...|   14|        1|
|"cnas8zw"|  3|But Mill's career...|    3|        0|
|"cnas8zx"|  1|Mine uses a strai...|    1|        0|
|"cnas8zz"|  2|           Very fast|    2|        0|
|"cnas900"|  6|The guy is a prof...|    6|        0|
+---------+---+--------------------+-----+---------+
only showing top 5 rows

In [11]:
split_regex = r'\W+'
linebreak_regex = r'\\r\\n\\r\\n'

def simpleTokenize(string):
    """ A simple implementation of input string tokenization
    Args:
        string (str): input string
    Returns:
        list: a list of tokens
    """
    linebreak_removed_string = re.sub(linebreak_regex, " ", string)
    return list(filter(None, re.split(split_regex, linebreak_removed_string.lower())))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [15]:
# stopfile = "https://raw.githubusercontent.com/10605/data/master/hw1/stopwords.txt"
# sc.addFile(stopfile)
# stopwords = set(sc.textFile("file://" + SparkFiles.get("stopwords.txt")).collect())
stopwords = set(sc.textFile("s3://dsml-vasu-simar-daniel/stopwords.txt").collect())
print('These are the stopwords: %s' % stopwords)

def tokenize(string):
    """ An implementation of input string tokenization that excludes stopwords
    Args:
        string (str): input string
    Returns:
        list: a list of tokens without stopwords
    """
    return list(filter(lambda word: word not in stopwords,simpleTokenize(string)))



# pattern = "\\W"
# # tokenizer = RegexTokenizer(inputCol="body", outputCol="words", pattern=pattern)
# tokenizer = Tokenizer(inputCol="body", outputCol="words")
# wordsDF = tokenizer.transform(sentenceDF)

# # Remove stop words
# remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
# wordsFilteredDF = remover.transform(wordsDF)

# # Remove body and words since they will no longer be used
# wordsFilteredDF = wordsFilteredDF.select('id','ups','filtered_words','score','viralness')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

These are the stopwords: {'out', 'we', 'was', 'how', 'myself', 'for', 'they', 'about', 'then', 'both', 'so', 'don', 'as', 'any', 'after', 'you', 'why', 'been', 'where', 'by', 'yourself', 'a', 'did', 'their', 'doing', 'be', 'further', 'ours', 'now', 'am', 'her', 'yourselves', 'that', 'what', 'my', 'to', 'not', 'own', 'there', 'this', 'each', 'all', 'more', 'me', 'which', 'himself', 'nor', 'other', 'who', 'same', 'at', 'such', 't', 'up', 'than', 'can', 'too', 'these', 'while', 'before', 'ourselves', 'he', 'i', 'our', 'its', 'but', 'with', 'because', 'those', 'the', 'it', 'hers', 'just', 'over', 'between', 'had', 'does', 'have', 'and', 'some', 'or', 'only', 'when', 'below', 'in', 'if', 'theirs', 'again', 'his', 'whom', 'above', 'should', 'itself', 'themselves', 'until', 'are', 'she', 'will', 'from', 'into', 'no', 'your', 'few', 'herself', 'of', 'has', 'down', 'were', 'once', 'having', 'them', 'under', 'him', 'do', 'on', 'an', 'yours', 'being', 'off', 'very', 'through', 'most', 'against', 

In [16]:
redditRecToToken = reddit.map(lambda line: (line[0], line[1], tokenize(line[2]),line[3], line[4]))

print(redditRecToToken.take(5))

def countTokens(vendorRDD):
    """ Count and return the number of tokens
    Args:
        vendorRDD (RDD of (recordId, tokenizedValue)): Pair tuple of record ID to tokenized output
    Returns:
        count: count of all tokens
    """
    # TODO: Uncomment the template below and replace <FILL IN> with appropriate code
    recordCount = vendorRDD.map(lambda line: len(line[0]))
    recordSum = recordCount.sum()
    return recordSum

totalTokens = countTokens(redditRecToToken)
print('There are %s tokens in the combined datasets' % totalTokens)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[('"cnas8zv"', 14, ['us', 'family', 'members', 'like', 'family', 'like'], 14, 1), ('"cnas8zw"', 3, ['mill', 'career', 'way', 'better', 'bentham', 'like'], 3, 0), ('"cnas8zx"', 1, ['mine', 'uses', 'strait', 'razor'], 1, 0), ('"cnas8zz"', 2, ['fast'], 2, 0), ('"cnas900"', 6, ['guy', 'professional'], 6, 0)]
There are 1329276276 tokens in the combined datasets

In [17]:
@timeit
def term_frequency(df, inputCol, outputCol, hashFeatures=None):
    '''
    Returns a DataFrame object containing a new row with the extracted features. 
    Passing hashed=True will return a Featured Hashed matrix.
    
    @params:
        df - DataFrame
        inputCol - name of input column from DataFrame to find features
        outputCol - name of the column to save the features
        hashFeatures - number of features for HashingTF, if None will perform 
            CountVectorization
    '''
    
    # since the number of features was not passed perform standard CountVectorization
    if hashFeatures is None:
        cv = CountVectorizer(inputCol=inputCol, outputCol=outputCol)
        feature_extractor = cv.fit(df)
    # otherwise perform a feature extractor with 
    else:
        feature_extractor = HashingTF(\
                              inputCol=inputCol, outputCol=outputCol, numFeatures=hashFeatures)
    
    # create a new DataFrame using either feature extraction method
    return feature_extractor.transform(df)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [18]:
wordsFilteredDF = spark.createDataFrame(redditRecToToken).toDF("id", "ups", "filtered_words", "score", "viralness")

# Feature Hash the comment content
# number of features for Feature Hash matrix, reccomended too use power of 2
hashDF = term_frequency(\
    df=wordsFilteredDF, inputCol="filtered_words", outputCol="features", hashFeatures=1024)

# Display snippet of new DataFrame
hashDF.select('filtered_words','features').show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

'term_frequency' took 0.36 sec

+--------------------+--------------------+
|      filtered_words|            features|
+--------------------+--------------------+
|[us, family, memb...|(1024,[368,386,45...|
|[mill, career, wa...|(1024,[102,205,31...|
|[mine, uses, stra...|(1024,[120,423,55...|
|              [fast]|  (1024,[863],[1.0])|
| [guy, professional]|(1024,[931,955],[...|
+--------------------+--------------------+
only showing top 5 rows

In [19]:
@timeit
def random_forest_regression(df, featuresCol, labelCol):
    '''
    Returns a DataFrame containing a column of predicted values of the labelCol.
    Predict the output of labelCol using values in featuresCol y = rf(x).
    
    @params:
        df - DataFrame
        featuresCol - input features, x
        labelCol - output variable, y
    '''
    # split the training and test data using the holdout method
    (trainingData, testData) = df.randomSplit([0.8, 0.2])
    
    # create the random forest regressor, limit number of trees to ten
    dtr = RandomForestRegressor(\
       featuresCol=featuresCol, labelCol=labelCol)
    
    # fit the training data to the regressor to create the model
    model = dtr.fit(trainingData)
    
    # create a DataFrame contained a column with predicted values of the labelCol
    predictions = model.transform(testData)
    
    return predictions

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [22]:
# train random forest regression
rfpredictions = random_forest_regression(df=hashDF,featuresCol="features",labelCol="viralness")

# compute the error
evaluator = RegressionEvaluator(labelCol="viralness", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(rfpredictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

'random_forest_regression' took 1777.49 sec

Root Mean Squared Error (RMSE) on test data = 0.266187

In [25]:
rfpredictions.show(100)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+---+--------------------+-----+---------+--------------------+-------------------+
|       id|ups|      filtered_words|score|viralness|            features|         prediction|
+---------+---+--------------------+-----+---------+--------------------+-------------------+
|"cnas90j"|  2|[religion, doesn,...|    2|        0|(1024,[50,60,255,...|0.07515022047761011|
|"cnas90k"|  2|        [hey, rocky]|    2|        0|(1024,[594,724],[...|0.07515022047761011|
|"cnas90o"|  1| [agreed, get, sale]|    1|        0|(1024,[511,567,88...|0.07515022047761011|
|"cnas90t"|  1|[chappelle, https...|    1|        0|(1024,[257,322,38...|0.07806345262040863|
|"cnas90y"|  2|[thought, wanted,...|    2|        0|(1024,[303,386,69...| 0.0756647834679763|
|"cnas912"|  3|[ll, try, find, g...|    3|        0|(1024,[95,159,174...|0.07056895981187288|
|"cnas914"|  1|        [like, idea]|    1|        0|(1024,[386,726],[...| 0.0756647834679763|
|"cnas918"|  3|[haha, guilty, ve...|    3|        0|(1024,[0

In [26]:
@timeit
def random_forest_regression_filtered(df, featuresCol, labelCol):
    '''
    Returns a DataFrame containing a column of predicted values of the labelCol.
    Predict the output of labelCol using values in featuresCol y = rf(x).
    
    @params:
        df - DataFrame
        featuresCol - input features, x
        labelCol - output variable, y
    '''
    
    print(df.count())
    df = df.filter((df.score >=0) & (df.score <10))
    print(df.show(10))
    print(df.count())
    # split the training and test data using the holdout method
    (trainingData, testData) = df.randomSplit([0.8, 0.2])
    
    # create the random forest regressor, limit number of trees to ten
    dtr = RandomForestRegressor(\
       featuresCol=featuresCol, labelCol=labelCol)
    
    # fit the training data to the regressor to create the model
    model = dtr.fit(trainingData)
    
    # create a DataFrame contained a column with predicted values of the labelCol
    predictions = model.transform(testData)
    
    return predictions

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [27]:
# train random forest regression
predictions = random_forest_regression_filtered(df=hashDF,featuresCol="features",labelCol="score")

# compute the error
evaluator = RegressionEvaluator(labelCol="score", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

147697364
+---------+---+--------------------+-----+---------+--------------------+
|       id|ups|      filtered_words|score|viralness|            features|
+---------+---+--------------------+-----+---------+--------------------+
|"cnas8zw"|  3|[mill, career, wa...|    3|        0|(1024,[102,205,31...|
|"cnas8zx"|  1|[mine, uses, stra...|    1|        0|(1024,[120,423,55...|
|"cnas8zz"|  2|              [fast]|    2|        0|  (1024,[863],[1.0])|
|"cnas900"|  6| [guy, professional]|    6|        0|(1024,[931,955],[...|
|"cnas901"|  1|   [great, question]|    1|        0|(1024,[116,131],[...|
|"cnas902"|  1|[ie, shiv, ghostb...|    1|        0|(1024,[114,200,36...|
|"cnas903"|  1|                 [d]|    1|        0|  (1024,[902],[1.0])|
|"cnas905"|  2|[know, describe, ...|    2|        0|(1024,[47,57,210,...|
|"cnas906"|  2|           [says, g]|    2|        0|(1024,[34,305],[1...|
|"cnas908"|  1|       [love, music]|    1|        0|(1024,[112,979],[...|
+---------+---+-------------

In [28]:
predictions.show(10)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+---+--------------------+-----+---------+--------------------+------------------+
|       id|ups|      filtered_words|score|viralness|            features|        prediction|
+---------+---+--------------------+-----+---------+--------------------+------------------+
|"cnas900"|  6| [guy, professional]|    6|        0|(1024,[931,955],[...|2.0639390632929513|
|"cnas905"|  2|[know, describe, ...|    2|        0|(1024,[47,57,210,...|2.0471774635344673|
|"cnas90i"|  1|        [wheredugit]|    1|        0|  (1024,[187],[1.0])|2.0471774635344673|
|"cnas90j"|  2|[religion, doesn,...|    2|        0|(1024,[50,60,255,...|2.0471774635344673|
|"cnas90z"|  1|[slightly, strong...|    1|        0|(1024,[561,621,84...|2.0471774635344673|
|"cnas912"|  3|[ll, try, find, g...|    3|        0|(1024,[95,159,174...|2.0471774635344673|
|"cnas915"|  4|               [yes]|    4|        0|  (1024,[835],[1.0])|2.0471774635344673|
|"cnas918"|  3|[haha, guilty, ve...|    3|        0|(1024,[0,150,266,.

In [29]:
from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint
@timeit
def logistic_regression(df, featuresCol, labelCol):
    '''
    Returns a DataFrame containing a column of predicted values of the labelCol.
    Predict the output of labelCol using values in featuresCol y = rf(x).
    
    @params:
        df - DataFrame
        featuresCol - input features, x
        labelCol - output variable, y
    '''
    # split the training and test data using the holdout method
    print(df.count())
    df = df.filter((df.score >=0) & (df.score <10))
    print(df.show(10))
    print(df.count())
    (trainingData, testData) = df.randomSplit([0.8, 0.2])
    
    
    # TODO: Uncomment the lines below and replace <FILL IN> with appropriate code
    # Given hyperparameters
    standardization = False
    elastic_net_param = 0.8
    reg_param = .3
    max_iter = 10

    lr = (LogisticRegression(featuresCol=featuresCol, labelCol=labelCol, regParam = reg_param, standardization = standardization, maxIter = max_iter,elasticNetParam = elastic_net_param))
  
    lr_model_basic = lr.fit(trainingData)

    trainingSummary = lr_model_basic.summary
    accuracy = trainingSummary.accuracy
    print(accuracy)
    # create a DataFrame contained a column with predicted values of the labelCol
    predictions = lr_model_basic.transform(testData)
    
    return predictions

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [30]:
# from pyspark.ml.classification import LogisticRegression
# from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD
# from pyspark.mllib.regression import LabeledPoint
@timeit
def logistic_regression_viral(df, featuresCol, labelCol):
    '''
    Returns a DataFrame containing a column of predicted values of the labelCol.
    Predict the output of labelCol using values in featuresCol y = rf(x).
    
    @params:
        df - DataFrame
        featuresCol - input features, x
        labelCol - output variable, y
    '''
    # split the training and test data using the holdout method
#     print(df.count())
#     df = df.filter((df.score >=0) & (df.score <10))
#     print(df.show(10))
#     print(df.count())
    (trainingData, testData) = df.randomSplit([0.8, 0.2])
    
    
    # TODO: Uncomment the lines below and replace <FILL IN> with appropriate code
    # Given hyperparameters
    standardization = False
    elastic_net_param = 0.8
    reg_param = .3
    max_iter = 10

    lr = (LogisticRegression(featuresCol=featuresCol, labelCol=labelCol, regParam = reg_param, standardization = standardization, maxIter = max_iter,elasticNetParam = elastic_net_param))
  
    lr_model_basic = lr.fit(trainingData)

    trainingSummary = lr_model_basic.summary
    accuracy = trainingSummary.accuracy
    print(accuracy)
    # create a DataFrame contained a column with predicted values of the labelCol
    predictions = lr_model_basic.transform(testData)
    
    return predictions

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [31]:
# train random forest regression
lrpredictions = logistic_regression(df=hashDF,featuresCol="features",labelCol="score")

# compute the error
evaluator = RegressionEvaluator(labelCol="score", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(lrpredictions)
print ("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

147697364
+---------+---+--------------------+-----+---------+--------------------+
|       id|ups|      filtered_words|score|viralness|            features|
+---------+---+--------------------+-----+---------+--------------------+
|"cnas8zw"|  3|[mill, career, wa...|    3|        0|(1024,[102,205,31...|
|"cnas8zx"|  1|[mine, uses, stra...|    1|        0|(1024,[120,423,55...|
|"cnas8zz"|  2|              [fast]|    2|        0|  (1024,[863],[1.0])|
|"cnas900"|  6| [guy, professional]|    6|        0|(1024,[931,955],[...|
|"cnas901"|  1|   [great, question]|    1|        0|(1024,[116,131],[...|
|"cnas902"|  1|[ie, shiv, ghostb...|    1|        0|(1024,[114,200,36...|
|"cnas903"|  1|                 [d]|    1|        0|  (1024,[902],[1.0])|
|"cnas905"|  2|[know, describe, ...|    2|        0|(1024,[47,57,210,...|
|"cnas906"|  2|           [says, g]|    2|        0|(1024,[34,305],[1...|
|"cnas908"|  1|       [love, music]|    1|        0|(1024,[112,979],[...|
+---------+---+-------------

In [33]:
lrpredictions.show(100)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+---+--------------------+-----+---------+--------------------+--------------------+--------------------+----------+
|       id|ups|      filtered_words|score|viralness|            features|       rawPrediction|         probability|prediction|
+---------+---+--------------------+-----+---------+--------------------+--------------------+--------------------+----------+
|"cnas8zw"|  3|[mill, career, wa...|    3|        0|(1024,[102,205,31...|[0.04919917672070...|[0.04955881868249...|       1.0|
|"cnas8zx"|  1|[mine, uses, stra...|    1|        0|(1024,[120,423,55...|[0.04919917672070...|[0.04955881868249...|       1.0|
|"cnas900"|  6| [guy, professional]|    6|        0|(1024,[931,955],[...|[0.04919917672070...|[0.04955881868249...|       1.0|
|"cnas90a"|  2|[always, forget, ...|    2|        0|(1024,[119,277,46...|[0.04919917672070...|[0.04955881868249...|       1.0|
|"cnas90e"|  1|[haha, awesome, m...|    1|        0|(1024,[342,537,55...|[0.04919917672070...|[0.04955881868249

In [34]:
# train random forest regression
lrViralpredictions = logistic_regression_viral(df=hashDF,featuresCol="features",labelCol="viralness")

# compute the error
evaluator = RegressionEvaluator(labelCol="viralness", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(lrViralpredictions)
print ("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

0.9229822832333946
'logistic_regression_viral' took 578.81 sec

Root Mean Squared Error (RMSE) on test data = 0.277614

In [35]:
lrViralpredictions.show(10)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+---+--------------------+-----+---------+--------------------+--------------------+--------------------+----------+
|       id|ups|      filtered_words|score|viralness|            features|       rawPrediction|         probability|prediction|
+---------+---+--------------------+-----+---------+--------------------+--------------------+--------------------+----------+
|"cnas901"|  1|   [great, question]|    1|        0|(1024,[116,131],[...|[2.48357455628972...|[0.92298228323339...|       0.0|
|"cnas903"|  1|                 [d]|    1|        0|  (1024,[902],[1.0])|[2.48357455628972...|[0.92298228323339...|       0.0|
|"cnas906"|  2|           [says, g]|    2|        0|(1024,[34,305],[1...|[2.48357455628972...|[0.92298228323339...|       0.0|
|"cnas90f"|  3|[completely, agre...|    3|        0|(1024,[1,60,392,4...|[2.48357455628972...|[0.92298228323339...|       0.0|
|"cnas90j"|  2|[religion, doesn,...|    2|        0|(1024,[50,60,255,...|[2.48357455628972...|[0.92298228323339