<a href="https://colab.research.google.com/github/JeevithaR3/Online_News/blob/main/Online_news_NB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pyspark




In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer, IDF, StringIndexer
from pyspark.ml.classification import NaiveBayes
from pyspark.ml import Pipeline


In [None]:
spark = SparkSession.builder.appName("NewsCategoryNB").getOrCreate()


In [None]:
df = spark.read.json("News_Category_Dataset_v3.json")
df.show(5)


+--------------------+---------+----------+--------------------+--------------------+--------------------+
|             authors| category|      date|            headline|                link|   short_description|
+--------------------+---------+----------+--------------------+--------------------+--------------------+
|Carla K. Johnson, AP|U.S. NEWS|2022-09-23|Over 4 Million Am...|https://www.huffp...|Health experts sa...|
|      Mary Papenfuss|U.S. NEWS|2022-09-23|American Airlines...|https://www.huffp...|He was subdued by...|
|       Elyse Wanshel|   COMEDY|2022-09-23|23 Of The Funnies...|https://www.huffp...|"Until you have a...|
|    Caroline Bologna|PARENTING|2022-09-23|The Funniest Twee...|https://www.huffp...|"Accidentally put...|
|      Nina Golgowski|U.S. NEWS|2022-09-22|Woman Who Called ...|https://www.huffp...|Amy Cooper accuse...|
+--------------------+---------+----------+--------------------+--------------------+--------------------+
only showing top 5 rows



In [None]:
data = df.select(col("headline").alias("text"), col("category"))
data.show(5)


+--------------------+---------+
|                text| category|
+--------------------+---------+
|Over 4 Million Am...|U.S. NEWS|
|American Airlines...|U.S. NEWS|
|23 Of The Funnies...|   COMEDY|
|The Funniest Twee...|PARENTING|
|Woman Who Called ...|U.S. NEWS|
+--------------------+---------+
only showing top 5 rows



In [None]:
label_indexer = StringIndexer(inputCol="category", outputCol="label")


In [None]:
tokenizer = Tokenizer(inputCol="text", outputCol="words")
stopremover = StopWordsRemover(inputCol="words", outputCol="filtered")


In [None]:
cv = CountVectorizer(inputCol="filtered", outputCol="rawFeatures")
idf = IDF(inputCol="rawFeatures", outputCol="features")


In [None]:
nb = NaiveBayes(featuresCol="features", labelCol="label")


In [None]:
pipeline = Pipeline(stages=[label_indexer,
                            tokenizer,
                            stopremover,
                            cv,
                            idf,
                            nb])


In [None]:
train, test = data.randomSplit([0.8, 0.2], seed=42)


In [None]:
model = pipeline.fit(train)


In [None]:
predictions = model.transform(test)
predictions.select("text", "category", "prediction").show(10)


+--------------------+-------------+----------+
|                text|     category|prediction|
+--------------------+-------------+----------+
|                    |     POLITICS|      41.0|
|"An International...|ENTERTAINMENT|       2.0|
|"Bible Believing"...|     RELIGION|       0.0|
|"Coming Out" in H...| QUEER VOICES|      18.0|
|"F*ck the Police"...|     POLITICS|       5.0|
|"How Do Asian Ame...|    EDUCATION|      19.0|
|"Jungle Book" Fan...|ENTERTAINMENT|       5.0|
|"New Hampshire": ...|     POLITICS|      39.0|
|"Scandal" Loves S...|     POLITICS|       8.0|
|"Sing" Is An Opti...|ENTERTAINMENT|      14.0|
+--------------------+-------------+----------+
only showing top 10 rows



In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Accuracy =", accuracy)


Test Accuracy = 0.529135024119979


In [None]:
# Your custom test sentence
sample_text = "The stock market crashed after the new economic report."

# Convert to DataFrame
sample_df = spark.createDataFrame([(sample_text,)], ["text"])

# Run the pipeline model on it
sample_pred = model.transform(sample_df)

# Show prediction
sample_pred.select("text", "prediction").show(truncate=False)


+-------------------------------------------------------+----------+
|text                                                   |prediction|
+-------------------------------------------------------+----------+
|The stock market crashed after the new economic report.|9.0       |
+-------------------------------------------------------+----------+



In [None]:
# Get indexer labels
labels = model.stages[0].labels

pred_value = sample_pred.collect()[0]['prediction']
print("Predicted Category:", labels[int(pred_value)])


Predicted Category: BUSINESS


In [None]:
# List of sample news texts
samples = [
    "The government passed a new healthcare reform bill today.",
    "The football team won their third consecutive championship.",
    "Scientists discovered a new exoplanet similar to Earth.",
    "A major tech company announced a breakthrough AI chip.",
    "The actor won an award for best performance in a drama movie."
]

# Convert list into DataFrame
sample_df = spark.createDataFrame([(s,) for s in samples], ["text"])

# Run the model
sample_pred = model.transform(sample_df)

# Show predictions
sample_pred.select("text", "prediction").show(truncate=False)


+-------------------------------------------------------------+----------+
|text                                                         |prediction|
+-------------------------------------------------------------+----------+
|The government passed a new healthcare reform bill today.    |0.0       |
|The football team won their third consecutive championship.  |11.0      |
|Scientists discovered a new exoplanet similar to Earth.      |28.0      |
|A major tech company announced a breakthrough AI chip.       |29.0      |
|The actor won an award for best performance in a drama movie.|2.0       |
+-------------------------------------------------------------+----------+



In [None]:
labels = model.stages[0].labels  # StringIndexer labels

preds = model.transform(sample_df).select("text", "prediction").collect()

for row in preds:
    print(f"\nText: {row['text']}")
    print("Predicted Category:", labels[int(row['prediction'])])



Text: The government passed a new healthcare reform bill today.
Predicted Category: POLITICS

Text: The football team won their third consecutive championship.
Predicted Category: SPORTS

Text: Scientists discovered a new exoplanet similar to Earth.
Predicted Category: SCIENCE

Text: A major tech company announced a breakthrough AI chip.
Predicted Category: TECH

Text: The actor won an award for best performance in a drama movie.
Predicted Category: ENTERTAINMENT
