In [1]:
# Imports

import os
import glob
import sys
import datetime

import re
from itertools import combinations, product
from functools import lru_cache
from statistics import mean

from natsort import natsorted

import pandas as pd
import numpy as np
import scipy.stats as ss
import Levenshtein

from gensim.models import Word2Vec

from cltk.stem.latin.j_v import JVReplacer
from cltk.lemmatize.latin.backoff import BackoffLatinLemmatizer

from pprint import pprint
import urllib
from tqdm.notebook import tqdm
import pickle

In [2]:
# Constants

TEXTS_FOLDER = "../data/texts/process"

files = natsorted(glob.glob(f'{TEXTS_FOLDER}/*.tess'))
MODEL = '../models/latin_w2v_bamman_lemma300_100_1'
VECTORS = Word2Vec.load(MODEL).wv
lemmatize = True

In [3]:
# Set up NLP tools

lemmatizer = BackoffLatinLemmatizer()
replacer = JVReplacer()

In [4]:
# Get comparison dataset

comps = pd.read_csv('../data/datasets/comps.csv', index_col=0)
comps['query_lemma'] = comps['query_lemma'].apply(lambda x: x.split())
comps['target_lemma'] = comps['target_lemma'].apply(lambda x: x.split())

In [5]:
# Text functions

def preprocess(text):
    replacer = JVReplacer()    
    text = text.lower()
    text = replacer.replace(text) #Normalize u/v & i/j
    
    text = text.replace('ego---sed', 'egosed') # handle a tesserae text issue
    
    punctuation ="\"#$%&\'()*+,-/:;<=>@[\]^_`{|}~.?!«»"
    translator = str.maketrans({key: " " for key in punctuation})
    text = text.translate(translator)

    translator = str.maketrans({key: " " for key in '0123456789'})
    text = text.translate(translator)
    
    text = text.replace('â\x80\x94', ' ')
    
    text = " ".join(text.split())
    
    return text

def index_tess(text):
    
    textlines = text.strip().split('\n')
    # https://stackoverflow.com/a/61436083
    def splitkeep(s, delimiter):
        split = s.split(delimiter)
        return [substr + delimiter for substr in split[:-1]] + [split[-1]]
    textlines = [splitkeep(line, '>') for line in textlines if line]
    return dict(textlines)

def pp_tess(tess_dict):
    return {k: preprocess(v) for k, v in tess_dict.items()}

def text_lemmatize(lemma_pairs):
    return " ".join([lemma for _, lemma in lemma_pairs])

def lem_tess(tess_dict):
    return {k: text_lemmatize(lemmatizer.lemmatize(v.split())) for k, v in tess_dict.items()}

def make_tess_file(str):
    str = str.lower()
    str = str.replace('lucan', 'lucan bellum_civile')
    str = str.split('.')[0]
    str = str.replace(' ', '.', 1)
    str = str.replace(' ', '.part.', 1)
    str += '.tess'
    return str

def make_tess_index(str):
    str = str.replace('Lucan', 'luc.').replace('Ovid', 'ov.').replace('Statius', 'stat.').replace('Vergil', 'verg.')
    str = str.replace('Metamorphoses', 'met.').replace('Thebaid', 'theb.').replace('Aeneid', 'aen.')
    str = str.split('-')[0]
    str = f'<{str}>'
    return str

def get_next_tess_index(index, n):
    index = index.replace('>','')
    index_base = index.split()[:-1]
    index_ref = index.split()[-1]
    index_ref_parts = index_ref.split('.')
    index_ref_next = int(index_ref_parts[1])+n
    next_index = f'{" ".join(index_base)} {index_ref_parts[0]}.{index_ref_next}>'
    
    exceptions = ['<luc. 1.419>', '<luc. 7.855>', '<luc. 7.856>', '<luc. 7.857>', '<luc. 7.858>', '<luc. 7.859>', '<luc. 7.860>', '<luc. 7.861>', '<luc. 7.862>', '<luc. 7.863>', '<luc. 7.864>', '<luc. 9.414>',
                  '<ov. met. 4.769>', 
                  '<stat. theb. 6.184>', '<stat. theb. 6.227>', '<stat. theb. 6.228>', '<stat. theb. 6.229>', '<stat. theb. 6.230>', '<stat. theb. 6.231>', '<stat. theb. 6.232>', '<stat. theb. 6.233>', '<stat. theb. 9.760>']
    if next_index in exceptions: # Handle missing data
        return None
    else:
        return next_index

In [6]:
# Get ngrams

def generate_ngrams(words_list, n):
    # Cf. https://www.techcoil.com/blog/how-to-generate-n-grams-in-python-without-using-any-external-libraries/
    ngrams_list = []

    for num in range(0, len(words_list)):
        ngram = ' '.join(words_list[num:num + n])
        ngrams_list.append(ngram)
    
    ngrams_list = [item.split() for item in ngrams_list if len(item.split()) == n]
    
    return ngrams_list

def generate_ngrams_interval(words_list, index, n, tess_dict, interval):
    
    limit = len(words_list) + interval + 1 # add one to avoid interval-based fencepost problem

    while len(words_list) < limit:
        words_extend = get_next_tess_index(index, 1)
        if words_extend:
            words_list += tess_dict[words_extend].split()
        else:
            break
    
    words_list = words_list[:limit]
    
    ngrams_list = []

    for num in range(0, len(words_list)):
        ngram = ' '.join(words_list[num:num + n])
        ngrams_list.append(ngram)
    
    ngrams_list = [item.split() for item in ngrams_list if len(item.split()) == n]
    
    return ngrams_list

def ngram_tess(tess_dict, n=2, interval=0):
    return {k: generate_ngrams_interval(v.split(), k, n, tess_dict, interval) for k, v in list(tess_dict.items())[:-1]} # Stop short of last item because of ngram lookahead

In [7]:
# Similarity functions

@lru_cache(maxsize=10000)
def get_similarities(terms, model):
    sims = []
    terms = list(terms)
    terms_ = set([y for x in terms for y in x])
    oov = [term for term in terms_ if term not in model.vocab]

    for term in terms:
        if term[0] in oov or term[1] in oov:
            sim = -1
        else:
            sim = model.similarity(term[0], term[1])
        sims.append(sim)

    return sims

# pair-aware mean
def mean_similarities(terms, sims):
    max_sim_index = sims.index(max(sims))
    if max_sim_index == 0:
        pair_index = 3
    elif max_sim_index == 1:
        pair_index = 2
    elif max_sim_index == 2:
        pair_index = 1
    elif max_sim_index == 3:
        pair_index = 0
    return (sims[max_sim_index] + sims[pair_index])/2

In [8]:
# # Uncomment code to run full intertext search over texts

# # Get intertext search results

# results_ = []

# for i, row in tqdm(comps.iterrows(), total=comps.shape[0]):    

#     search_files = natsorted([file for file in files if row['intertext_author'] in file])
    
#     if np.isnan(row['interval']):
#         interval = 0
#     else:
#         interval = int(row['interval'])
    
#     n = row['query_length'] + interval
    
#     results = []
    
#     for file in search_files:
#         with open(file, 'r') as f:
#             contents = f.read()
#             tess_dict = index_tess(contents)    
#             tess_dict = pp_tess(tess_dict)
#             if lemmatize:
#                 tess_dict = lem_tess(tess_dict)
#         tess_dict = ngram_tess(tess_dict, n, interval)           

#         for item in list(tess_dict.items()):            
#             index = item[0]
#             ngrams = item[1]
#             for ngram in ngrams:                
#                 orderfree = row['orderfree']
#                 if orderfree:
#                     combs = list(combinations(ngram, 2))
#                 else:
#                     combs = [ngram]

#                 for comb in combs:
#                     pairs = tuple(product(row["query_lemma"], comb))
#                     dists = get_similarities(pairs, VECTORS)
#                     dists_sum = mean_similarities(pairs, dists)
#                     if dists_sum >= row["similarity"]:
#                         results.append((index, dists_sum, row['query_lemma'], list(comb)))
#     results_.append((row['index'], results))

In [9]:
# Create time-stamped file for results; cf. https://stackoverflow.com/a/14115286
# output_path = f"{os.path.join('temp', datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))}-results.p"
# pickle.dump(results_, open(output_path, 'wb'))

# For paper results, uncomment to download and run remaining cells
# url = 'https://utexas.box.com/shared/static/2m2e09dijfxgqg3nnxei2cuttc7qu7kd.p'
# urllib.request.urlretrieve (url, 'temp/results_naacl2021_search.p')

# output_path = 'temp/results_naacl2021_search.p'
# results_ = pickle.load(open(output_path, 'rb'))
# print(f'Temp results saved at {output_path}')

Temp results saved at temp/results_naacl2021_search.p


In [10]:
# Get ranks

ranks = [len(result[1]) for result in results_ if len(result[1]) != 0]

In [11]:
# Compute recall & precision at k; computer MRR

ks = [1, 3, 5, 10, 25, 50, 75, 100, 250]

def recall_at_k(ranks, k):
    n = len([rank for rank in ranks if rank <= k])
    d = len(ranks)
    recall = n/d
    return recall

def precision_at_k(ranks, k):
    n = len([rank for rank in ranks if rank <= k])
    d = sum([rank if rank<=k else k for rank in ranks])
    precision = n/d
    return precision

def mrr(ranks):
    return mean([1/item for item in ranks])

print(f'MRR: {mrr(ranks)}')
print()

print(f'Checking the following values for k {ks}\n')
for k in ks:
    print(f'\tRecall at k={k}: {recall_at_k(ranks, k)}')
    print(f'\tPrecision at k={k}: {precision_at_k(ranks, k)}')
    print()

MRR: 0.3911290607923982

Checking the following values for k [1, 3, 5, 10, 25, 50, 75, 100, 250]

	Recall at k=1: 0.2913135593220339
	Precision at k=1: 0.2913135593220339

	Recall at k=3: 0.4385593220338983
	Precision at k=3: 0.18809631985461153

	Recall at k=5: 0.5180084745762712
	Precision at k=5: 0.15219421101774042

	Recall at k=10: 0.5953389830508474
	Precision at k=10: 0.10575837410613474

	Recall at k=25: 0.673728813559322
	Precision at k=25: 0.0608554205339202

	Recall at k=50: 0.7245762711864406
	Precision at k=50: 0.038876889848812095

	Recall at k=75: 0.75
	Precision at k=75: 0.029737903225806453

	Recall at k=100: 0.7711864406779662
	Precision at k=100: 0.024760220393170534

	Recall at k=250: 0.8241525423728814
	Precision at k=250: 0.013574107999651051

