In [6]:
from random import shuffle

from tokenize import group
from typing import Iterator
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from gensim.models.doc2vec import Doc2Vec

from src.classifiers.base import DatasetEntry, DatasetParser

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_selection import SelectFpr, chi2
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import pickle

# Classifiers
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from gensim.utils import simple_preprocess

# Metrics
from sklearn.metrics import classification_report

In [7]:
# Load doc2vec trained model on newsgroups

doc2vec = Doc2Vec.load('./20newsgroups-18828(Doc2Vec Model).model')

str_to_doc2vec_embedding = lambda sentence: doc2vec.infer_vector(simple_preprocess(sentence))
id_to_doc2vec_embedding = lambda id: doc2vec.dv

# Example
embedding = str_to_doc2vec_embedding('Hello world')

print(embedding)
print(embedding.shape)

[ 0.06866241 -0.26297903  0.07269929  0.24940334 -0.12856945 -0.905505
  0.28001404  0.46869963 -0.49396008  0.49471778  0.16190854 -0.300937
 -0.19951914  0.45683843  0.08871276 -0.3749264   0.8437531  -0.43066245
 -0.71053845 -0.07456572  0.29725227  0.42502564  0.17985488  0.57995033
 -0.4592125  -0.5038251   0.37507877  0.29996717 -0.12883967 -0.493351
 -0.09473423 -0.03056421 -0.41688246 -0.18476976 -0.2343755  -0.11148989
  0.2709551  -0.37814376  0.6519001  -0.40003008 -1.0897918   0.44535235
 -0.10015228 -0.32676652  0.5154579   0.32467648  0.05818143  0.3651494
  0.17330344 -0.01803419 -0.0208741   0.3086182  -0.06492205 -0.25627473
  0.5756725  -0.5938626   0.26274246 -0.6802011  -0.15170975  0.89259833
 -0.16566792 -0.27607304 -0.13363057  0.19446003 -0.00382854 -0.09263676
  0.4336044  -0.31560203 -0.26633114 -0.18227287 -0.00822984 -0.47730404
  0.6964229  -0.02441906 -0.47603273  0.7067244   0.60662615  0.36952233
  0.16568188 -0.20176555  0.06361324 -0.1877929   0.180606

In [24]:
class NewsgroupsEntry(DatasetEntry):

    def __init__(self, group_id: int, entry_path: Path):
        try:
            text = entry_path.read_text(errors="ignore")
        except UnicodeDecodeError as e:
            print(entry_path)
            raise e
        end_of_line1 = text.find("\n")
        end_of_line2 = text.find("\n", end_of_line1 + 1)
        line1 = text[:end_of_line1]
        line2 = text[end_of_line1 + 1: end_of_line2]

        super(NewsgroupsEntry, self).__init__(f"{group_id}_{entry_path.name}")

        self.path = entry_path
        self.group = entry_path.parent.name

        if line1.startswith("From: "):
            self.from_ = line1[6:]
            assert line2.startswith("Subject: ")
            self.subject = line2[9:]
        elif line1.startswith("Subject: "):
            self.subject = line1[9:]
            assert line2.startswith("From: ")
            self.from_ = line2[6:]
        else:
            assert False, f"From/Subject not found in {entry_path}"

        self.text = text[end_of_line2:].strip()

    @property
    def raw_text(self):
        return self.path.read_text(errors="ignore")


class NewsgroupsParser(DatasetParser):
    """Parser for the 20 Newsgroups dataset"""

    def __init__(self):
        super(NewsgroupsParser, self).__init__(
            data=self.root / "20newsgroups-18828",
            count_vzer=CountVectorizer(
                input="filename",
                decode_error="ignore",
                stop_words="english"
            ),
            total=18828
        )

        self.entries: list[NewsgroupsEntry] = []

        for group_id, folder in enumerate(self.data.iterdir()):
            for file in folder.iterdir():
                self.entries.append(NewsgroupsEntry(group_id, file))

        assert len(self.entries) == self.total

    def __iter__(self) -> Iterator[NewsgroupsEntry]:
        return iter(self.entries)\


    def fit_transform(self):
        return self.count_vzer.fit_transform(
            tuple(str(entry.path) for entry in self)
        )

In [25]:
parser = NewsgroupsParser()

In [46]:
entries = parser.entries
shuffle(entries)

sentences = np.array([(entry.id, entry.group) for entry in entries])
groups = set([sentence[1] for sentence in sentences])
groups_dict = {label: i for i, label in enumerate(groups)}

In [47]:
# Vectorize input
X = np.array([doc2vec.dv[sentence[0]] for sentence in sentences])
y = np.array([groups_dict[sentence[1]] for sentence in sentences])

len(sentences)

18828

In [48]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
X_train.shape, y_train.shape

((15062, 100), (15062,))

In [55]:
clf = MultinomialNB(random_state=47)
clf.fit(X_train, y_train)

TypeError: MultinomialNB.__init__() got an unexpected keyword argument 'random_state'

In [53]:
y_pred = clf.predict(X_test)

In [54]:
report = classification_report(y_test, y_pred, zero_division=0)
matrix = confusion_matrix(y_test, y_pred)
print(report)

              precision    recall  f1-score   support

           0       0.24      0.26      0.25       192
           1       0.17      0.16      0.16       202
           2       0.19      0.21      0.20       184
           3       0.12      0.12      0.12       164
           4       0.30      0.29      0.30       194
           5       0.24      0.26      0.25       193
           6       0.13      0.14      0.13       180
           7       0.27      0.27      0.27       193
           8       0.18      0.17      0.17       202
           9       0.11      0.10      0.10       197
          10       0.19      0.21      0.20       185
          11       0.32      0.30      0.31       213
          12       0.28      0.23      0.25       209
          13       0.13      0.15      0.14       187
          14       0.14      0.15      0.14       124
          15       0.11      0.10      0.11       162
          16       0.23      0.26      0.24       194
          17       0.28    

In [3]:
# Save model

pickle.dump(clf, open('newsgroups_clf.model', 'wb'))

NameError: name 'clf' is not defined

In [4]:
# For loading model
loaded_clf = pickle.load(open('newsgroups_clf.model', 'rb'))

In [5]:
loaded_clf.predict(X_test)

NameError: name 'X_test' is not defined