In [0]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('NBGridSearch').getOrCreate()

In [0]:
df = spark.read.csv('CleanedNews.csv/part-00000-e0c20413-d9a2-4ae3-bc41-a77b460c6a58-c000.csv',inferSchema=True)
df = df.withColumnRenamed('_c0','claim').withColumnRenamed('_c1','claimant').withColumnRenamed('_c2','articles').withColumnRenamed('_c3','label')
df.printSchema()

root
 |-- claim: string (nullable = true)
 |-- claimant: string (nullable = true)
 |-- articles: string (nullable = true)
 |-- label: integer (nullable = true)



In [0]:
from pyspark.ml.feature import Tokenizer,StopWordsRemover,CountVectorizer,IDF#,StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

tokenizer = Tokenizer(inputCol='articles',outputCol='token_text')
stop_remove = StopWordsRemover(inputCol='token_text',outputCol='stop_token')
count_vec = CountVectorizer(inputCol='stop_token',outputCol='c_vec')
idf = IDF(inputCol='c_vec',outputCol='tf_idf')
assembler = VectorAssembler(inputCols=['tf_idf'],outputCol='features')

pipe = Pipeline(stages=[tokenizer,stop_remove,count_vec,idf,assembler])
pipelineFit = pipe.fit(df)
dataset = pipelineFit.transform(df)

In [0]:
training,test = dataset.randomSplit(weights = [0.8,0.2],seed = 0 )

In [0]:
from pyspark.ml.tuning import ParamGridBuilder,CrossValidator
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

nb = NaiveBayes()

gridSearch = ParamGridBuilder().addGrid(nb.smoothing,[0.0,0.2,0.4,0.6,0.8,1.0]).build()
cvEvaluater = MulticlassClassificationEvaluator(metricName="weightedPrecision",predictionCol="prediction")

cv = CrossValidator(estimator=nb,estimatorParamMaps=gridSearch,evaluator=cvEvaluater)
cvModel = cv.fit(training)

In [0]:
cvModel.avgMetrics

[0.519050392151873,
 0.569111548730109,
 0.5696235436736343,
 0.5697474534866707,
 0.5708469714860119,
 0.5707518265656956]

In [0]:
from sklearn.metrics import classification_report
prediction = cvModel.transform(test)
y_true = prediction.select('label').collect()
y_pred = prediction.select('prediction').collect()
print (classification_report(y_true,y_pred))

              precision    recall  f1-score   support

           0       0.66      0.65      0.65      1494
           1       0.60      0.57      0.58      1292
           2       0.19      0.23      0.21       349

    accuracy                           0.57      3135
   macro avg       0.48      0.48      0.48      3135
weighted avg       0.58      0.57      0.58      3135

