<a href="https://colab.research.google.com/github/LUMII-AILab/NLP_Course/blob/main/notebooks/fastText.ipynb" target="_new"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

# fastText-based classifier

## Setting up the environment

In [None]:
!pip install fasttext
!pip install scikit-learn
!pip install nltk

In [None]:
!wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz

!gunzip cc.en.300.bin.gz

In [3]:
import fasttext
import numpy
import re

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [None]:
import nltk
nltk.download('punkt')

In [None]:
# Load the pre-trained fastText model
ft_model = fasttext.load_model('cc.en.300.bin')

In [None]:
!wget https://raw.githubusercontent.com/LUMII-AILab/NLP_Course/main/notebooks/resources/news20/20_newsgroup.tsv

!wget https://raw.githubusercontent.com/LUMII-AILab/NLP_Course/main/notebooks/resources/news20/20_newsgroup-freq.tsv

!wget https://raw.githubusercontent.com/LUMII-AILab/NLP_Course/main/notebooks/resources/news20/stoplist.txt

## Text preprocessing

### Common with the NB classifier

In [7]:
def initialise(stop_txt, freq_tsv):
	global STOPLIST
	STOPLIST = set()

	with open(stop_txt) as txt:
		for word in txt:
			STOPLIST.add(normalize_text(word.strip()))

	global WHITELIST
	WHITELIST = set()

	with open(freq_tsv) as tsv:
		for entry in tsv:
			freq, word = entry.strip().split("\t")

			if int(freq) < 3: # TODO: experiment with the threshold
				continue

			WHITELIST.add(normalize_text(word))

In [8]:
def normalize_text(text):
	text = text.lower()
	text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) # e-mail addresses
	text = re.sub(r'https?://[A-Za-z0-9./-]+|www\.[A-Za-z0-9./-]+', '', text)				# URLs
	text = re.sub(r'\d+', "100", text)																					    # numbers

	return text.strip()


def filter_tokens(tokens): # cf. normalize_vector() in the NB implementation
    tokens_prim = []

    for t in tokens:
        if t in STOPLIST or len(t) == 1 or t not in WHITELIST:
            continue
        else:
            tokens_prim.append(t)

    return tokens_prim

### fastText-specific vectorization

In [16]:
def get_sentence_vector(sentence):
    # Normalization, tokenization and filtering - as for the NB classifier
    tokens = filter_tokens(nltk.word_tokenize(normalize_text(sentence)))

    if len(tokens) == 0: return []

    # For each token, get its fastText vector representation
    embeddings = [ft_model.get_word_vector(t) for t in tokens]

    # Average the token vectors to create a single sentence/text vector
    sentence_vector = numpy.sum(embeddings, axis=0)
    # TODO: experiment with mean() vs. sum() vs. amax()

    return sentence_vector

In [13]:
def load_dataset(filename):
    sentences, labels = [], []

    with open(filename, 'r', encoding='utf-8') as file:
        for line in file:
            cols = line.strip().split('\t')
            if len(cols) == 2:

                label, sent = cols
                sent_vec = get_sentence_vector(sent)

                if len(sent_vec) > 0:
                    sentences.append(sent_vec)
                    labels.append(label)

    print("[I] Samples loaded and vectorized:", len(sentences))

    return numpy.array(sentences), numpy.array(labels)

## Experimentation & evaluation

In [None]:
# Initialise the stopword and word frequency lists
initialise('stoplist.txt', '20_newsgroup-freq.tsv')

# Load and vectorize the dataset
X, y = load_dataset("20_newsgroup.tsv")

In [None]:
# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a logistic regression model
lr_model = LogisticRegression(max_iter=1000)
lr_model.fit(X_train, y_train)

# Evaluate the model
predictions = lr_model.predict(X_test)
print(classification_report(y_test, predictions))

In [None]:
# Save the model for later use

import pickle

with open("ft_classifier.pickle", "wb") as dmp:
		pickle.dump(lr_model, dmp)
		print("[I] FT classifier stored in a file")