In [None]:
# Imports

from itertools import product
from statistics import mean

import pandas as pd
import numpy as np
import Levenshtein

from gensim.models import Word2Vec

from cltk.stem.latin.j_v import JVReplacer
from cltk.lemmatize.latin.backoff import BackoffLatinLemmatizer

In [None]:
# Constants

TEXTS_FOLDER = "../data/texts/process"
MODEL = '../models/latin_w2v_bamman_lemma300_100_1'
VECTORS = Word2Vec.load(MODEL).wv

In [None]:
# Set up NLP tools

lemmatizer = BackoffLatinLemmatizer()
replacer = JVReplacer()

In [None]:
# Get comparison dataset

comps_csv = '../data/datasets/vf_intertext_dataset_1_0.csv'
comps = pd.read_csv(comps_csv)

In [None]:
# Similarity functions

def get_similarities(terms, model):
        sims = []
        terms = list(terms)
        terms_ = set([y for x in terms for y in x])
        oov = [term for term in terms_ if term not in model.vocab]

        for term in terms:
            if term[0] in oov or term[1] in oov:
                sim = -1
            else:
                sim = model.similarity(term[0], term[1])
            sims.append(sim)
        
        return sims

# pair-aware mean
def mean_similarities(terms, sims):
    max_sim_index = sims.index(max(sims))
    if max_sim_index == 0:
        pair_index = 3
    elif max_sim_index == 1:
        pair_index = 2
    elif max_sim_index == 2:
        pair_index = 1
    elif max_sim_index == 3:
        pair_index = 0
    return (sims[max_sim_index] + sims[pair_index])/2

In [None]:
# Text functions

def preprocess(text):
    replacer = JVReplacer()    
    text = text.lower()
    text = replacer.replace(text) #Normalize u/v & i/j
    
    text = text.replace('ego---sed', 'egosed') # handle a tesserae text issue
    
    punctuation ="\"#$%&\'()*+,-/:;<=>@[\]^_`{|}~.?!«»"
    translator = str.maketrans({key: " " for key in punctuation})
    text = text.translate(translator)

    translator = str.maketrans({key: " " for key in '0123456789'})
    text = text.translate(translator)
    
    text = text.replace('â\x80\x94', ' ')
    
    text = " ".join(text.split())
    
    return text

def index_tess(text):
    
    textlines = text.strip().split('\n')
    # https://stackoverflow.com/a/61436083
    def splitkeep(s, delimiter):
        split = s.split(delimiter)
        return [substr + delimiter for substr in split[:-1]] + [split[-1]]
    textlines = [splitkeep(line, '>') for line in textlines if line]
    return dict(textlines)

def pp_tess(tess_dict):
    return {k: preprocess(v) for k, v in tess_dict.items()}

def text_lemmatize(lemma_pairs):
    return " ".join([lemma for _, lemma in lemma_pairs])

def make_ref(author, work, book, line):
    if author == 'Lucan':
        work = ''
    
    if np.isnan(book):
        book_line = f'{line}'
    else:
        book_line = f'{int(book)}.{int(line)}'
        
    ref = " ".join(f'{author} {work} {book_line}'.split())
    
    return ref

def make_tess_file(str):
    #     Vergil Aeneid 1.1
    # vergil.aeneid.part.1.tess
    str = str.lower()
    str = str.replace('lucan', 'lucan bellum_civile')
    str = str.split('.')[0]
    str = str.replace(' ', '.', 1)
    str = str.replace(' ', '.part.', 1)
    str += '.tess'
    return str

def make_tess_index(str):
    str = str.replace('Lucan', 'luc.').replace('Ovid', 'ov.').replace('Statius', 'stat.').replace('Vergil', 'verg.')
    str = str.replace('Metamorphoses', 'met.').replace('Thebaid', 'theb.').replace('Aeneid', 'aen.')
    str = str.split('-')[0]
    str = f'<{str}>'
    return str

def get_next_tess_index(index, n):
    index = index.replace('>','')
    index_base = index.split()[:-1]
    index_ref = index.split()[-1]
    index_ref_parts = index_ref.split('.')
    index_ref_next = int(index_ref_parts[1])+n
    next_index = f'{" ".join(index_base)} {index_ref_parts[0]}.{index_ref_next}>'
    
    exceptions = ['<luc. 1.419>', '<luc. 7.855>', '<luc. 7.856>', '<luc. 7.857>', '<luc. 7.858>', '<luc. 7.859>', '<luc. 7.860>', '<luc. 7.861>', '<luc. 7.862>', '<luc. 7.863>', '<luc. 7.864>', '<luc. 9.414>',
                  '<ov. met. 4.769>', 
                  '<stat. theb. 6.184>', '<stat. theb. 6.227>', '<stat. theb. 6.228>', '<stat. theb. 6.229>', '<stat. theb. 6.230>', '<stat. theb. 6.231>', '<stat. theb. 6.232>', '<stat. theb. 6.233>', '<stat. theb. 9.760>']
    if next_index in exceptions: # Handle missing data
        return None
    else:
        return next_index    

In [None]:
# Interval/order functions

def get_interval(file, index, result, orderfree):

    result = replacer.replace(result.lower())
    
    with open(f'{TEXTS_FOLDER}/{file}') as f:
        contents = f.read()
        tess_dict = index_tess(contents)
        tess_dict = pp_tess(tess_dict)
        item = tess_dict[index]
        
        print(result)
        print(item)

        # Check for adjacent words in ref
        if result in item:
            return int(0)
        result = result.split()

        # Check for non-adjacent words in ref
        tokens = item.split()
        if result[0] in tokens and result[1] in tokens:
            result_index_1 = tokens.index(result[0])
            result_index_2 = tokens.index(result[1])
            interval = abs(result_index_1 - result_index_2) - 1
            return int(interval)

        # Add up to 5 lines to check for words in ref
        for i in range(1,6): # 5 line context sufficient?
            item_extend = tess_dict[get_next_tess_index(index, i)]
            print(item_extend)
            if item_extend:
                item = " ".join([item, item_extend])
                tokens = item.split()
                if result[0] in tokens and result[1] in tokens:
                    result_index_1 = tokens.index(result[0])
                    result_index_2 = tokens.index(result[1])
                    interval = abs(result_index_1 - result_index_2) - 1
                    return int(interval)
                
    return -1


def get_orderfree(query, result):    
    result_b = " ".join(result.split()[::-1])
    comp_a = Levenshtein.distance(query, result)
    comp_b = Levenshtein.distance(query, result_b)
    if comp_a > comp_b:
        return True
    else:
        return False
    
def update_orderfree(orderfree, interval):
    if interval:
        orderfree = True
    return orderfree

In [None]:
comps = comps[comps['Query Phrase'].notna()]
comps = comps[comps['Result Phrase'].notna()]

comps['query_length'] = comps['Query Phrase'].apply(lambda x: len(x.split()))
comps['result_length'] = comps['Result Phrase'].apply(lambda x: len(x.split()))

comps = comps[comps['query_length'] == 2]
comps = comps[comps['result_length'] == 2]

comps['ref'] = comps.apply(lambda x: make_ref(x['Intertext: Author'], x['Intertext: Work'],x['Intertext: Book'], x['Intertext: Line Start']), axis=1)
comps['index'] = comps.apply(lambda x: make_tess_index(x['ref']), axis=1)
comps['file'] = comps['ref'].apply(lambda x: make_tess_file(x))

comps = comps[comps['Intertext: Author'] != 'Seneca']
comps = comps[comps['Intertext: Author'] != 'Valerius']

comps['intertext_author'] = comps['ref'].apply(lambda x: x.lower().split()[0])
comps['intertext_book'] = comps['ref'].apply(lambda x: x.split()[-1].split('.')[0])

comps['Query'] = comps['Query Phrase'].apply(lambda x: replacer.replace(x.lower()))
comps['Result'] = comps['Result Phrase'].apply(lambda x: replacer.replace(x.lower()))

comps['orderfree'] = comps['Order Free']
comps['interval'] = comps['Interval']

comps['query_lemma'] = comps['Query'].apply(lambda x: " ".join([lemma[1] for lemma in lemmatizer.lemmatize(x.lower().split())]))
comps['target_lemma'] = comps['Result'].apply(lambda x: " ".join([lemma[1] for lemma in lemmatizer.lemmatize(x.lower().split())]))
comps['pairs'] = comps.apply(lambda x: tuple(product(x['query_lemma'].split(), x['target_lemma'].split())), axis=1)
comps['similarities'] = comps.apply(lambda x: get_similarities(x['pairs'], VECTORS), axis=1)
comps['similarity'] = comps.apply(lambda x: mean_similarities(x['pairs'], x['similarities']), axis=1)

In [None]:
# Exclude NaN intervals

excluded_rows = comps[comps['interval'] == -1]
excluded_rows.to_csv('../data/datasets/comps_excluded.csv')

In [None]:
# Export preprocessed data

comps = comps[comps['interval'] != -1]
# comps['interval'] = comps['interval'].astype(int)
comps.to_csv('../data/datasets/comps.csv')

In [None]:
print(f'Comparison dataset processed...\n\nThere are {comps.shape[0]} rows in comparison dataset.\n{excluded_rows.shape[0]} rows have been excluded.')