<a href="https://colab.research.google.com/github/LUMII-AILab/NLP_Course/blob/main/notebooks/NaiveBayes.ipynb" target="_new"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

# Naïve Bayes text classifier

Hands-on dataset: *20 Newsgroup* assembled by Ken Lang @ CMU.

https://www.kaggle.com/datasets/au1206/20-newsgroup-original

We will use a format-converted, single-file version available from the course GitHub repo.

## Setting up the environment

In [None]:
!wget https://raw.githubusercontent.com/LUMII-AILab/NLP_Course/main/notebooks/resources/news20/20_newsgroup.tsv

!wget https://raw.githubusercontent.com/LUMII-AILab/NLP_Course/main/notebooks/resources/news20/20_newsgroup-freq.tsv

!wget https://raw.githubusercontent.com/LUMII-AILab/NLP_Course/main/notebooks/resources/news20/stoplist.txt

In [None]:
!pip install nltk
!pip install scikit-learn
!pip install seaborn

In [None]:
import nltk
nltk.download('punkt')

from sklearn.model_selection import KFold
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

import numpy

import seaborn
import matplotlib.pyplot as mplot

In [4]:
import re
import os
import sys
import pickle
import datetime

## Text preprocessing

In [5]:
def initialise(stop_txt, freq_tsv):
	global STOPLIST
	STOPLIST = set()

	with open(stop_txt) as txt:
		for word in txt:
			STOPLIST.add(normalize_text(word.strip()))

	print("[I] Word stoplist is read:", len(STOPLIST))

	global WHITELIST
	WHITELIST = set()

	with open(freq_tsv) as tsv:
		for entry in tsv:
			freq, word = entry.strip().split("\t")

			if int(freq) < 3:
        # TODO: experiment with the threshold (e.g., 3 / 5 / 10)
				# Ignore the long tail: most words occure less than N times
				continue

			WHITELIST.add(normalize_text(word))

	print("[I] Word whitelist is read:", len(WHITELIST))

In [6]:
def normalize_text(text):
	text = text.lower()
	text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) # e-mail addresses
	text = re.sub(r'https?://[A-Za-z0-9./-]+|www\.[A-Za-z0-9./-]+', '', text)				# URLs
	text = re.sub(r'\d+', "100", text)																					    # numbers

	return text.strip()


def normalize_vector(vector):
	words = list(vector.keys())

	for w in words:
		if w in STOPLIST or len(w) == 1 or w not in WHITELIST:
			vector.pop(w)

	return vector


def vectorize_text(text):
	return normalize_vector({word: True for word in nltk.word_tokenize(normalize_text(text))})

In [7]:
def read_data(file):
	data_set = {}  # topic => samples

	with open(file) as data:
		for entry in data:
			topic, text = entry.strip().split("\t")

			sub_set = []
			if topic in data_set:
				sub_set = data_set[topic]

			sub_set.append((vectorize_text(text), topic))
			data_set[topic] = sub_set

	return data_set


def join_data(data_set):
	union = []

	for cat in data_set:
		union += data_set[cat]

	return union

## Experimentation & evaluation

In [8]:
def cross_validate(data_set, k):
	global LABELS
	LABELS = []

	kfold = KFold(n_splits=k, shuffle=True)

	data_split = {}

	for cat in data_set:
		LABELS.append(cat)

		# K-Fold split for each class to ensure balanced training and test data sets
		folds = []

		for train, test in kfold.split(data_set[cat]):    # k loops
			train_data = numpy.array(data_set[cat])[train]  # vs. data[train]
			test_data = numpy.array(data_set[cat])[test]    # vs. data[test]
			folds.append({"train": train_data, "test": test_data})

		data_split[cat] = folds

	validations = []
	gold_result = []
	silver_result = []

	for i in range(k):
		# Join the training and test data into two respective sets
		train_data = numpy.array([])
		test_data = numpy.array([])

		for cat in data_split:
			if len(train_data) > 0:
				train_data = numpy.append(train_data, data_split[cat][i]["train"], axis=0)
			else:
				train_data = data_split[cat][i]["train"]

			if len(test_data) > 0:
				test_data = numpy.append(test_data, data_split[cat][i]["test"], axis=0)
			else:
				test_data = data_split[cat][i]["test"]

		# Naive Bayes classifier: training and evaluation
		nb = nltk.NaiveBayesClassifier.train(train_data)
		validations.append(nltk.classify.accuracy(nb, test_data))

		for t in test_data:
			gold_result.append(t[1])
			silver_result.append(nb.classify(t[0]))

	return (validations, gold_result, silver_result)

In [9]:
def run_validation(data_path, k):
		print("{0}-fold cross-validation:\n".format(k))

		start_time = datetime.datetime.now().replace(microsecond=0)

		# Run k-fold cross-validation
		validations, gold, silver = cross_validate(read_data(data_path), k)

		# Print the average accuracy: for each cross-validation step, and overall
		for step in validations:
				print("{0:.2f}  ".format(step), end='')
		print("{0:.0%}".format(numpy.mean(validations)))

		end_time = datetime.datetime.now().replace(microsecond=0)
		print("\nTotal validation time:", end_time - start_time, "\n")

		# Print an evaluation report
		print(classification_report(gold, silver))

		# Print a fancy confusion matrix
		matrix = confusion_matrix(gold, silver)
		seaborn.heatmap(matrix, xticklabels=LABELS, yticklabels=LABELS)
		mplot.xticks(rotation=90)
		mplot.show()
		# cf. print(nltk.ConfusionMatrix(gold_total, silver_total))

## Training for production

In [10]:
def run_training(data_path, verbose):
		print("[I] Training an NB classifier...")
		start_time = datetime.datetime.now().replace(microsecond=0)

		# TRAINING
		# The final (production) model is trained by using all available data (train+test)
		nb = nltk.NaiveBayesClassifier.train(join_data(read_data(data_path)))

		end_time = datetime.datetime.now().replace(microsecond=0)
		print("[I] Training time:", end_time - start_time)

		if verbose:
				nb.show_most_informative_features(n=10) # Try with n=100

		# Save the model for later use
		with open("nb_classifier.pickle", "wb") as dmp:
				pickle.dump(nb, dmp)
				print("[I] NB classifier stored in a file")

## The inference part

In [11]:
def run_inference():
		# Load the pre-trained model
		with open("nb_classifier.pickle", "rb") as dmp:
				nb = pickle.load(dmp)
				print("[I] NB classifier loaded from a file")

		while True:
				text = input("\nEnter a text to classify: ")
				if len(text) == 0: break

				# Extract text features for classification
				text_feat = vectorize_text(text)
				print("\nFeatures:", text_feat.keys(), "\n")

				# INFERENCE
				# Calculate a probability distribution over the classes
				prob_dist = nb.prob_classify(text_feat)

				# Return the probability distribution
				for label in prob_dist.samples():
						print("{0}: {1:.3f}".format(label, prob_dist.prob(label)))

				# Return the most probable class
				print("\nPrediction:", prob_dist.max())

## Execution

In [None]:
# Initialise the stopword and word frequency lists
initialise('stoplist.txt', '20_newsgroup-freq.tsv')

In [None]:
# Run k-fold cross-validation
run_validation("20_newsgroup.tsv", k=5)

# TODO:
# * experiment with preprocessing, feature extraction and 'hyperparameters'
# * evaluate and compare results

In [None]:
# Train and save the final model
run_training("20_newsgroup.tsv", True) # True=verbose

In [None]:
# Run the pre-trained model
run_inference()

Some test cases (generated with ChatGPT):
* `alt.atheism`: *Religion lacks empirical evidence to justify supernatural claims.*
* `soc.religion.christian`: *Faith in Jesus brings profound peace and eternal salvation.*
* `sci.med`: *Regular exercise contributes to overall well-being and disease prevention.*
* `sci.space`: *Satellite technology advances have revolutionized global communication networks.*
* `sci.space`: *The Kepler telescope's discovery of exoplanets revolutionizes our search for extraterrestrial life.*