In [35]:
import findspark
findspark.init()

In [36]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, DoubleType
from nlp_preproc import preproc_pipeline

In [37]:
spark = SparkSession.builder.appName('hatespeech-detector').getOrCreate()

In [38]:
preproc_udf = udf(preproc_pipeline, StringType())

In [39]:
df = spark.read.csv('data/labelled_data.csv', header=True)

In [40]:
df = df.filter(df.tweet.isNotNull())

In [41]:
df = df.withColumn('text', preproc_udf(df['tweet']))

In [42]:
df = df.dropDuplicates(['text', 'class'])

In [43]:
df = df.withColumnRenamed('class', 'label')

In [45]:
df = df.withColumn('label', df['label'].cast(DoubleType()))

In [47]:
data_set = df.select(df['text'], df['label'])

In [48]:
train_df, test_df = data_set.randomSplit([0.8, 0.2])

In [49]:
train_df.count(), test_df.count()

(18638, 4667)

In [50]:
from pyspark.ml.feature import HashingTF, IDF, Tokenizer
from pyspark.ml import Pipeline
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.ml.tuning import CrossValidator

In [51]:
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
idf = IDF(minDocFreq=3, inputCol="features", outputCol="idf")
nb = NaiveBayes()
pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, nb])


paramGrid = ParamGridBuilder().addGrid(nb.smoothing, [0.0, 1.0]).build()


cv = CrossValidator(estimator=pipeline, 
                    estimatorParamMaps=paramGrid, 
                    evaluator=MulticlassClassificationEvaluator(), 
                    numFolds=4)

cvModel = cv.fit(train_df)

In [52]:
result = cvModel.transform(test_df)

In [53]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
evaluator.evaluate(result, {evaluator.metricName: "accuracy"})

0.79537175916006