# NLP Preprocessing

## Contents:


1. Imports
2. Directory Setup
3. Preprocessing Unit
4. Vectorization Uit
5. Pair-Wise Fast-Similarity Matrix Creation Unit
6. Similarity Computation Unit
7. Execute (aka. final :) )

## Imports

In [52]:
## Imports
'''Update code from Python 3.6.10 to a stable Kernel Python Version 3.8.0 '''

# Standard libs
import os
import sys
import json
import warnings
import re
import io
from io import StringIO
import inspect
import shutil
import ast
import string
import time
import pickle
import glob
import traceback
import multiprocessing
import requests
import logging
import math
import pytz
from itertools import chain
from string import Template
from datetime import datetime, timedelta
from dateutil import parser
import base64
from collections import defaultdict, Counter, OrderedDict
from contextlib import contextmanager
import unicodedata
from functools import reduce
import itertools
import tempfile
from typing import Any, Dict, List, Callable, Optional, Tuple, NamedTuple, Union
from functools import wraps

# graph
import networkx as nx

# Required pkgs
import numpy as np
from numpy import array, argmax
import pandas as pd
import ntpath
import tqdm

# General text correction - fit text for you (ftfy) and others
import ftfy
from fuzzywuzzy import fuzz
#from wordcloud import WordCloud
from spellchecker import SpellChecker

# imbalanced-learn
from imblearn.over_sampling import SMOTE, SVMSMOTE, ADASYN

# scikit-learn
from sklearn.utils import shuffle
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, jaccard_score, silhouette_score, homogeneity_score, calinski_harabasz_score
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
from sklearn.neighbors import NearestNeighbors, LocalOutlierFactor
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.cluster import KMeans, AgglomerativeClustering
from sklearn.base import BaseEstimator, TransformerMixin

# scipy
from scipy import spatial, sparse
from scipy.sparse import coo_matrix, vstack, hstack
from scipy.spatial.distance import euclidean, jensenshannon, cosine, cdist
from scipy.io import mmwrite, mmread
from scipy.stats import entropy
from scipy.cluster.hierarchy import dendrogram, ward, fcluster
import scipy.cluster.hierarchy as sch
from scipy.sparse.csr import csr_matrix
from scipy.sparse.lil import lil_matrix
from scipy.sparse.csgraph import connected_components

# sparse_dot_topn: matrix multiplier
from sparse_dot_topn import awesome_cossim_topn
import sparse_dot_topn.sparse_dot_topn as ct

# Gensim
import gensim
from gensim.models import Phrases, Word2Vec, KeyedVectors, FastText, LdaModel
from gensim import utils
from gensim.utils import simple_preprocess
from gensim.test.utils import datapath, get_tmpfile
from gensim.scripts.glove2word2vec import glove2word2vec
import gensim.downloader as api
from gensim import models, corpora, similarities

# NLTK
import nltk
#nltk_model_data_path = "/someppath/"
#nltk.data.path.append(nltk_model_data_path)
from nltk import FreqDist, tokenize, sent_tokenize, word_tokenize, pos_tag
from nltk.corpus import stopwords, PlaintextCorpusReader
from nltk.tokenize import RegexpTokenizer
from nltk.stem import WordNetLemmatizer
from nltk.stem.lancaster import LancasterStemmer
from nltk.stem.porter import *
from nltk.translate.bleu_score import sentence_bleu
print("NLTK loaded.")

# Spacy
import spacy
from spacy import displacy
from spacy.matcher import Matcher
from spacy.lang.en import English
print("Spacy loaded.")

# Pytorch
import torch
from torch import optim, nn
import torch.nn.functional as Functional
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelWithLMHead
from transformers import pipeline
from transformers import AutoModel
print("PyTorch loaded.")

# Plots
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly import offline
%matplotlib inline

# Theme settings
pd.set_option("display.max_columns", 80)
sns.set_context('talk')
sns.set(rc={'figure.figsize':(15,10)})
sns.set_style("darkgrid")
warnings.filterwarnings('ignore')

NLTK loaded.
Spacy loaded.
PyTorch loaded.


## Directory Setup

In [53]:
root_dir = os.path.abspath("../")
data_dir = os.path.join(root_dir, "data")
models_dir = os.path.join(root_dir, "models")
output_dir = os.path.join(root_dir, "output")

nlp_resources_fp = "../../../_Resources_/nlp_resources/"
spacy_model_data_path = "../../../_Resources_/spacy/en_core_web_lg-3.5.0/en_core_web_lg/en_core_web_lg-3.5.0/"
sbert_model_fp = "../../../_Resources_/transformer_models/all-distilroberta-v1/"

# load spacy
nlp = spacy.load(spacy_model_data_path)  # disabling: nlp = spacy.load(spacy_data_path, disable=['ner'])

## Data

In [56]:
df = pd.read_csv(data_dir + "/TwitterData_2018_01_01_2020-12-31_cuboulder.csv")
df = df.append(pd.read_csv(data_dir + "/TwitterData_2021_01_01_2022-12-31_cuboulder.csv"))

In [58]:
df.shape

(92132, 14)

In [60]:
df.head()

Unnamed: 0,date,tweet,lang,retweetCount,likeCount,replyCount,username,user_followersCount,user_friendsCount,verifiedStatus,tweet_url,hastags,chr_count,topic
0,2018-01-01,@CREWcrew @HillaryClinton @GOP @HustonTillotso...,qam,0,0,0,jccj171,382,1475,False,https://twitter.com/jccj171/status/94797812392...,,309,CUBoulder
1,2018-01-01,Applications can now be turned in for those in...,en,8,19,0,CUBuffsRalphie,9723,271,False,https://twitter.com/CUBuffsRalphie/status/9479...,"['GoBuffs', 'RunRalphieRun', 'CUBoulder']",251,CUBoulder
2,2018-01-01,A classic Darwinian ecological hypothesis hold...,en,1,6,1,CUBoulder,90944,893,True,https://twitter.com/CUBoulder/status/947943030...,,112,CUBoulder
3,2018-01-01,#MondayMotivation to live this year pushing th...,en,0,0,0,CUBoulderHealth,895,131,False,https://twitter.com/CUBoulderHealth/status/947...,"['MondayMotivation', 'CUBoulder']",120,CUBoulder
4,2018-01-01,@CUBoulder @CUBoulderLiving @CUBuffsRalphie 👍😎,und,0,0,0,OnlyInBoulder,16138,9680,False,https://twitter.com/OnlyInBoulder/status/94791...,,46,CUBoulder


## Preprocessing

In [82]:
class preprocessText:
    
    def __init__(self, resources_dir_path, custom_vocab=[], do_lemma=False):
        self.stopwords_file = os.path.join(resources_dir_path, "stopwords.txt")
        self.special_stopwords_file = os.path.join(resources_dir_path, "special_stopwords.txt")
        self.special_characters_file = os.path.join(resources_dir_path, "special_characters.txt")
        self.contractions_file = os.path.join(resources_dir_path, "contractions.json")
        self.chatwords_file = os.path.join(resources_dir_path, "chatwords.txt")
        self.emoticons_file = os.path.join(resources_dir_path, "emoticons.json")
        self.greeting_file = os.path.join(resources_dir_path, "greeting_words.txt")
        self.signature_file = os.path.join(resources_dir_path, "signature_words.txt")
        self.preserve_key = "<$>" # preserve special vocab
        self.vocab_list = custom_vocab
        self.preseve = True if len(custom_vocab) > 0 else False
        self.load_resources()
        self.do_lemma = do_lemma
        return
    
    def load_resources(self):
        
        ### Build Vocab Model --> Words to keep
        self.vocab_list = set(map(str.lower, self.vocab_list))
        self.vocab_dict = {w: self.preserve_key.join(w.split()) for w in self.vocab_list}
        self.re_retain_words = re.compile('|'.join(sorted(map(re.escape, self.vocab_dict), key=len, reverse=True)))
        
        ### Build Stopwords Model --> Words to drop/delete
        with open(self.stopwords_file, 'r', encoding='utf-8', errors='ignore') as f:
            self.stopwords = [x.rstrip() for x in f.readlines()]
        with open(self.special_stopwords_file, 'r', encoding='utf-8', errors='ignore') as f:
            self.stopwords.extend([x.rstrip() for x in f.readlines()])
        with open(self.special_characters_file, 'r', encoding='utf-8', errors='ignore') as f:
            self.stopwords.extend([x.rstrip() for x in f.readlines()])
        self.stopwords = list(sorted(set(self.stopwords).difference(self.vocab_list)))

        ### Build Contractions
        with open(self.contractions_file, 'r', encoding='utf-8', errors='ignore') as f:
            self.contractions = dict(json.load(f))
        
        ### Build Chat-words
        with open(self.chatwords_file, 'r', encoding='utf-8', errors='ignore') as f:
            self.chat_words_map_dict, self.chat_words_list = {}, []
            chat_words = [x.rstrip() for x in f.readlines()]
            for line in chat_words:
                cw = line.split("=")[0]
                cw_expanded = line.split("=")[1]
                self.chat_words_list.append(cw)
                self.chat_words_map_dict[cw] = cw_expanded
            self.chat_words_list = set(self.chat_words_list)
        
        ### Bukd social markups
        # emoticons
        with open(self.emoticons_file, "r") as f:
            self.emoticons = re.compile(u'(' + u'|'.join(k for k in json.load(f)) + u')')
        # emojis
        self.emojis = re.compile("["
                                   u"\U0001F600-\U0001F64F"  # emoticons
                                   u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                                   u"\U0001F680-\U0001F6FF"  # transport & map symbols
                                   u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                                   u"\U00002702-\U000027B0"
                                   u"\U000024C2-\U0001F251"
                                   "]+", flags=re.UNICODE)
        # greeting
        with open(self.greeting_file, 'r', encoding='utf-8', errors='ignore') as f:
            self.greeting_words = [x.rstrip() for x in f.readlines()]
        # signature
        with open(self.signature_file, 'r', encoding='utf-8', errors='ignore') as f:
            self.signature_words = [x.rstrip() for x in f.readlines()]
        # spell-corrector
        self.spell_checker = SpellChecker()   
        return
    
    
    def reserve_keywords_from_cleaning(self, text, reset=False):
        """ 
        Finds common words from a user-provided list of special keywords to preserve them from 
        cleaning steps. Identifies every special keyword and joins them using `self.preserve_key` during the 
        cleaning steps, and later resets it back to original word in the end.
        """
        if reset is False:
            # compile using a dict of words and their expansions, and sub them if found!
            match_and_sub = self.re_retain_words.sub(lambda x: self.vocab_dict[x.string[x.start():x.end()]], text)
            return re.sub(r"([\s\n\t\r]+)", " ", match_and_sub).strip()
        else:
            # reverse the change! - use this at the end of preprocessing
            text = text.replace(self.preserve_key, " ")
            return re.sub(r"([\s\n\t\r]+)", " ", text).strip()


    def basic_clean(self, input_sentences):
        cleaned_sentences = []
        for sent in input_sentences:
            sent = str(sent).strip()
            # FIX text
            sent = ftfy.fix_text(sent)
            # Normalize accented chars
            sent = unicodedata.normalize('NFKD', sent).encode('ascii', 'ignore').decode('utf-8', 'ignore')
            # Removing <…> web scrape tags
            sent = re.sub(r"\<(.*?)\>", " ", sent)
            # Expanding contractions using contractions_file
            sent = re.sub(r"(\w+\'\w+)", lambda x: self.contractions.get(x.group().lower(), x.group().lower()), sent)
            # Removing web urls
            sent = re.sub(r'''(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0–9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'".,<>?«»""'']))''', " ", sent)
            # Removing date formats
            sent = re.sub(r"(\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}\:\d{2}\s\:)", " ", sent)
            # Removing extra whitespaces
            sent = re.sub(r"([\s\n\t\r]+)", " ", sent).strip()
            cleaned_sentences.append(sent)
        return cleaned_sentences


    def deep_clean(self, input_sentences):
        cleaned_sentences = []
        for sent in input_sentences:
            # normalize text to "utf-8" encoding
            sent = unicodedata.normalize('NFKD', str(sent)).encode('ascii', 'ignore').decode('utf-8', 'ignore')
            # lowercasing
            sent = str(sent).strip().lower()

            # <----------------------------- CUSTOM CLEANING ----------------------------- >
            #
            # *** Mark important keywords such as: Domain specific, Question words(wh-words), etc, using 
            # "self.vocab_list". Words from this list if found in any input sentence shall be joined using 
            # a key (self.preserve_key) during pre-processing step, and later un-joined to retain them.
            #
            
            # TWITTER SPECIFIC
            twitter_anchor_re = "(?<=^|(?<=[^a-zA-Z0-9-_\.]))@([A-Za-z]+[A-Za-z0-9-_]+)"
            sent = re.sub(r"{}".format(twitter_anchor_re), " ", sent)
            
            if self.preseve: 
                sent = self.reserve_keywords_from_cleaning(sent, reset=False)
            #
            # <----------------------------- CUSTOM CLEANING ----------------------------- >

            # remove Emojis
            sent = self.emojis.sub(r'', sent)
            # remove emoticons
            sent = self.emoticons.sub(r'', sent)
            # remove common chat-words
            sent = " ".join([self.chat_words_map_dict[w.upper()] if w.upper() in self.chat_words_list else w for w in sent.split()])
            # FIX text
            sent = ftfy.fix_text(sent)
            # Normalize accented chars
            sent = unicodedata.normalize('NFKD', sent).encode('ascii', 'ignore').decode('utf-8', 'ignore')
            # Removing <…> web scrape tags
            sent = re.sub(r"\<(.*?)\>", " ", sent)
            # Expanding contractions using contractions_file
            sent = re.sub(r"(\w+\'\w+)", lambda x: self.contractions.get(x.group().lower(), x.group().lower()), sent)
            # Removing web urls
            sent = re.sub(r'''(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0–9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'".,<>?«»""'']))''', " ", sent)
            # Removing date formats
            sent = re.sub(r"(\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}\:\d{2}\s\:)", " ", sent)

            # <----------------------------- OPTIONAL CLEANING ----------------------------- >
            #
            # removing punctuations 🔥🔥
            # *** disable them, when sentence structure needs to be retained ***
            sent = re.sub(r"[\$|\#\@\*\%]+\d+[\$|\#\@\*\%]+", " ", sent)
            sent = re.sub(r"\'s", " \'s", sent)
            sent = re.sub(r"\'ve", " \'ve", sent)
            sent = re.sub(r"n\'t", " n\'t", sent)
            sent = re.sub(r"\'re", " \'re", sent)
            sent = re.sub(r"\'d", " \'d", sent)
            sent = re.sub(r"\'ll", " \'ll", sent)
            sent = re.sub(r"[\/,\@,\#,\\,\{,\},\(,\),\[,\],\$,\%,\^,\&,\*,\<,\>]", " ", sent)
            sent = re.sub(r"[\,,\;,\:,\-]", " ", sent)      # main puncts
            
            # remove sentence de-limitters 🔥🔥
            # *** disable them, when sentence boundary/ending is important ***
            # sent = re.sub(r"[\!,\?,\.]", " ", sent)

            # keep only text & numbers 🔥🔥
            # *** enable them, when only text and numbers matter! *** 
            # sent = re.sub(r"\s+", " ", re.sub(r"[\\|\/|\||\{|\}|\[|\]\(|\)]+", " ", re.sub(r"[^A-z0-9]", " ", str(sent))))
            
            # correct spelling mistakes 🔥🔥
            # *** enable them when english spelling mistakes matter *** 
            # sent = " ".join([self.spell_checker.correction(w) if w in self.spell_checker.unknown(sent.split()) else w for w in sent.split()])
            #
            # <----------------------------- OPTIONAL CLEANING ----------------------------- >

            # Remove stopwords
            sent = " ".join(token.text for token in nlp(sent) if token.text not in self.stopwords and 
                                                                 token.lemma_ not in self.stopwords)
            # Lemmatize
            if self.do_lemma:
                sent = " ".join(token.lemma_ for token in nlp(sent))
            # Removing extra whitespaces
            sent = re.sub(r"([\s\n\t\r]+)", " ", sent).lower().strip()

            # <----------------------------- CUSTOM CLEANING ----------------------------- >
            #
            # *** Reverse the custom joining now to un-join the special words found!
            if self.preseve: 
                sent = self.reserve_keywords_from_cleaning(sent, reset=True)
            # <----------------------------- CUSTOM CLEANING ----------------------------- >

            cleaned_sentences.append(sent.strip().lower())
        return cleaned_sentences


    def spacy_get_pos_list(self, results):
        word_list, pos_list, lemma_list, ner_list, start_end_list = [], [], [], [], []
        indices = results['sentences']
        for line in indices:
            tokens = line['tokens']
            for token in tokens:
                # (1). save tokens
                word_list.append(token['word'])
                # (2). save pos
                pos_list.append(token['pos'])
                # (3). save lemmas
                lemma = token['lemma'].lower()
                if lemma in self.stopwords: continue
                lemma_list.append(lemma)
                # (4). save NER
                ner_list.append(token['ner'])
                # (5). save start
                start_end_list.append(str(token['characterOffsetBegin']) + "_" + str(token['characterOffsetEnd']))
        output = {"word_list": word_list, 
                  "lemma_list": lemma_list, 
                  "token_start_end_list": start_end_list,
                  "pos_list": pos_list, "ner_list": ner_list}
        return output

    def spacy_generate_features(self, doc, operations='tokenize,ssplit,pos,lemma,ner'):
        """
        Spacy nlp pipeline to generate features such as pos, tokens, ner, dependency. Accepts doc=nlp(text)
        """
        # spacy doc
        doc_json = doc.to_json()  # Includes all operations given by spacy pipeline

        # Get text
        text = doc_json['text']

        # ---------------------------------------- OPERATIONS  ---------------------------------------- #
        # 1. Extract Entity List
        entity_list = doc_json["ents"]

        # 2. Create token lib
        token_lib = {token["id"]: token for token in doc_json["tokens"]}

        # init output json
        output_json = {}
        output_json["sentences"] = []

        # Perform spacy operations on each sent in text
        for i, sentence in enumerate(doc_json["sents"]):
            # init parsers
            parse = ""
            basicDependencies = []
            enhancedDependencies = []
            enhancedPlusPlusDependencies = []

            # init output json
            out_sentence = {"index": i, "line": 1, "tokens": []}
            output_json["sentences"].append(out_sentence)

            # 3. Split sentences by indices(i), add labels (pos, ner, dep, etc.)
            for token in doc_json["tokens"]:

                if sentence["start"] <= token["start"] and token["end"] <= sentence["end"]:
                    
                    # >>> Extract Entity label
                    ner = "O"
                    for entity in entity_list:
                        if entity["start"] <= token["start"] and token["end"] <= entity["end"]:
                            ner = entity["label"]

                    # >>> Extract dependency info
                    dep = token["dep"]
                    governor = 0 if token["head"] == token["id"] else (token["head"] + 1)  # CoreNLP index = pipeline index +1
                    governorGloss = "ROOT" if token["head"] == token["id"] else text[token_lib[token["head"]]["start"]:
                                                                                     token_lib[token["head"]]["end"]]
                    dependent = token["id"] + 1
                    dependentGloss = text[token["start"]:token["end"]]

                    # >>> Extract lemma
                    lemma = doc[token["id"]].lemma_

                    # 4. Add dependencies
                    basicDependencies.append({"dep": dep,
                                              "governor": governor,
                                              "governorGloss": governorGloss,
                                              "dependent": dependent,
                                              "dependentGloss": dependentGloss})
                    # 5. Add tokens
                    out_token = {"index": token["id"] + 1,
                                 "word": dependentGloss,
                                 "originalText": dependentGloss,
                                 "characterOffsetBegin": token["start"],
                                 "characterOffsetEnd": token["end"]}

                    # 6. Add lemmas
                    if "lemma" in operations:
                        out_token["lemma"] = lemma

                    # 7. Add POS tagging
                    if "pos" in operations:
                        out_token["pos"] = token["tag"]

                    # 8. Add NER
                    if "ner" in operations:
                        out_token["ner"] = ner

                    # Update output json
                    out_sentence["tokens"].append(out_token)

            # 9. Add dependencies operation
            if "parse" in operations:
                out_sentence["parse"] = parse
                out_sentence["basicDependencies"] = basicDependencies
                out_sentence["enhancedDependencies"] = out_sentence["basicDependencies"]
                out_sentence["enhancedPlusPlusDependencies"] = out_sentence["basicDependencies"]
        # ---------------------------------------- OPERATIONS  ---------------------------------------- #
        return output_json
    
    def spacy_clean(self, input_sentences):
        batch_size = min(int(np.ceil(len(input_sentences)/100)), 500)
        
        # Part 1: generate spacy textual features (pos, ner, lemma, dependencies)
        sentences = [self.spacy_generate_features(doc) for doc in nlp.pipe(input_sentences, batch_size=batch_size, n_process=-1)]
        
        # Part 2: collect all the features for each sentence
        spacy_sentences = [self.spacy_get_pos_list(sent) for sent in sentences]

        return spacy_sentences


    ## MAIN ##
    def run_pipeline(self, sentences, operation):
        """
        Main module to execute pipeline. Accepts list of strings, and desired operation.
        """
        if operation=="":
            raise Exception("Please pass a cleaning type - `basic`, `deep` or `spacy` !!")

        # run basic cleaning
        if "basic" == operation.lower(): 
            return self.basic_clean(sentences)

        # run deep cleaning
        if "deep" == operation.lower(): 
            return self.deep_clean(sentences)

        # run spacy pipeline
        if "spacy" == operation.lower(): 
            return self.spacy_clean(sentences)

### Execute

In [83]:
## settings ##

"""
CUSTOM VOCABULARY ::
- List of words you wish to mark and retain them across the preprocessing steps - very important!
- Example, task-specific, domain-specific keywords. In this case, question words for FAQs.

"""

custom_vocab = ["who", "what", "where", "when", "would", "which", "how", "why", "can", "may", 
                "will", "won't", "does", "does not","doesn't", "do", "do i", "do you", "is it", "would you", 
                "is there", "are there", "is it so", "is this true", "to know", "is that true", "are we", 
                "am i", "question is", "can i", "can we", "tell me", "can you explain", "how ain't", 
                "question", "answer", "questions", "answers", "ask", "can you tell"]


"""
Utilities:
- Truncate words to their root-known-word form, stripping off their adjectives, verbs, etc.
"""

do_lemmatizing = True
# do_chinking = False
# do_chunking = False
# do_dependencyParser = False

In [84]:
## Preprocessing ##

preprocessText_obj = preprocessText(nlp_resources_fp, custom_vocab, do_lemmatizing)

def cleaning(data, text_col):
    #data["Basic_%s" % text_col] = preprocessText_obj.run_pipeline(data[text_col], "basic")
    data["Deep_%s" % text_col] = preprocessText_obj.run_pipeline(data[text_col], "deep")
    #data["Spacy_%s" % text_col] = preprocessText_obj.run_pipeline(data[text_col], "spacy")
    return data


## SAMPLE
# df = cleaning(df, <_TEXT_COLUMN_>)

In [None]:
## Execute ##

df_clean = cleaning(df, "tweet")
df_clean.to_csv("../data/Twitter_2018_2022_cuboulder_processed.csv", index=False)

---

## Vectorization

In [50]:
class BertTransformer(BaseEstimator, TransformerMixin):

    def __init__(self, tokenizer, model, max_length=128, embedding_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,):
        self.tokenizer = tokenizer
        self.model = model
        self.model.eval()
        self.max_length = max_length
        self.embedding_func = embedding_func
        if self.embedding_func is None:
            self.embedding_func = lambda x: x[0][:, 0, :].squeeze()

    def _tokenize(self, text):
        # Mean Pooling - Take attention mask into account for correct averaging
        def mean_pooling(model_output, attention_mask):
            token_embeddings = model_output[0] #First element of model_output contains all token embeddings
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            return sum_embeddings / sum_mask

        # Tokenize the text with the provided tokenizer
        encoded_input = tokenizer(text, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt')

        # Compute token embeddings
        with torch.no_grad():
            model_output = self.model(**encoded_input)

        # Perform mean pooling
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

        # bert takes in a batch so we need to unsqueeze the rows
        return sentence_embeddings

    def transform(self, text: List[str], formatt='tensor'):
        """ MODIFIED LATEST JAN 27, 2023"""
        if isinstance(text, pd.Series):
            text = text.tolist()
        
        # default previously returned embeddings
        embeddings = self._tokenize(text)
        
        # new **modified**
        if formatt:
            formatt = str(formatt).strip().lower()
            if formatt=='tensor':
                return embeddings
            elif formatt=='numpy':
                return embeddings.numpy()
            elif formatt=='csr':
                embeddings_matrix = Functional.normalize(embeddings, p=2, dim=1)
                embeddings_matrix_csr = csr_matrix(embeddings_matrix.numpy().astype(np.float64))
                return embeddings_matrix_csr
            else:
                raise Exception("Invalid input for formatt!")

    def fit(self, X, y=None):
        """No fitting required so we just return ourselves. For fine-tuning, refer to shared gpu-code!"""
        return self

In [51]:
# SENTENCE-BERT VECTORIZATION #

# load tokenizer, model classes
tokenizer = AutoTokenizer.from_pretrained(sbert_model_fp)
model_bert = AutoModel.from_pretrained(sbert_model_fp)

# load vectorizer
bert_vectorizer = BertTransformer(tokenizer, model_bert, embedding_func=lambda x: x[0][:, 0, :].squeeze())
print("Bert Model '%s' loaded." % sbert_model_fp)

## SAMPLE FOR VECTORIZATION
# corpus = df['text_col']
# embeddings = bert_vectorizer.transform(corpus)
# embeddings = bert_vectorizer.transform(corpus, formatt='numpy')
# embeddings = bert_vectorizer.transform(corpus, formatt='csr')

Bert Model '../../../_Resources_/transformer_models/all-distilroberta-v1/' loaded.


---