## Import Libraries

In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import StopWordsRemover, HashingTF, IDF, RegexTokenizer, VectorAssembler
from pyspark.ml.classification import NaiveBayes
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from sklearn.datasets import fetch_20newsgroups

## Create Spark Session

In [2]:
def create_spark():
    """ Create a SparkSession object. """
    spark = SparkSession.builder \
        .master("local[*]") \
        .appName("TestSuite") \
        .config(key='spark.sql.shuffle.partitions', value='4') \
        .config(key='spark.default.parallelism', value='4') \
        .config(key='spark.sql.session.timeZone', value='UTC') \
        .config(key='spark.ui.enabled', value='false') \
        .config(key='spark.app.id', value='Test') \
        .config(key='spark.driver.host', value='localhost') \
        .config(key='spark.executor.memory', value='8g') \
        .getOrCreate()

    return spark

In [3]:
spark = create_spark()

## Load Dataset

In [8]:
categories = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
              'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware',
              'comp.windows.x', 'misc.forsale', 'rec.autos',
              'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey',
              'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc',]
newsgroups = fetch_20newsgroups(subset='all', categories=categories, remove=('headers', 'footers', 'quotes'), shuffle=True, random_state=42)
news_spark_df = spark.createDataFrame(zip(newsgroups.data, newsgroups.target.tolist()), schema=['text', 'label'])

## Data Preprocessing

In [None]:
tokenizer = RegexTokenizer(inputCol='text', outputCol='tokens', pattern='\\W+')
tokenized_df = tokenizer.transform(news_spark_df)

stop_words_remover = StopWordsRemover(inputCol='tokens', outputCol='filtered_tokens')
filtered_df = stop_words_remover.transform(tokenized_df)

hashing_tf = HashingTF(inputCol='filtered_tokens', outputCol='raw_features')
tf_data = hashing_tf.transform(filtered_df)

idf = IDF(inputCol='raw_features', outputCol='features')

idf_model = idf.fit(tf_data)

tfidf_data = idf_model.transform(tf_data)

assembler = VectorAssembler(inputCols=['features'], outputCol='vectorized_features')

assembled_data = assembler.transform(tfidf_data)

## Split Dataset

In [None]:
train, test = assembled_data.randomSplit([0.8, 0.2], seed=42)

## Build Naive Bayes model

In [None]:
nb = NaiveBayes(featuresCol='features', labelCol='label', modelType='multinomial')
nb_model = nb.fit(train)

## Evaluation

In [None]:
predicted = nb_model.transform(test)

label_map = {i: label for i, label in enumerate(newsgroups.target_names)}
map_label_udf = spark.udf.register("mapLabel", lambda x: label_map[x])

predicted = predicted.withColumn('label_name', map_label_udf(predicted['label']))
predicted = predicted.withColumn('prediction_name', map_label_udf(predicted['prediction']))

evaluator = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='accuracy')
accuracy = evaluator.evaluate(predicted)
print(f'Accuracy: {accuracy:.4f}') # Accuracy: 0.7349

## Show Accuracy

In [None]:
selected_predictions = predicted.select('label', 'prediction').orderBy('label')
selected_predictions.show(5)

In [None]:
# +--------------------+--------------------+
# |          label_name|     prediction_name|
# +--------------------+--------------------+
# |         alt.atheism|    rec.sport.hockey|
# |       comp.graphics|    rec.sport.hockey|
# |       comp.graphics|    rec.sport.hockey|
# |comp.os.ms-window...|    rec.sport.hockey|
# |comp.sys.mac.hard...|    rec.sport.hockey|
# |           rec.autos|    rec.sport.hockey|
# |     rec.motorcycles|    rec.sport.hockey|
# |  rec.sport.baseball|    rec.sport.hockey|
# |  talk.politics.guns|    rec.sport.hockey|
# |talk.politics.mid...|    rec.sport.hockey|
# |talk.politics.mid...|    rec.sport.hockey|
# |talk.politics.mid...|    rec.sport.hockey|
# |talk.politics.mid...|    rec.sport.hockey|
# |        misc.forsale|        misc.forsale|
# |        misc.forsale|       comp.graphics|
# |    rec.sport.hockey|    rec.sport.hockey|
# |        misc.forsale|    rec.sport.hockey|
# |comp.os.ms-window...|comp.sys.mac.hard...|
# |  talk.politics.misc|  talk.politics.misc|
# |     rec.motorcycles|     rec.motorcycles|
# |       comp.graphics|       comp.graphics|
# |comp.sys.mac.hard...|comp.sys.mac.hard...|
# |     rec.motorcycles|     rec.motorcycles|
# |     rec.motorcycles|  talk.politics.misc|
# |comp.os.ms-window...|comp.sys.ibm.pc.h...|
# |talk.politics.mid...|talk.politics.mid...|
# |    rec.sport.hockey|    rec.sport.hockey|
# |           rec.autos|           rec.autos|
# |comp.os.ms-window...|    rec.sport.hockey|
# |  rec.sport.baseball|  rec.sport.baseball|
# |comp.os.ms-window...|comp.os.ms-window...|
# |    rec.sport.hockey|talk.politics.mid...|
# |  talk.politics.guns|  talk.politics.guns|
# |  talk.politics.misc|     rec.motorcycles|
# |comp.sys.mac.hard...|         alt.atheism|
# |      comp.windows.x|      comp.windows.x|
# |comp.sys.ibm.pc.h...|comp.sys.ibm.pc.h...|
# |      comp.windows.x|      comp.windows.x|
# |         alt.atheism|         alt.atheism|
# |  talk.politics.guns|  talk.politics.guns|
# |         alt.atheism|         alt.atheism|
# |  talk.politics.misc|  talk.politics.misc|
# |  talk.politics.guns|  talk.politics.guns|
# |     rec.motorcycles|           rec.autos|
# |  rec.sport.baseball|  rec.sport.baseball|
# |      comp.windows.x|      comp.windows.x|
# |  rec.sport.baseball|  rec.sport.baseball|
# |       comp.graphics|       comp.graphics|
# |talk.politics.mid...|talk.politics.mid...|
# |  rec.sport.baseball|  rec.sport.baseball|
# +--------------------+--------------------+