# Process datasets

This notebook takes as input a file with sentences containing the target query and:
   
1. Performs linguistic-based filtering of the sentences,
2. Uses BLERT to perform word prediction of the masked target query.
    
The output are two additional files: `{dataset}_{query}_synparsed.pkl` and `{dataset}_{query}_synparsed_pred_bert.pkl`.

In [None]:
from pathlib import Path
import pandas as pd
import spacy
from utils import explore_preds, prepare_sents
import pandas as pd
from gensim.models import Word2Vec
from scipy.spatial.distance import cosine
import string
from nltk.corpus import stopwords

In [None]:
# Specify the datasets here:
datasets = ["example"] # in our experiments, this was ["jsa", "hmd", "blb"]

In [None]:
# Load the Spacy model
nlp = spacy.load('en_core_web_sm', disable=['ner'])

In [None]:
# Specify the query tokens here. Change the query to see the results for a different
# target word:
query = "machine" 
min_year = 1783
max_year = 1908

In [None]:
# This dictionary maps the query to the name that will be displayed in the output file.
generic = {"machine": "machine"}

In [None]:
# This dictionary maps the query to the tokens it will be expanded to.
query_tokens = dict()
query_tokens["machine"] = ["machine", "machines"]

In [None]:
#### ----------------------------------
#### Linguistic filtering
print("Linguistic filtering")

for dataset in datasets:
    print("*", dataset)
    syndf = pd.read_csv("data/" + dataset + "_processed/" + dataset + "_" + query + ".tsv", sep="\t")
    # If we have more than 100000 sentences, downsample to 100000:
    if query != "machine" and syndf.shape[0] > 65000:
        syndf = syndf.sample(n=65000, random_state=42)
    # Get a sentence ID
    syndf['sentId'] = list(syndf.index.values)
    # Process and filter sentences through syntactic parsing:
    syndf['currentSentence'] = syndf.apply(lambda x: prepare_sents.remove_punctspaces(x["currentSentence"]), axis=1)
    syndf['synt'] = prepare_sents.preprocess_pipe(syndf['currentSentence'], nlp)
    syndf = syndf[syndf.apply(lambda x: prepare_sents.filter_sents_synt(x.synt, x.maskedSentence, x.currentSentence, x.targetExpression), axis=1)]
    syndf["query_label"] = syndf.apply(lambda x: prepare_sents.find_query_deplabel(x.synt, x.maskedSentence, x.targetExpression), axis=1)
    syndf.to_pickle("data/" + dataset + "_processed/" + dataset + "_" + query + "_synparsed.pkl")

In [None]:
#### ----------------------------------
#### BERT masking

print("BERT masking")

for dataset in datasets:
    
    dataset_path = "data/" + dataset + "_processed/" + dataset + "_" + query + "_synparsed.pkl"
    
    if not Path(dataset_path.split(".pkl")[0] + "_pred_bert.pkl").is_file():

        # Load dataframe where to apply this:
        pred_df = pd.read_pickle(dataset_path)
        for epoch in  ["1760_1850", "1890_1900"]:

            print("*", epoch)

            # Create pipeline depending on the BERT model of the specified period
            # and the number of expected predictions:
            pred_toks = 20
            model_rd = explore_preds.create_mask_pipeline(epoch, pred_toks)

            # Use BERT to find most likely predictions for a mask:
            pred_df["pred_bert_" + epoch] = pred_df.apply(lambda x: explore_preds.bert_masking(x, model_rd), axis=1)

        pred_df.to_pickle(dataset_path.split(".pkl")[0] + "_pred_bert.pkl")