In [43]:
# Standard library
import logging

import os
from pathlib import Path
import shlex
import time
import tqdm

# Third party
import hydra
import jsonlines
import rich
import transformers

# Setup Logging
LOGGER = logging.getLogger(__name__)
logging.getLogger("common_retriever").setLevel(logging.WARNING)

# First Party
BASE_PATH = Path("/home/mila/g/gagnonju/DPR/")
CONF_PATH = BASE_PATH/"conf"
INPUTS = Path("/home/mila/g/gagnonju/IteratedDecoding/outputs/")
OUT_FILE = Path("/home/mila/g/gagnonju/DPR/outputs/integrated_script_attempt/")

os.chdir(BASE_PATH)
import common_retriever

In [32]:
with hydra.initialize_config_dir(
    config_dir=str(CONF_PATH), 
):
    cfg = hydra.compose(
        config_name="dense_retriever",
        overrides=[
            f"out_file={shlex.quote(str(OUT_FILE))}",
            f"batch_size={1024}",
        ],
    )
    
    all_passages, id_prefixes = common_retriever.load_passages(cfg)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
[140592103552832] 2021-09-10 00:10:13,298 [INFO] root: args.local_rank -1
[140592103552832] 2021-09-10 00:10:13,299 [INFO] root: WORLD_SIZE None
[140592103552832] 2021-09-10 00:10:13,426 [INFO] root: Initialized host cn-c005 as d.rank -1 on device=cuda, n_gpu=1, world size=1
[140592103552832] 2021-09-10 00:10:13,427 [INFO] root: 16-bits training: False 
[140592103552832] 2021-09-10 00:10:13,438 [INFO] dpr.data.download_data: Requested resource from https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
[140592103552832] 2021-09-10 00:10:13,439 [INFO] dpr.data.download_data: Download root_dir /home/mila/g/gagnonju/DPR
[140592103552832] 2021-09-10 00:10:13,441 [INFO] dpr.data.download_d

In [82]:
SETS = [
    "nq_test",
    "nq_dev",
    "nq_train",
]

data = {}
for k in SETS:
    input_target = INPUTS/f"nq_retrievals_{k}.jsonl"
    if os.path.exists(input_target):
        with jsonlines.open(input_target) as fin:
            data[k] = list(fin)

for k, v in data.items():
    print(f"'{k}': {len(v)}")

'nq_test': 3610
'nq_dev': 8757
'nq_train': 79168


In [83]:
USE_N_CONTEXTS = 5

logging.getLogger("transformers.configuration_utils").setLevel(logging.CRITICAL)
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-large")

for set_name, set_ in data.items():
    path = INPUTS/"data_with_context"/f"with_context_{set_name}_all.source"
    with open(path, "w") as fin:
        for i, entry in enumerate(tqdm.tqdm(set_)):
            question = entry["question"]
            indices = entry["indices"]
            scores = entry["scores"]
            answers = entry["answer"]

            indices = [index for index in indices]
            f_scores = [float(score) for score in scores]
            contexts = [all_passages[index].text for index in indices[:USE_N_CONTEXTS]]
            len_contexts = [
                len(tokenizer.encode(context))
                for context in contexts[:USE_N_CONTEXTS]
            ]

            SEP = tokenizer.sep_token
            real_output = question + SEP + SEP.join(contexts)
            fin.write(real_output + "\n")

[140592103552832] 2021-09-10 01:29:19,027 [INFO] transformers.tokenization_utils_base: loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /home/mila/g/gagnonju/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b
[140592103552832] 2021-09-10 01:29:19,030 [INFO] transformers.tokenization_utils_base: loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /home/mila/g/gagnonju/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
100%|██████████| 3610/3610 [00:13<00:00, 266.90it/s]
100%|██████████| 8757/8757 [00:26<00:00, 326.85it/s]
100%|██████████| 79168/79168 [03:45<00:00, 350.98it/s]


In [95]:
!find $HOME/IteratedDecoding/outputs/data_with_context/

/home/mila/g/gagnonju/IteratedDecoding/outputs/data_with_context/
/home/mila/g/gagnonju/IteratedDecoding/outputs/data_with_context/with_context_nq_dev_all.source
/home/mila/g/gagnonju/IteratedDecoding/outputs/data_with_context/with_context_nq_test_all.source
/home/mila/g/gagnonju/IteratedDecoding/outputs/data_with_context/with_context_nq_train_all.source
