Importing the [Dataset](https://www.kaggle.com/datasets/rmisra/news-category-dataset)

In [1]:
!unzip "/content/News_Category_Dataset_v3.json.zip"

Archive:  /content/News_Category_Dataset_v3.json.zip
  inflating: News_Category_Dataset_v3.json  


Installing the necessary packages

In [2]:
!pip install jsonlines

Collecting jsonlines
  Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Installing collected packages: jsonlines
Successfully installed jsonlines-4.0.0


In [3]:
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-2.6.1-py3-none-any.whl (163 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/163.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.3/163.3 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from tor

Importing required packages

In [47]:
import pandas as pd
import jsonlines
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
import re
import pickle

Loading the dataset

In [5]:
with jsonlines.open('/content/News_Category_Dataset_v3.json') as reader:
  df = pd.DataFrame(reader)
df.head()

Unnamed: 0,link,headline,category,short_description,authors,date
0,https://www.huffpost.com/entry/covid-boosters-...,Over 4 Million Americans Roll Up Sleeves For O...,U.S. NEWS,Health experts said it is too early to predict...,"Carla K. Johnson, AP",2022-09-23
1,https://www.huffpost.com/entry/american-airlin...,"American Airlines Flyer Charged, Banned For Li...",U.S. NEWS,He was subdued by passengers and crew when he ...,Mary Papenfuss,2022-09-23
2,https://www.huffpost.com/entry/funniest-tweets...,23 Of The Funniest Tweets About Cats And Dogs ...,COMEDY,"""Until you have a dog you don't understand wha...",Elyse Wanshel,2022-09-23
3,https://www.huffpost.com/entry/funniest-parent...,The Funniest Tweets From Parents This Week (Se...,PARENTING,"""Accidentally put grown-up toothpaste on my to...",Caroline Bologna,2022-09-23
4,https://www.huffpost.com/entry/amy-cooper-lose...,Woman Who Called Cops On Black Bird-Watcher Lo...,U.S. NEWS,Amy Cooper accused investment firm Franklin Te...,Nina Golgowski,2022-09-22


In [6]:
df['category'].value_counts()

category
POLITICS          35602
WELLNESS          17945
ENTERTAINMENT     17362
TRAVEL             9900
STYLE & BEAUTY     9814
PARENTING          8791
HEALTHY LIVING     6694
QUEER VOICES       6347
FOOD & DRINK       6340
BUSINESS           5992
COMEDY             5400
SPORTS             5077
BLACK VOICES       4583
HOME & LIVING      4320
PARENTS            3955
THE WORLDPOST      3664
WEDDINGS           3653
WOMEN              3572
CRIME              3562
IMPACT             3484
DIVORCE            3426
WORLD NEWS         3299
MEDIA              2944
WEIRD NEWS         2777
GREEN              2622
WORLDPOST          2579
RELIGION           2577
STYLE              2254
SCIENCE            2206
TECH               2104
TASTE              2096
MONEY              1756
ARTS               1509
ENVIRONMENT        1444
FIFTY              1401
GOOD NEWS          1398
U.S. NEWS          1377
ARTS & CULTURE     1339
COLLEGE            1144
LATINO VOICES      1130
CULTURE & ARTS     1074
EDUCATI

In [7]:
df.shape

(209527, 6)

Selecting only specific categories to have lesser rows to keep the training quicker

In [8]:
df_data = df[['headline', 'category']]
df_data = df_data[df_data['category'].isin(['POLITICS', 'WELLNESS', 'ENTERTAINMENT', 'TRAVEL', 'STYLE & BEAUTY'])]
df_data.dropna(inplace=True)
df_data.shape

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_data.dropna(inplace=True)


(90623, 2)

In [40]:
headlines = df_data['headline'].tolist()
categories = df_data['category'].tolist()

Basic text cleaning and text preprocessing. Performing stop word removal and word lemmatization

In [41]:
def text_clean(text):
  text = re.sub(r'[^A-Za-z0-9 .]+', '', text)
  return text

headlines = [text_clean(text) for text in headlines]

BERT model to convert sentences to word embeddings

In [42]:
model = SentenceTransformer('bert-base-nli-mean-tokens', device='cuda')
headlineEmbeddings = model.encode(headlines)

In [43]:
len(headlineEmbeddings), len(headlineEmbeddings[0])

(90623, 768)

Train and test split

In [44]:
X_train, X_test, y_train, y_test = train_test_split(headlineEmbeddings, categories, test_size=0.2, random_state=13)

Training a basic Random Forest Classifier

In [45]:
rf_classifier = RandomForestClassifier(random_state=13)
rf_classifier.fit(X_train, y_train)

y_pred = rf_classifier.predict(X_test)

In [46]:
print(classification_report(y_test, y_pred))

                precision    recall  f1-score   support

 ENTERTAINMENT       0.71      0.66      0.69      3480
      POLITICS       0.81      0.91      0.86      7116
STYLE & BEAUTY       0.86      0.61      0.72      1962
        TRAVEL       0.81      0.62      0.70      1971
      WELLNESS       0.75      0.82      0.78      3596

      accuracy                           0.78     18125
     macro avg       0.79      0.73      0.75     18125
  weighted avg       0.78      0.78      0.78     18125



Saving the trained model to a pickle file. We can later load it for predictions

In [48]:
with open('random_forest_model.pkl', 'wb') as file:
    pickle.dump(rf_classifier, file)