In [1]:
import datasets
import os
import numpy as np
import sys
import re
sys.path.append("../retrievers")

from tqdm import tqdm
from totto_retriever import TottoRetriever
from retriever import Retriever
from nltk import word_tokenize

os.environ["HF_HOME"] = "/projects/ogma2/users/andrewsi/cache/huggingface/"
root = "/projects/ogma2/users/andrewsi/control-data2text"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
train_dataset = f"{root}/prototype-retrieval/datasets/totto/totto_train_v2"
val_dataset = f"{root}/prototype-retrieval/datasets/totto/totto_validation"
totto_val_json = f"{root}/transformers/examples/seq2seq/test_data/totto/val_with_headers_only.json"

retriever = TottoRetriever(train_dataset)

In [6]:
retriever.dataset

Dataset({
    features: ['clean_source', 'clean_source_embed', 'edit_dist_map', 'edit_dist_map_no_mask', 'masked_target', 'masked_target_embed', 'source', 'source_headers_only', 'source_headers_only_embed', 'split_masked_target', 'split_target', 'target', 'target_embed'],
    num_rows: 120761
})

In [2]:
trim_retriever = TottoRetriever("../datasets/totto/totto_train_trim")

In [3]:
trim_retriever.dataset

Dataset({
    features: ['clean_source_embed', 'source', 'target'],
    num_rows: 120761
})

In [4]:
trim_retriever.write_train_set(f"{root}/transformers/examples/seq2seq/test_data/totto_proto/train_k5_clean_source_embed.json", retrieval_map=None, retrieval_embed="clean_source_embed", retrieval_k=5)

HBox(children=(FloatProgress(value=0.0, max=121.0), HTML(value='')))

  0%|          | 0/120761 [00:00<?, ?it/s]
100%|██████████| 120761/120761 [4:39:05<00:00,  7.21it/s]
Wrote training data to file: /projects/ogma2/users/andrewsi/control-data2text/transformers/examples/seq2seq/test_data/totto_proto/train_k5_clean_source_embed.json


In [4]:
proto_txt = f"{root}/transformers/examples/seq2seq/exp/totto/totto_t5_small_new_parent/validation_results/checkpoint-64158/preds.txt"
proto_f = open(proto_txt, "r")
proto_lines = [line.strip() for line in proto_f]

In [9]:
def add_baseline_pred(example, idx):
    return {"baseline_pred": proto_lines[idx]}


In [12]:
retriever.dataset = retriever.dataset.map(add_baseline_pred, with_indices=True)

HBox(children=(FloatProgress(value=0.0, max=7700.0), HTML(value='')))




In [22]:
retriever.add_embeds("baseline_pred", gpu=0)

HBox(children=(FloatProgress(value=0.0, max=7700.0), HTML(value='')))




In [5]:
retriever.write_eval_set(f"{root}/transformers/examples/seq2seq/test_data/totto_proto/val_from_train_random.json", query_embed="random", retrieval_path=train_dataset, eval_k=10)

100%|██████████| 7700/7700 [00:37<00:00, 203.09it/s]
Wrote eval data to file: /projects/ogma2/users/andrewsi/control-data2text/transformers/examples/seq2seq/test_data/totto_proto/val_from_train_random.json


In [6]:
retriever.write_eval_set_from_protos(f"{root}/transformers/examples/seq2seq/exp/totto/totto_t5_small_new_parent/validation_results/checkpoint-64158/preds.txt", f"{root}/transformers/examples/seq2seq/test_data/totto/validation.json", f"{root}/transformers/examples/seq2seq/test_data/totto_proto/val_from_baseline_preds.json")

Wrote eval data to file: /projects/ogma2/users/andrewsi/control-data2text/transformers/examples/seq2seq/test_data/totto_proto/val_from_baseline_preds.json


In [5]:
source = retriever.dataset[0]["source"]

In [21]:
source

'<page_title> List of 8/9 PM telenovelas of Rede Globo </page_title> <section_title> 2000s </section_title> <table> <cell> A Favorita <col_header> Title </col_header> </cell> </table>'

In [20]:
re.findall("> ([^<>]+?) <", source)

['List of 8/9 PM telenovelas of Rede Globo', '2000s', 'A Favorita', 'Title']

In [10]:
def clean_source(example):
    clean_src = ", ".join(re.findall("> ([^<>]+?) <", example["source"]))
    return {"clean_source": clean_src}

In [6]:
def clean_headers(example):
    clean_src = ", ".join(re.findall("> ([^<>]+?) <", example["source_headers_only"]))
    return {"clean_source_headers_only": clean_src}

In [7]:
retriever.dataset = retriever.dataset.map(clean_headers)

HBox(children=(FloatProgress(value=0.0, max=7700.0), HTML(value='')))




In [13]:
retriever.add_embeds("target", gpu=0)

HBox(children=(FloatProgress(value=0.0, max=7700.0), HTML(value='')))




In [9]:
src = retriever.dataset[1]["clean_source"]

In [34]:
retriever.save_dataset()

In [None]:
retriever.write_eval_set("/projects/ogma2/users/andrewsi/control-data2text/transformers/examples/seq2seq/test_data/totto_proto/val_from_train_headers_only.json")

In [5]:
src = retriever.dataset[0]["source"]
tgt = retriever.dataset[0]["target"]

In [None]:
retrieval_k = 10
weighted = False
max_edit_dist = 5
retriever.write_train_set(f"/projects/ogma2/users/andrewsi/control-data2text/transformers/examples/seq2seq/test_data/totto_proto/train_k{retrieval_k}{'_max' + str(max_edit_dist) if max_edit_dist else ''}{'_weighted' if weighted else ''}.json", retrieval_k=retrieval_k, max_edit_dist=max_edit_dist, weighted=weighted)

In [14]:
retriever.write_eval_set("/projects/ogma2/users/andrewsi/control-data2text/transformers/examples/seq2seq/test_data/totto_proto/val_from_train_target_embed.json", retrieval_embed="target_embed", retrieval_path=train_dataset, eval_k=10)

HBox(children=(FloatProgress(value=0.0, max=121.0), HTML(value='')))

  0%|          | 0/7700 [00:00<?, ?it/s]
100%|██████████| 7700/7700 [21:23<00:00,  6.00it/s]
Wrote eval data to file: /projects/ogma2/users/andrewsi/control-data2text/transformers/examples/seq2seq/test_data/totto_proto/val_from_train_target_embed.json


In [None]:
idx = 1502

results = val_data.get_nearest_examples("source_headers_only_embed", np.array(val_data[idx]["source_headers_only_embed"], dtype=np.float32), k=30)

print(val_data[idx]["target"])
print("===================================")

for i, target in enumerate(results[1]["target"]):
        print("({}) {}".format(results[0][i], target))

In [None]:
def add_protos(proto_path, val_file):
    examples = []
    val_src_lines = []

    proto_file = open(proto_path, "r")
    proto_lines = [line for line in proto_file]

    for line in val_file:
        val_example = json.loads(line)
        example = {}
        example["source"] = proto_lines[i].strip() + " [SEP] " + val_example["source"]
        example["target"] = val_example["target"]

    with open(out_json, "w+") as f:
        for example in examples:
            f.write(json.dumps(example) + "\n")
        print(f"Wrote training data to file: {out_json}")