<a href="https://colab.research.google.com/github/MNoichl/tttms_public/blob/main/formulas_training_and_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install necessary packages

In [None]:

!pip install --upgrade gensim
!pip install --upgrade umap-learn
!pip install compress-pickle
!pip install cylouvain
!sudo apt-get install latexml
!sudo apt-get install libtext-unidecode-perl 
!apt-get update
!sudo apt-get install latexml
!latexmlmath \\frac{-b\\pm\\sqrt{b^2-4ac}}{2a}
!pip install hdbscan
!pip install loess
!pip install cmocean

# Load packages

In [None]:
import tqdm
import numpy as np
import pandas as pd
import re
import os
import sys

import matplotlib.pyplot as plt
import seaborn as sns

import umap
from scipy import stats

import subprocess
import platform
import shutil

from compress_pickle import dump, load
import zipfile
import cylouvain
import networkx as nx

from IPython.display import display, Math, Latex


from collections import Counter



# Load Tangent-functions

In [None]:
!gdown --fuzzy https://drive.google.com/file/d/1NSx7sQw8Kk1zQl5aaLWftj8T9SHvaeCC/view?usp=sharing -O "TangentS.zip"
!unzip TangentS.zip

In [None]:

from xml.dom import minidom
from io import StringIO
from xml.dom.minidom import parseString


from TangentS.math_tan.math_extractor import MathExtractor
from TangentS.math_tan.symbol_tree import SymbolTree

def convert_latex_formula_to_tuple_list(tex_query):
    # print("Convert LaTeX to MathML:$"+tex_query+"$",flush=True)
    qvar_template_file = os.path.join(os.path.abspath(''), "mws.sty.ltxml")
    if not os.path.exists(qvar_template_file):
        print('Tried %s' % qvar_template_file, end=": ")
        sys.exit("Stylesheet for wildcard is missing")

    # Make sure there are no isolated % signs in tex_query (introduced by latexmlmath, for example, in 13C.mml test file) (FWT)
    tex_query = re.sub(r'([^\\])%', r'\1', tex_query)  # remove % not preceded by backslashes (FWT)
    
    with open('temporary_tex.tex', "w", encoding='utf-8') as text_file:
        text_file.write('\\begin{equation} ' +tex_query + ' \\end{equation}')

    cmd = 'timeout 5s latexml' + ' --preload=amsmath' + ' --preload=amsfonts' + ' --destination=temporary_xml.xml' +' temporary_tex.tex '+ ' --preload=mws.sty.ltxml'
    !{cmd}
    cmd = 'timeout 5s latexmlpost' + ' --contentmathml' + ' --destination=temporary_xml_2.xml' + ' temporary_xml.xml'
    !{cmd}


    xmldoc = minidom.parse('temporary_xml_2.xml')
    equation_type_objects = []
    paraNode = xmldoc.getElementsByTagName('para')[0]
    for child in paraNode.childNodes:
        if 'equation' in str(child):
            equation_type_objects.append(child)

    for i,this_equation_object in enumerate(equation_type_objects): 
        xml_formula_list = this_equation_object.getElementsByTagName('Math')
        for j,this_xml_formula in enumerate(xml_formula_list):
            this_xml_formula = this_xml_formula.firstChild


            this_xml_formula.setAttribute("xmlns", "http://www.w3.org/1998/Math/MathML")
            out = StringIO()
            this_xml_formula.writexml(out)
            mathml = out.getvalue()
            mathml = mathml.replace('m:','')
            mathml = mathml.replace('<semantics>','').replace('<\semantics>','')
            mathml = '<?xml version="1.0" encoding="UTF-8"?>\n\n' + mathml
            cmml = MathExtractor.isolate_cmml(mathml)
            cmml = re.sub('<share href=\".*\"><\/share>','',cmml)
            current_tree = MathExtractor.convert_to_semanticsymbol(cmml)
            temp = SymbolTree(current_tree)
            tuple_list = temp.get_pairs(window=2, eob=True)

    return tuple_list



In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Aggregate data 
(unnecessary for re-runs: Just load data below)

In [None]:
from os.path import exists

files_to_load = [
                 (0,48000),
                 (48000,50000),
                 (50000,58000),
  	             (58000,74000),
                 (74000,78000),
                 (78000,84000),
                 (84000,98000),
                 (98000,100000),
                 (100000,130000),
                 (130000,134000),
                 (134000,142000),
                 (142000,150000),
                 (150000,152000),
                 (152000,158000),
                 (158000,160000),
                 (160000,162000),
                 (162000,172000),
                 (172000,176000),
                 (176000,190000),
                 (190000,196000),
                 (194000,200000), 
                 (200000,202000),
                 (202000,208000),
                 (208000,220000),
                 (220000,224000), 
                 (224000,226000),
                 (226000,228000),
                 (228000,230000),
                 (230000,236000),
                 (236000,240000),
                 (240000,244000),
                 (244000,246000),
                 (246000,250000),
                 (250000,258000),
                 (258000,266000),
                 (266000,270000),
                 (270000,278000),
                 (278000,280000),
                 (280000,290000),
                 (290000,300000),
                 (300000,316000),
                 (316000,326000),
                 (326000,330000),
                 (330000,336000),
                 (336000,350000),
                 (350000,370000),
                 (370000,383960)
                 ]

for start, stop in tqdm.tqdm_notebook(files_to_load):
  file_exists = exists("drive/MyDrive/combined_formula_parsing_results_v3/formula_tuple_list_" + str(start) +"_"+ str(stop) + ".bz")
  print(file_exists)

base_df = load("drive/MyDrive/combined_formula_parsing_results_v3/formula_tuple_list_"+str(files_to_load[0][0])+"_"+str(files_to_load[0][1])+".bz")


for start, stop in tqdm.tqdm_notebook(files_to_load[1:]):
  this_df = load("drive/MyDrive/combined_formula_parsing_results_v3/formula_tuple_list_" + str(start) +"_"+ str(stop) + ".bz")
  base_df[start:stop] = this_df[start:stop]
  this_df = pd.DataFrame(this_df)
  # this_df.columns = ['formula_tuples']
  print(np.sum([len(x) for x in this_df['formula_tuples'] if x != 'no formulas']))







In [None]:
formula_df = pd.DataFrame(base_df)

formula_df


In [None]:
full_data = load("drive/MyDrive/ARXIV_FORMULA_PARSING/full_data_rejoined_with_clusters.bz")

In [None]:
full_data[['filtered_formulas','formula_tuples','actually_transformed_formulas']] = formula_df

In [None]:
full_data

In [15]:
dump(full_data, "drive/MyDrive/ARXIV_FORMULA_PARSING/full_data_joined_with_formulas.bz")

# Load dataset (start here for replicatory runs)

In [None]:
!gdown --fuzzy https://drive.google.com/file/d/1-5ZrFDj_6tff7tCRd8nqlVf6Ls-l8-x1/view?usp=sharing -O "full_data_joined_with_formulas.bz"
full_data = load("full_data_joined_with_formulas.bz")

In [None]:
formula_collection = []
for ix, row in tqdm.tqdm_notebook(full_data.iterrows()):
  if row['actually_transformed_formulas'] == 'no formulas':
    pass
  else:
    if len(row['actually_transformed_formulas']) == len(row['formula_tuples']):
      usable_label = True
    else:
      usable_label = False
    for formula, tuples in zip(row['actually_transformed_formulas'] ,row['formula_tuples'] ):
      if len(tuples) > 1:
        formula_collection.append({'formula':formula,
                                  'tuples':tuples,
                                  'cluster':row['cluster'],
                                  'color':row['color'],
                                  'id':row['id'],
                                  'origin':row['origin'],
                                   'usable_label': usable_label})
    else:
      pass

In [None]:
# Count available formulas:
print('biorxiv: ',len([item for item in formula_collection if item['origin'] == 'biorxiv']))
print('arxiv: ',len([item for item in formula_collection if item['origin'] == 'arxiv']))

In [None]:
formula_collection = [x for x in formula_collection if x['usable_label'] == True]

In [None]:
flat_formula_tuples = [x['tuples'] for x in formula_collection] 
len(flat_formula_tuples)

# Training the model
(uneccessary for re-runs, we can load the pretrained one below)

In [None]:
from gensim.test.utils import get_tmpfile
from gensim.models.callbacks import CallbackAny2Vec

class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        model.save(r'drive/MyDrive/formulaFT_models/combined_equation_fasttext_model_v02_'+str(self.epoch)+'.model')
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1
        
epoch_logger = EpochLogger()  
        
from gensim.models import FastText
my_tokens = flat_formula_tuples
print('tokens retrieved')
model = FastText(vector_size=300, window=13, sg=1,
                 hs=1,workers=8, negative=15,
                 min_n=5, max_n=40, #word_ngrams=3,
                 min_count=7)


model.build_vocab(my_tokens) 
print('voc built, training')

model.train(my_tokens,total_examples=model.corpus_count,epochs =9, callbacks=[epoch_logger])
print('training complete')

In [16]:
# Saving the model after training:
# model.save(r'drive/MyDrive/formulaFT_models/combined_equation_fasttext_model_v01.model')


# Load the pretrained model

In [23]:
!gdown --fuzzy https://drive.google.com/file/d/1-yzwqBBsCh9FQIs6iBmUYh0M0xWhAFLS/view?usp=sharing -O "combined_equation_fasttext_model_v02_6.model"
!gdown --fuzzy https://drive.google.com/file/d/1-vkI7clJx62w2GIuiVj6Ia9c7mFYnhYh/view?usp=sharing -O "combined_equation_fasttext_model_v02_6.model.wv.vectors_vocab.npy"
!gdown --fuzzy https://drive.google.com/file/d/1-x4htbDWiN3gDb_WpSMFfrJ99QvycHiD/view?usp=sharing -O 'combined_equation_fasttext_model_v02_6.model.wv.vectors_ngrams.npy'
!gdown --fuzzy https://drive.google.com/file/d/1-xLAmr4WvvH10MpRQwd_5JH0TvIv-Xls/view?usp=sharing -O 'combined_equation_fasttext_model_v02_6.model.syn1.npy'
!gdown --fuzzy https://drive.google.com/file/d/1-yf6oFOqpKfTpsyAsQeXMYl3Ri464vK4/view?usp=sharing -O 'combined_equation_fasttext_model_v02_6.model.syn1neg.npy'

Downloading...
From: https://drive.google.com/uc?id=1-yzwqBBsCh9FQIs6iBmUYh0M0xWhAFLS
To: /content/combined_equation_fasttext_model_v02_6.model
100% 149M/149M [00:00<00:00, 241MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-vkI7clJx62w2GIuiVj6Ia9c7mFYnhYh
To: /content/combined_equation_fasttext_model_v02_6.model.wv.vectors_vocab.npy
100% 1.00G/1.00G [00:03<00:00, 256MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-x4htbDWiN3gDb_WpSMFfrJ99QvycHiD
To: /content/combined_equation_fasttext_model_v02_6.model.wv.vectors_ngrams.npy
100% 2.40G/2.40G [00:10<00:00, 221MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-xLAmr4WvvH10MpRQwd_5JH0TvIv-Xls
To: /content/combined_equation_fasttext_model_v02_6.model.syn1.npy
100% 1.00G/1.00G [00:05<00:00, 169MB/s]
Downloading...
From: https://drive.google.com/uc?id=1-yf6oFOqpKfTpsyAsQeXMYl3Ri464vK4
To: /content/combined_equation_fasttext_model_v02_6.model.syn1neg.npy
100% 1.00G/1.00G [00:11<00:00, 88.6MB/s]


In [24]:
from gensim.test.utils import get_tmpfile
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models import FastText

class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_begin(self, model):
        print("Epoch #{} start".format(self.epoch))

    def on_epoch_end(self, model):
        print("Epoch #{} end".format(self.epoch))
        self.epoch += 1

epoch_logger = EpochLogger()  
        
model = FastText.load(r'combined_equation_fasttext_model_v02_6.model')

# Draw a random subsample

In [None]:
w_rand = np.random.randint(0,len(flat_formula_tuples),500000)
flat_formula_tuples_small = [flat_formula_tuples[x] for x in w_rand]
w_in_collection = [formula_collection[x] for x in w_rand] # select data for random sample

# Convert tuples to vectors

In [None]:

from scipy.linalg import norm
from scipy.stats import gmean
from joblib import Parallel, delayed
import multiprocessing
import functools
from multiprocessing.dummy import Pool  # This is a thread-based Pool
from multiprocessing import cpu_count
import time
         
            
def get_mean_vectors(model,x):
    vectors = []
    for y in x:
        try: 
            vector = model.wv[y]
            vectors.append(vector / np.linalg.norm(vector))
        except Exception as e:
            print(e)
    return np.mean(np.array(vectors),axis=0) #powermean(np.array(vectors), 2, axis = 0)

pool = Pool(8)
formula_vectors = list(tqdm.tqdm_notebook(pool.imap(functools.partial(get_mean_vectors, model), flat_formula_tuples_small)))

In [None]:
embeddings = np.vstack(formula_vectors)

# Compute layout

In [None]:
from sklearn.decomposition import TruncatedSVD
SVD = TruncatedSVD(n_components= 150, n_iter=7, random_state=42)
XSVD = SVD.fit_transform(embeddings)
print(XSVD.shape)

In [None]:
import umap

umapped_equations_full = umap.UMAP(densmap=False,random_state=42,
                    n_components=2,
                    n_neighbors=10,
                    min_dist=0.1,
                    init=XSVD[:,0:2],
                    metric='cosine',#n_epochs=200,#disconnection_distance =0.3,
                    verbose=True,
                    low_memory=True)

umapped_equations_full.fit(XSVD)#[:,1:])#[:,1:])#
umapped_equations = umapped_equations_full.embedding_

In [None]:
sns.set(font="STIXGeneral",font_scale=2.1)
sns.set_style("white")
fig, ax = plt.subplots(figsize=(30,20))

hfont = {'fontname':'STIXGeneral'}


plt.scatter(umapped_equations[:, 0], umapped_equations[:, 1], s=0.6
            , c=[x['color'] for x in w_in_collection],alpha=0.9)


# Clustering


In [None]:
import hdbscan
clusterer = hdbscan.HDBSCAN(min_cluster_size=40,
                            min_samples=50, prediction_data=False).fit(umapped_equations)
print(len(set(clusterer.labels_)))
clustering_solution = clusterer.labels_
# soft_clusters = hdbscan.all_points_membership_vectors(clusterer)
# print(soft_clusters.shape)

In [None]:
sns.set(font="STIXGeneral",font_scale=2.1)
sns.set_style("white")
fig, ax = plt.subplots(figsize=(30,20))

hfont = {'fontname':'STIXGeneral'}


clustered = (clustering_solution >= 0)
plt.scatter(umapped_equations[~clustered, 0],
            umapped_equations[~clustered, 1],
            color=(0.5, 0.5, 0.5),
            s=0.1,
            alpha=0.5)
plt.scatter(umapped_equations[clustered, 0],
            umapped_equations[clustered, 1],
            c=clustering_solution[clustered],
            s=0.1,
            cmap='Spectral');

plt.axis('equal')

# Build nearest neighbours graph of formulas

In [None]:
# %%timeit
import pynndescent
index = pynndescent.NNDescent(embeddings,#embeddings
                              metric="cosine",n_neighbors=90)
index.prepare()

# Calculate textual distances to formulas' nearest neighbours

In [None]:
!gdown --fuzzy https://drive.google.com/file/d/12FKOy0jN9n1AJnSV-JeJlu1yw-1C39Q2/view?usp=sharing -O "thematic_SVD_vectors.bz"

thematic_SVD = load("thematic_SVD_vectors.bz")

In [None]:
accurate_neighbors = index.query(embeddings,
                                 epsilon=0.2,k =40) 

In [None]:
from sklearn.metrics.pairwise import pairwise_distances
import warnings
import numpy as np
warnings.simplefilter(action='ignore', category=FutureWarning)

n_neighbors_to_consider = 5
average_thematic_dists = []
null_model = []

for ix, these_neighbours in tqdm.tqdm_notebook(enumerate(accurate_neighbors[0])):

  id_s = [w_in_collection[x]['id'] for x  in these_neighbours]
  w_id_source = np.where(full_data['id'].isin([w_in_collection[ix]['id']]))[0]
  id_s = [x for x in sorted(set(id_s), key=id_s.index)][1:n_neighbors_to_consider+1]
  w_ids_targets = np.where(full_data['id'].isin(id_s))[0]
  
  dists = pairwise_distances(thematic_SVD[w_id_source,:], thematic_SVD[w_ids_targets,:], metric='cosine')
  average_thematic_dists.append(np.mean(dists))

  dists = pairwise_distances(thematic_SVD[w_id_source,:], thematic_SVD[np.random.randint(0,len(thematic_SVD),n_neighbors_to_consider),:], metric='cosine')
  null_model.append(np.mean(dists))

# Plot Similarities


In [None]:
sns.set(font_scale=.5)
sns.set_style("ticks")

fig, ax = plt.subplots(figsize=(10,5))
sns.distplot(average_thematic_dists,color='#cc4157',hist=True, kde_kws={"lw": 3},)
sns.distplot(null_model,color='#b3c3b9',hist=True, kde_kws={'color':'#686868',"lw": 3},)
plt.xticks([np.round(x,decimals=2) for x in np.linspace(0,1,51)],)

plt.savefig('dist_comp_random_matter_10_with_hist.png', dpi=300)


from scipy import stats
stats.ttest_ind(average_thematic_dists, null_model)

In [None]:
treated_part_cl_sol = clustering_solution[0:len(average_thematic_dists)]
mean_average_dist_by_cluster = np.zeros(len(treated_part_cl_sol))
for small_clust in tqdm.tqdm_notebook(np.unique(treated_part_cl_sol)):
    where_small_clust = np.where(np.array(treated_part_cl_sol)==small_clust)[0]
    mean_average_dist_by_cluster[where_small_clust]= np.mean([average_thematic_dists[x]for x in where_small_clust])


In [None]:

import cmocean

sns.set(font="STIXGeneral",font_scale=2.1)
sns.set_style("white")
fig, ax = plt.subplots(figsize=(30,20))


clustered = (treated_part_cl_sol >= 0)

hfont = {'fontname':'STIXGeneral'}


plt.scatter(umapped_equations[0:len(average_thematic_dists), 0][~clustered],
            umapped_equations[0:len(average_thematic_dists), 1][~clustered],
            color=(0.5, 0.5, 0.5),
            s=0.1,
            alpha=0.3)


plt.scatter(umapped_equations[0:len(average_thematic_dists), 0][clustered],
            umapped_equations[0:len(average_thematic_dists), 1][clustered],
            c= np.array(mean_average_dist_by_cluster)[clustered],#average_thematic_dists ,#[average_thematic_dists[x] for x in np.argsort(average_thematic_dists)],#[np.log(x) for x in average_thematic_dists],
            s=1,alpha=.3,
            cmap=cmocean.cm.matter_r,
            vmin=np.percentile(mean_average_dist_by_cluster, 1),
            vmax=np.percentile(mean_average_dist_by_cluster, 99)
            )
cbar = plt.colorbar()
cbar.set_alpha(1)
cbar.draw_all()
plt.axis('equal')

plt.savefig('clustering_with_distances_matter.png', dpi=300)


# Get formulas closest to clusters' centroids, for labeling

In [None]:
from sklearn.metrics.pairwise import pairwise_distances

cluster_centers = []
for small_clust in tqdm.tqdm_notebook(np.unique(clustering_solution)):
    where_small_clust = np.where(np.array(clustering_solution)==small_clust)[0]
    cluster_centers.append(np.median(embeddings[where_small_clust,:],axis=0))
cluster_centers = np.array(cluster_centers)

In [None]:
# plot cluster-positions
sns.set(font="STIXGeneral",font_scale=.4
        )
sns.set_style("white")
fig, ax = plt.subplots(figsize=(30,20))

hfont = {'fontname':'STIXGeneral'}

plt.scatter(umapped_equations[0:len(average_thematic_dists), 0][~clustered],
            umapped_equations[0:len(average_thematic_dists), 1][~clustered],
            color=(0.5, 0.5, 0.5),
            s=0.1,
            alpha=0.5)


plt.scatter(umapped_equations[0:len(average_thematic_dists), 0][clustered],
            umapped_equations[0:len(average_thematic_dists), 1][clustered],
            c= np.array(mean_average_dist_by_cluster)[clustered],#average_thematic_dists ,#[average_thematic_dists[x] for x in np.argsort(average_thematic_dists)],#[np.log(x) for x in average_thematic_dists],
            s=1,alpha=1,
            cmap=cmocean.cm.haline,
            vmin=np.percentile(mean_average_dist_by_cluster, 1),
            vmax=np.percentile(mean_average_dist_by_cluster, 99)
            )
plt.colorbar()

for ix,small_clust in tqdm.tqdm_notebook(enumerate(np.unique(clustering_solution))):
    where_small_clust = np.where(np.array(clustering_solution)==small_clust)[0]
    center = np.median(umapped_equations[where_small_clust,:],axis=0)
    # print(center[0])

    plt.text(x=center[0],
          y=center[1],
          s=str(ix), horizontalalignment='center',
     verticalalignment='center')
  
plt.axis('equal')
plt.savefig('clustering_with_labels.png', dpi=300)



# Prepare relative cluster-membership-graphics

In [None]:
c = Counter([x['cluster'] for x in w_in_collection])
c_zero = Counter({x:0 for x in c})

In [None]:
text_clusters = list(dict(Counter([x['cluster'] for x in w_in_collection])).keys())
text_colors = list(dict(Counter([x['color'] for x in w_in_collection])).keys())

In [None]:
palette = dict(zip(text_clusters,text_colors))

In [None]:
cluster_composition_count_frames = []
for ix,small_clust in tqdm.tqdm_notebook(enumerate(np.unique(clustering_solution))):
    where_small_clust = np.where(np.array(clustering_solution)==small_clust)[0]
    present_text_clusters = [w_in_collection[x]['cluster'] for x in where_small_clust]
    this_counter = c_zero.copy()
    this_counter.update(present_text_clusters)
    for key in this_counter:

      this_counter[key]  = this_counter[key] / c[key]
    count_frame = pd.DataFrame(dict(this_counter), index=[0]).T
    count_frame = count_frame.reset_index()
    count_frame.columns = ['cluster','counts']
    count_frame['x'] = 1
    count_frame['cluster'] = pd.Categorical(count_frame['cluster'],[-1,10, 1,  3, 4,2,0,5,  6, 7,11,16, 17,20, 8, 9, 
                                                                     19,12, 13,  15, 
                                                                       21, 14,22,18,23]) #reorder, to better fit the visual order of the textual UMAP

    cluster_composition_count_frames.append(count_frame)



# Print out full report on the formula-clustering

In [None]:
neighbors = index.query(cluster_centers, epsilon=0.2, k=100)

In [None]:
for cluster, cluster_center_nn in enumerate(neighbors[0]):
  print('Cluster: ', cluster)
  for ix in cluster_center_nn[0:10]:
    formula = w_in_collection[ix]['formula'].replace(r'\begin{equation}','').replace(r'\end{equation}','')
    formula = '$' + formula.replace(r'\begin{eqnarray}','').replace(r'\end{eqnarray}','')+'$'
    formula = re.sub(
            r"\\label{.*?}", 
            '', 
            formula)

    display(Math(formula))
    print(formula)

  fig, ax = plt.subplots(figsize=(6,1))
  ax = sns.histplot(cluster_composition_count_frames[cluster], y='x', hue='cluster', weights='counts',
            multiple='stack', palette=palette, shrink=4)
  plt.legend([],[], frameon=False)
  ax.axis('off')
  plt.show()
  print('\n\n')


# Test model by querying for the normal-distribution-formula


In [None]:

query = r"f(x)={\frac {1}{\sigma {\sqrt {2\pi }}}}e^{-{\frac {1}{2}}\left({\frac {x-\mu }{\sigma }}\right)^{2}"
display(Math(query))



print('Turning query into tuples.')
query_tuples = convert_latex_formula_to_tuple_list(query)

print('Encoding tupels.')
encoded_query = get_mean_vectors(model,query_tuples) 

print('Searching closest matches...')
neighbors = index.query(encoded_query.reshape(1,-1), epsilon=0.9,k=200)

In [None]:
for ix, this_neighbor in enumerate(neighbors[0][0][0:30]):
  print(np.round(neighbors[1][0][ix],decimals = 3))
  formula =  w_in_collection[this_neighbor]['formula']
  formula = formula.replace(r'\begin{equation}','').replace(r'\end{equation}','')
  formula = re.sub(
          r"\\label{.*?}", 
          '', 
          formula)
  display(Math('$'+formula+'$'))
  print(len(w_in_collection[this_neighbor]['tuples']))
  print(w_in_collection[this_neighbor]['tuples']) # look at embeddings!
  print(w_in_collection[this_neighbor]['formula'])

  print('\n')
