In [4]:
from __future__ import print_function

# $example on$
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# $example off$
from pyspark.sql import SparkSession

if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("NaiveBayesExample")\
        .getOrCreate()

    # $example on$
    # Load training data
    data = spark.read.format("libsvm") \
        .load("C:\Users\dues1\Desktop\dataset-lab3\dataset-lab3\MyText.txt")

    # Split the data into train and test
    split_data = data.randomSplit([0.6, 0.4], 1234)
    train = split_data[0]
    test = split_data[1]

    # create the trainer and set its parameters
    naive = NaiveBayes(smoothing=1.0, modelType="multinomial")

    # train the model
    model = naive.fit(train)

    # select example rows to display.
    pred = model.transform(test)
    pred.show()

    # compute accuracy on the test set
    evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction",
                                                  metricName="accuracy")
    accuracy = evaluator.evaluate(pred)
    print("Test set accuracy = " + str(accuracy))
    # $example off$

    spark.stop()


+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(100,[0,1,2,4,5,1...|[-132.08841666947...|[1.0,2.6112979117...|       0.0|
|  0.0|(100,[0,1,2,4,5,1...|[-132.08841666947...|[1.0,2.6112979117...|       0.0|
|  0.0|(100,[0,1,2,4,5,1...|[-132.08841666947...|[1.0,2.6112979117...|       0.0|
|  0.0|(100,[0,1,2,4,8,9...|[-101.47798502549...|[1.0,1.4000530024...|       0.0|
|  0.0|(100,[0,1,2,4,8,9...|[-101.47798502549...|[1.0,1.4000530024...|       0.0|
|  0.0|(100,[0,1,2,4,8,9...|[-101.47798502549...|[1.0,1.4000530024...|       0.0|
|  0.0|(100,[0,1,2,4,8,9...|[-101.47798502549...|[1.0,1.4000530024...|       0.0|
|  0.0|(100,[0,1,2,8,24,...|[-76.850690653443...|[1.0,1.1800593050...|       0.0|
|  0.0|(100,[0,1,2,8,24,...|[-76.850690653443...|[1.0,1.1800593050...|       0.0|
|  0.0|(100,[0,1