This notebook builds the database used in the survey.

In [1]:
import embeddings
import faiss
import json
import Levenshtein
import numpy as np
import random

from process_words import *
from process_sentences import *
from similarity_search import *
import utils

In [8]:
def clean_sentence(sen):
    sen = sen.replace("``", "“")
    sen = sen.replace("''", "“")
    sen = sen.replace(" .", ".")
    sen = sen.replace(" ,", ",")
    sen = sen.replace(" ?", "?")
    sen = sen.replace(" !", "!")
    sen = sen.replace(" ’ s", "’s")
    sen = sen.replace(" ’ t", "’t")
    sen = sen.replace(" n't", "n't")
    sen = sen.replace(" ’ ll", "’ll")
    sen = sen.replace(" 've", "'ve")
    sen = sen.replace(" 'll", "'ll")
    sen = sen.replace(" ’ m", "’m")
    sen = sen.replace(" 's", "'s")
    sen = sen.replace(" ’ ve", "’ve")
    sen = sen.replace(" 'm", "'m")
    sen = sen.replace(" ’ d", "’d")
    sen = sen.replace(" 're", "'re")
    sen = sen.replace(" 'd", "'d")
    sen = sen.replace(" ;", ";")
    sen = sen.replace(" :", ":")
    sen = sen.replace("( ", "(")
    sen = sen.replace(" )", ")")
    sen = sen.replace(" ”", "”")
    sen = sen.replace("“ ", "“")
    sen = sen.replace(" “", "“")
    return sen

def get_results(query, s):
    edits_allowed = 5  # This accounts for manual adjustments to the query sentence
    res = {}
    res["query"] = query
    res["difficulty"] = s.data_name.split("_")[1]
    
    set_cover = [clean_sentence(x) for x in s.get_k_set_cover(query)]
    set_cover = [x for x in set_cover if Levenshtein.distance(query, x) > edits_allowed][:5]
    
    res["set_cover"] = {k: v for k,v in zip(range(0, s.k), set_cover)}
    
    s.update_search_params(use_wt=True, use_dis=True)
    weighted_set_cover = [clean_sentence(x) for x in s.get_k_set_cover(query)]
    weighted_set_cover = [x for x in weighted_set_cover if Levenshtein.distance(query, x) > edits_allowed][:5]
    s.update_search_params(use_wt=False, use_dis=False)
    res["weighted_set_cover"] = {k: v for k,v in zip(range(0, s.k), weighted_set_cover)}
    
    avg = [clean_sentence(x) for x in s.get_k_avg_embed(query)]
    avg = [x for x in avg if x != query][:5]
    res["embedding_average"] = {k: v for k,v in zip(range(0, s.k), avg)}
    
    for metric in ["word_movers_distance", "jaccard", "edit_distance"]:
        non_ind = [clean_sentence(x) for x in s.get_k_non_index(query, metric)]
        non_ind = [x for x in non_ind if Levenshtein.distance(query, x) > edits_allowed][:5]
        res[metric] = {k: v for k,v in zip(range(0, s.k), non_ind)}
        
    return res

def build_db(query_list=None, sim_list=None):
    if query_list is None:
        with open("./queries.json", "r") as f:
            queries = json.load(f)
            query_list = [queries["middle"], queries["high"], queries["college"]]
            
    if sim_list is None:
        embedder = embeddings.FastTextEmbedding()
        sim_list = []
        for key in ["middle", "high", "college"]:
            sim = SimilarSentences(f'books_{key}', embedder)
            sim.update_search_params(k=6) # to avoid repetition
            sim.set_stopwords()
            sim.embed_sentences()
            sim.build_sentence_index()
            sim.embed_words()
            sim.build_word_index()
            sim_list.append(sim)
            
    result = {}
    ind = 0
    for q_list, s in zip(query_list, sim_list):
        for q in q_list:
            result[str(ind)] = get_results(q, s)
            ind += 1
    with open("./database.json", "w") as f:
        json.dump(result, f, indent=2)

def sample_sentences(n, sim):
    return [clean_sentence(x) for x in random.sample(sim.sen_2_ind.keys(), n)]


In [None]:
embedder = embeddings.FastTextEmbedding()
sim_list = []
for key in ["middle", "high", "college"]:
    sim = SimilarSentences(f'books_{key}', embedder)
    sim.update_search_params(k=6, use_wmd_estimate=True) # to avoid repetition
    sim.set_stopwords()
    sim.embed_sentences()
    sim.build_sentence_index()
    sim.embed_words()
    sim.build_word_index()
    sim_list.append(sim)

In [9]:
build_db(sim_list=sim_list)