# Embeddings

In [None]:
import re
import string
from typing import Any, Optional

import numpy as np
import pandas as pd

In [None]:
from sklearn.base import TransformerMixin
from sentence_transformers import SentenceTransformer
from gensim.models import KeyedVectors

from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer

In [None]:
ac_frame = pd.read_csv("../data/agatha_christie.csv")

X_ac = ac_frame["text"]
y_ac = ac_frame["book"]

np_frame = pd.read_csv("../data/newspaper_articles.csv")

X_np = np_frame["STORY"]
y_np = np_frame["SECTION"].map(
    {0: "Politics", 1: "Technology", 2: "Entertainment", 3: "Business"}
)



ja_frame = pd.read_csv("../data/jane_austen.csv")

X_ja = ja_frame["x_text"]
y_ja = ja_frame["y_book"]

sh_frame = pd.read_csv("../data/sherlock_holmes.csv")



X_sh = sh_frame["rawtext"]
y_sh = sh_frame["label"].map(
    {
        0: "The Valley of Fear",
        1: "The Memoirs of Sherlock Holmes",
        2: "The Return of Sherlock Holmes",
        3: "Adventures of Sherlock Holmes",
    }
)

names = ["Agatha Christie", "Sherlock Holmes", "Jane Austen", "Newspapers"]
X_datasets = [X_ac, X_sh, X_ja, X_np]
y_datasets = [y_ac, y_sh, y_ja, y_np]

## Avg FastText

In [None]:
# !wget https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip && unzip wiki-news-300d-1M.vec.zip

In [None]:
stop_words = set(stopwords.words("english"))
stemmer = SnowballStemmer("english")

class FastTextTransformer(TransformerMixin):
    __splitter = re.compile(r"[\W_]")
    __table = str.maketrans("", "", string.punctuation)

    def __init__(self, func: str = "mean") -> None:
        self.func = func

    def fit(self, X: Any, y: Any = None) -> "FastTextTransformer":
        return self

    def _word_vec(self, word: str) -> Optional[np.ndarray]:

        for w in (
            word,
            word.lower(),
            word.translate(self.__table),
            word.translate(self.__table).lower(),
            stemmer.stem(word.translate(self.__table).lower()),
        ):
            if w in stop_words:
                return None

            if w in w2v:
                return w2v[w]

        return None

    def transform(self, X: Any, y: Any = None) -> np.ndarray:
        results = []

        for x in X:
            vec = []

            tokens = word_tokenize(x)

            for token in tokens:
                w = self._word_vec(token)

                if w is not None:
                    vec.append(w)
                    continue
                else:
                    for sub in self.__splitter.split(token):
                        w = self._word_vec(token)
                        if w is not None:
                            vec.append(w)
                            continue

            if self.func == "mean":
                results.append(np.mean(vec, axis=0))
            else:
                results.append(np.sum(vec, axis=0))

        return np.vstack(results)

        def transform(self, X: Any, y: Any = None) -> np.ndarray:
            return self.transform(X)

In [None]:
fs = FastTextTransformer()

X_fasttext = []

for name, X in tqdm(zip(names, X_datasets), total=4):
    X_fasttext.append(fs.fit_transform(X))

## Distil USE

In [None]:
s_transformer = SentenceTransformer("distiluse-base-multilingual-cased-v2")

X_distiluse = []
for name, X in tqdm(zip(names, X_datasets), total=4):
    X_distiluse.append(
        [
            s_transformer.encode(sent_tokenize(x)).mean(axis=0)
            for x in tqdm(X, desc=name)
        ]
    )

## Average GloVe

In [None]:
s_transformer = SentenceTransformer("average_word_embeddings_glove.6B.300d")

X_avg_glove = []
for name, X in tqdm(zip(names, X_datasets), total=4):
    X_avg_glove.append(
        [
            s_transformer.encode(sent_tokenize(x)).mean(axis=0)
            for x in tqdm(X, desc=name)
        ]
    )

## STSB Roberta

In [None]:
s_transformer = SentenceTransformer("stsb-roberta-base")

X_roberta = []
for name, X in tqdm(zip(names, X_datasets), total=4):
    X_roberta.append(
        [
            s_transformer.encode(sent_tokenize(x)).mean(axis=0)
            for x in tqdm(X, desc=name)
        ]
    )