In [1]:
from IPython.display import clear_output

In [2]:
!pip install minisom

clear_output()

In [3]:
!pip install spacy==3.0.5

clear_output()

In [4]:
STOPWORDS_FILE = 'stopwords.txt'

In [5]:
import requests
import os
import tarfile


def download_model():

    filename = 'en_core_web_sm_temporary'
    if not os.path.exists(filename):
        r = requests.get('https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0'
                         '/en_core_web_sm-3.0.0.tar.gz', allow_redirects=True)
        open(filename, 'wb').write(r.content)
        tar = tarfile.open(filename, 'r:gz')
        tar.extractall()
        tar.close()


def download_stopwords():
  filename = STOPWORDS_FILE
  if not os.path.exists(filename):
        r = requests.get('https://github.com/DinarZayahov/thesaurus/releases/download/0.0.1/extended_stopwords.txt', allow_redirects=True)
        open(filename, 'wb').write(r.content)

In [6]:
download_model()
download_stopwords()

In [13]:
import spacy
import numpy as np
from gensim.utils import tokenize
import time
import pickle
import os
from minisom import MiniSom

from bokeh.models import ColumnDataSource, HoverTool
from bokeh.io import show, output_notebook
from bokeh.plotting import figure

# display bokeh plot in notebook
output_notebook()


MAX_LENGTH = 1250000
LEMMATIZATION_THRESHOLD = 500000


class Thesaurus:
    def __init__(self):
        self.spacy_model = None

    @staticmethod
    def read_text(file):
        lines = []
        for line in file:
            # line = line.decode('utf-8', 'ignore')
            lines.append(line)
        return ''.join(lines)

    def set_spacy_model(self, model):
        self.spacy_model = spacy.load(model)
        self.spacy_model.max_length = MAX_LENGTH

    def lemmatize(self, text, length):
        if length < LEMMATIZATION_THRESHOLD:
            doc = self.spacy_model(text)
            result = " ".join([token.lemma_ for token in doc])
            return result
        else:
            for doc in self.spacy_model.pipe([text], batch_size=32, n_process=3, disable=["parser", "ner"]):
                result = " ".join([token.lemma_ for token in doc])
                return result

    @staticmethod
    def tokenize(text):
        tokens = list(tokenize(text, to_lower=True))
        return tokens

    @staticmethod
    def get_stopwords(path):
        stopwords_file = open(path, 'r')
        stopwords = []
        for line in stopwords_file:
            stopwords.append(line[:-1])
        return stopwords

    def remove_stopwords(self, tokens: list):
        stopwords = self.get_stopwords(STOPWORDS_FILE)
        filtered_tokens = []
        for token in tokens:
            if token not in stopwords:
                filtered_tokens.append(token)
        return filtered_tokens, list(dict.fromkeys(filtered_tokens))

    def make_embeddings(self, tokens: list) -> list:
        embeddings_filename = 'embeddings.pickle'
        if os.path.exists(embeddings_filename):
            print('Found cache..')
            embeddings_file = open(embeddings_filename, 'rb')
            changed = False
            dictionary = pickle.load(embeddings_file)
            result = []
            for token in tokens:
                if token in dictionary:
                    result.append(dictionary[token])
                else:
                    e = self.spacy_model(token).vector
                    dictionary[token] = e
                    changed = True
                    result.append(e)
            if changed:
                print('Rewriting cache..')
                embeddings_file.close()
                os.remove(embeddings_filename)
                new_embeddings_file = open(embeddings_filename, 'wb')
                pickle.dump(dictionary, new_embeddings_file)
            return result
        else:
            print('Cache not found..')
            dictionary = dict()
            for token in tokens:
                dictionary[token] = self.spacy_model(token).vector
            embeddings_file = open(embeddings_filename, 'wb')
            pickle.dump(dictionary, embeddings_file)
            return list(dictionary.values())

    @staticmethod
    def get_grid_size(n):
        neurons_num = 5*np.sqrt(n)
        return int(np.ceil(np.sqrt(neurons_num)))

    def plot_bokeh(self, embeddings_f, embeddings_b, filtered_ftext_set, filtered_btext_set):
        HEXAGON_SIZE = 54
        DOT_SIZE = 20

        GRID_SIZE = self.get_grid_size(len(embeddings_b))
        PLOT_SIZE = HEXAGON_SIZE * (GRID_SIZE + 1)

        som = MiniSom(GRID_SIZE, GRID_SIZE, np.array(embeddings_b).shape[1], sigma=5, learning_rate=.2,
                      activation_distance='euclidean', topology='hexagonal', neighborhood_function='bubble',
                      random_seed=10)

        som.train(embeddings_b, 1000, verbose=True)

        b_label = []

        b_weight_x, b_weight_y = [], []
        for cnt, i in enumerate(embeddings_b):
            w = som.winner(i)
            wx, wy = som.convert_map_to_euclidean(xy=w)
            wy = wy * np.sqrt(3) / 2
            b_weight_x.append(wx)
            b_weight_y.append(wy)
            b_label.append(filtered_btext_set[cnt])

        f_label = []

        f_weight_x, f_weight_y = [], []
        for cnt, i in enumerate(embeddings_f):
            w = som.winner(i)
            wx, wy = som.convert_map_to_euclidean(xy=w)
            wy = wy * np.sqrt(3) / 2
            f_weight_x.append(wx)
            f_weight_y.append(wy)
            f_label.append(filtered_ftext_set[cnt])

        # initialise figure/plot
        fig = figure(plot_height=PLOT_SIZE, plot_width=PLOT_SIZE,
                     match_aspect=True,
                     tools="pan")

        fig.axis.visible = False
        fig.xgrid.grid_line_color = None
        fig.ygrid.grid_line_color = None

        # create data stream for plotting
        b_source_pages = ColumnDataSource(
            data=dict(
                wx=b_weight_x,
                wy=b_weight_y,
                species=b_label
            )
        )

        f_source_pages = ColumnDataSource(
            data=dict(
                wx=f_weight_x,
                wy=f_weight_y,
                species=f_label
            )
        )

        fig.hex(x='wy', y='wx', source=b_source_pages,
                fill_alpha=1.0, line_alpha=1.0,
                size=HEXAGON_SIZE)

        fig.scatter(x='wy', y='wx', source=f_source_pages,
                    fill_color='orange',
                    size=DOT_SIZE)

        TOOLTIPS = """
            <div style ="border-style: solid;border-width: 15px;background-color:black;">         
                <div>
                    <span style="font-size: 12px; color: white;font-family:century gothic;"> @species</span>
                </div>
            </div>
            """

        # add hover-over tooltip
        fig.add_tools(HoverTool(
            tooltips=[
                ("label", '@species')],
            # tooltips=TOOLTIPS,
            mode="mouse",
            point_policy="follow_mouse"
        ))

        return fig  

In [14]:
from google.colab import files

FOREGROUND_FILE = '2108.06252v1.txt'
BACKGROUND_FILE = '2111.06414v1.txt'

foreground = open(FOREGROUND_FILE, 'r')
background = open(BACKGROUND_FILE, 'r')

MODEL = 'en_core_web_sm-3.0.0/en_core_web_sm/en_core_web_sm-3.0.0'

if (foreground is not None) and (background is not None):

    obj = Thesaurus()

    foreground = obj.read_text(foreground)
    background = obj.read_text(background)

    obj.set_spacy_model(MODEL)

    lemmatized_f = obj.lemmatize(foreground, len(foreground))

    lemmatized_b = obj.lemmatize(background, len(background))

    tokenized_f = obj.tokenize(lemmatized_f)

    tokenized_b = obj.tokenize(lemmatized_b)

    filtered_tokens_f, filtered_tokens_f_set = obj.remove_stopwords(tokenized_f)
    filtered_tokens_b, filtered_tokens_b_set = obj.remove_stopwords(tokenized_b)

    embeddings_f = obj.make_embeddings(filtered_tokens_f_set)
    embeddings_b = obj.make_embeddings(filtered_tokens_b_set)

    fig = obj.plot_bokeh(embeddings_f, embeddings_b, filtered_tokens_f_set, filtered_tokens_b_set)

    show(fig)

Found cache..
Found cache..
 [ 1000 / 1000 ] 100% - 0:00:00 left 
 quantization error: 4.022099235330194
