In [3]:
# Standard library
import os
from pathlib import Path
import re
import sys

# Third party
import hydra
import rich

# First Party
BASE_PATH = Path("/home/mila/g/gagnonju/DPR/")
CONF_PATH = BASE_PATH/"conf"

os.chdir(BASE_PATH)
import dense_retriever
import jules_validate_dense_retriever
from dense_retriever import *
import common_retriever

Importing torch...
Done importing torch.
Importing transformers...
Done importing transformers.
Importing tqdm...
Done importing tqdm.
Importing spacy...
Done importing spacy.
Importing hydra...
Done importing hydra.
Importing omegaconf...
Done importing omegaconf.


In [4]:
###########################################################################
# Resets the passage cache, which is the longest load time of the 
# script at 7 min
###########################################################################
all_passages = None

In [5]:
index = None
retriever = None

In [8]:
@hydra.main(config_path=CONF_PATH, config_name="dense_retriever")
def main(cfg):
    ###########################################################################
    # Complete and validate CFG
    ###########################################################################
    jules_validate_dense_retriever.validate(
        {k: getattr(cfg, k) for k in dir(cfg)}, 
        dense_retriever.SCHEMA_PATH,
    )
    cfg = dense_retriever.setup_cfg_gpu(cfg)

    assert cfg.out_file, cfg.out_file
    assert Path(cfg.out_file).parent.exists(), cfg.out_file

    logger.info("CFG (after gpu  configuration):")
    logger.info("%s", OmegaConf.to_yaml(cfg))


    ###########################################################################
    # Prepare models
    ###########################################################################
#     saved_state = load_states_from_checkpoint(cfg.model_file)
#     set_cfg_params_from_state(saved_state.encoder_params, cfg)

#     tensorizer, encoder, _ = init_biencoder_components(
#         cfg.encoder.encoder_model_type, cfg, inference_only=True
#     )

#     encoder_path = cfg.encoder_path
#     if encoder_path:
#         logger.info("Selecting encoder: %s", encoder_path)
#         encoder = getattr(encoder, encoder_path)
#     else:
#         logger.info("Selecting standard question encoder")
#         encoder = encoder.question_model

#     encoder, _ = setup_for_distributed_mode(
#         encoder, 
#         None, 
#         cfg.device, 
#         cfg.n_gpu, 
#         cfg.local_rank, 
#         cfg.fp16
#     )
#     encoder.eval()

#     # load weights from the model file
#     model_to_load = get_model_obj(encoder)
#     logger.info("Loading saved model state ...")

#     encoder_prefix = (
#         encoder_path if encoder_path else "question_model") + "."
#     prefix_len = len(encoder_prefix)

#     logger.info("Encoder state prefix %s", encoder_prefix)
#     question_encoder_state = {
#         key[prefix_len:]: value
#         for (key, value) in saved_state.model_dict.items()
#         if key.startswith(encoder_prefix)
#     }
#     # TODO: long term HF state compatibility fix
#     model_to_load.load_state_dict(question_encoder_state, strict=False)
#     vector_size = model_to_load.get_out_size()
#     logger.info("Encoder vector_size=%d", vector_size)


    ###########################################################################
    # Prepare sources
    ###########################################################################
#     rich.print("[red bold]Starting the thing.")
#     id_prefixes = []
#     ctx_sources = []
#     for ctx_src in cfg.ctx_datatsets:
#         ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
#         id_prefixes.append(ctx_src.id_prefix)
#         ctx_sources.append(ctx_src)
    
#     rich.print(ctx_sources)
#     rich.print("[red bold]Second part.")
    
#     global all_passages
#     if all_passages is None:
#         all_passages = {}
#         for ctx_src in ctx_sources:
#             ctx_src.load_data_to(all_passages)
#             rich.print("[green]Done loading passages.")
#         print(len(all_passages))

    ###########################################################################
    # Load Index & Prepare retriever
    ###########################################################################

#     index_path = cfg.index_path
#     #------------
#     ## Instantiate the index and create a retriever
#     #------------
#     global index
#     global retriever
#     if index is None or retriever is None:
#         rich.print(f"[bold blue]Loading index.")
#         index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
#         logger.info("Index class %s ", type(index))
#         index_buffer_sz = index.buffer_size
#         index.init_index(vector_size)
#         rich.print(f"[bold blue]Done loading index.")
#         rich.print(f"[bold blue]Loading retriever.")
#         retriever = LocalFaissRetriever(
#             encoder, 
#             cfg.batch_size, 
#             tensorizer, 
#             index,
#         )
#         rich.print(f"[bold blue]Loaded retriever.")
#         if index_path and index.index_exists(index_path):
#             logger.info("Index path: %s", index_path)
#             retriever.index.deserialize(index_path)
#         else:
#             logger.info("Reading all passages data from files: %s", input_paths)
#             retriever.index_encoded_data(
#                 input_paths, 
#                 index_buffer_sz, 
#                 path_id_prefixes=path_id_prefixes,
#             )
#             if index_path:
#                 retriever.index.serialize(index_path)
#     else:
#         index_buffer_sz = index.buffer_size
#         rich.print(f"[bold green]Using cached index.")
#         rich.print(f"[bold green]Using cached retriever.")
        

    #------------
    ## Index all passages
    #------------
#     ctx_files_patterns = cfg.encoded_ctx_files

#     logger.info("ctx_files_patterns: %s", ctx_files_patterns)
#     if ctx_files_patterns:
#         assert len(ctx_files_patterns) == len(
#             id_prefixes
#         ), "ctx len={} pref leb={}".format(
#             len(ctx_files_patterns), 
#             len(id_prefixes),
#         )
#     else:
#         assert (
#             index_path
#         ), "Either encoded_ctx_files or index_path parameter should be set."

#     input_paths = []
#     path_id_prefixes = []
#     for i, pattern in enumerate(ctx_files_patterns):
#         pattern_files = glob.glob(pattern)
#         pattern_id_prefix = id_prefixes[i]
#         input_paths.extend(pattern_files)
#         path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))

#     logger.info("Embeddings files id prefixes: %s", path_id_prefixes)
    
       
    ###########################################################################
    # Prepare questions and answers
    ###########################################################################
    questions = []
    question_answers = []

    if not cfg.qa_dataset:
        logger.warning("Please specify qa_dataset to use")
        return

    ds_key = cfg.qa_dataset
    logger.info("qa_dataset: %s", ds_key)

    qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
    qa_src.load_data()
    assert not qa_src.selector, qa_src.selector
    logger.info("Using custom representation token selector")
    retriever.selector = qa_src.selector

    logger.info("id_prefixes per dataset: %s", id_prefixes)

    for ds_item in qa_src.data:
        question, answers = ds_item.query, ds_item.answers
        questions.append(question)
        question_answers.append(answers)

    
    ###########################################################################
    # Get top k results.
    ###########################################################################
    logger.info(
        "Using special token %s", 
        qa_src.special_query_token,
    )
    questions_tensor = retriever.generate_question_vectors(
        questions, 
        query_token=qa_src.special_query_token,
    )
    
    rich.print(
        f"[bold]get_top_docs: Starting. Approx 7 min. {timestamp()}"
    )
    top_ids_and_scores = retriever.get_top_docs(
        questions_tensor.numpy(), 
        cfg.n_docs,
    )
    rich.print("[bold green]get_top_docs: Done.")

    # we no longer need the index
    retriever = None

    if len(all_passages) == 0:
        raise RuntimeError(
            "No passages data found. Please specify "
            "ctx_file param properly."
        )

    rich.print("[green bold]All done.")
    
    
    
sys.argv = [
    "fake.py", 
    "out_file=/home/mila/g/gagnonju/DPR/outputs/integrated_script_attempt.py",
]
main()

[2021-09-08 01:04:25,047][root][INFO] - args.local_rank -1
[2021-09-08 01:04:25,048][root][INFO] - WORLD_SIZE None
[2021-09-08 01:04:25,049][root][INFO] - Initialized host cn-a002 as d.rank -1 on device=cuda, n_gpu=1, world size=1
[2021-09-08 01:04:25,050][root][INFO] - 16-bits training: False 
[2021-09-08 01:04:25,051][root][INFO] - CFG (after gpu  configuration):
[2021-09-08 01:04:25,059][root][INFO] - indexers:
  flat:
    _target_: dpr.indexer.faiss_indexers.DenseFlatIndexer
  hnsw:
    _target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer
  hnsw_sq:
    _target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer
out_file: /home/mila/g/gagnonju/DPR/outputs/integrated_script_attempt.py
validation_workers: 15
n_gpu: 1
qa_dataset: nq_test
ctx_datatsets:
- dpr_wiki
encoded_ctx_files:
- /home/mila/g/gagnonju/DPR/downloads/data/wikipedia_split/psgs_w100.tsv
match: string
n_docs: 100
batch_size: 128
do_lower_case: true
encoder_path: null
index_path: /home/mila/g/gagnonju/DPR/dpr/downloads

[2021-09-08 01:04:32,633][root][INFO] - ctx_files_patterns: ['/home/mila/g/gagnonju/DPR/downloads/data/wikipedia_split/psgs_w100.tsv']
[2021-09-08 01:04:32,635][root][INFO] - Embeddings files id prefixes: ['']
[2021-09-08 01:04:32,636][root][INFO] - qa_dataset: nq_test
[2021-09-08 01:04:32,651][dpr.data.download_data][INFO] - Requested resource from https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv
[2021-09-08 01:04:32,652][dpr.data.download_data][INFO] - Download root_dir /home/mila/g/gagnonju/DPR
[2021-09-08 01:04:32,653][dpr.data.download_data][INFO] - File to be downloaded as /home/mila/g/gagnonju/DPR/downloads/data/retriever/qas/nq-test.csv
[2021-09-08 01:04:32,654][dpr.data.download_data][INFO] - File already exist /home/mila/g/gagnonju/DPR/downloads/data/retriever/qas/nq-test.csv
[2021-09-08 01:04:32,655][dpr.data.download_data][INFO] - Loading from https://dl.fbaipublicfiles.com/dpr/nq_license/LICENSE
[2021-09-08 01:04:32,656][dpr.data.download_data][INFO] - File

[2021-09-08 01:07:40,207][root][INFO] - index search time: 169.900069 sec.


In [None]:
CONCATENATION_TECHNIQUES = dict()
MAX_LOOP_N = 15
MODEL_PATH = None
CONCATENATION_TECHNIQUE = None
BATCH_SIZE_QUESTIONS = 10
OUTPUT_FILES_ROOT = None

make_batches = more_itertools.ichunked


def write_contexts(
    all_contexts, 
    context_ids, 
    output_files_root, 
    loop_i,
):
    with open(
        output_files_root/f"contexts_{loop_i}.txt", 
        "a",
    ) as f_out:
        for context_id in context_ids:
            f_out.write(all_contexts[context_id].strip() + "\n")

def write_generations(
    generated_text,
    output_files_root, 
    loop_i,
):
    with open(
        output_files_root/f"generated_{loop_i}.txt", 
        "a",
    ) as f_out:
        text = "\n".join((x.strip() for x in generated_text))
        f_out.write(text)

def inference(
    all_contexts,
    max_loop_n=MAX_LOOP_N, 
    model_path=MODEL_PATH, 
    concatenation_technique=CONCATENATION_TECHNIQUE,
    output_files_root=OUTPUT_FILES_ROOT,
):
    """
    We currently do 
    
    for batch_questions in questions:
        for loop_i in range(max_loop_n)
            encode_context_to_gen
            (load generator to faster memory)
            generate(batch)
            decode_text_from_gen
            
            encode_text_to_retriever
            (load retriever to faster memory)
            retrieve(batch)
            decode_text_from_retriever
    
    We could do
    
    for loop_i in range(max_loop_n)
        # Parallelize as needed if helpful
        # num_questions x num_beams to do
        for batch_questions in zip(
            retrieved_contexts
        ):
            encode_contexts_to_gen
        
        (imaginary barrier)
        (load generator to faster memory)    
        for batch_questions in questions:
            top_beam_ids, top_beam_ppls = generate(batch)
            
        # Parallelize as needed if helpful
        for batch_questions in questions:
            decode_text_from_gen
            
        # Parallelize as needed if helpful
        for batch_questions in questions:
            encode_text_to_retriever

        (imaginary barrier)
        # Parallelize as needed if helpful
        for batch_questions in questions:
            concatenate_contexts

        (imaginary barrier)
        (load retriever to faster memory)
        for batch_questions in questions:            
            retrieve(batch)
        
        # Parallelize as needed if helpful
        for batch_questions in questions:            
            decode_text_from_retriever
    
    Much better for GPU memory locality,
    worse for total memory use. Would maybe allow
    for slightly larger batches.
    
    I think that the fact that we don't need the 
    GPU results right away makes it async and faster
    
    """
    output_files_root = pathlib.Path(output_files_root)
    model = load_model(model_path)
    concatenation_technique = CONCATENATION_TECHNIQUES[
        concatenation_technique
    ]
    question_batches = make_batches(questions)
    
    context_ids_accum = [[] for _ in range(max_loop_n)]
    generation_text_accum = [[] for _ in range(max_loop_n)]
    
    for question_batch in question_batches:
        retrieval_query = question
        all_generations = []
        contexts = retrieve(retrieval_query)

        for loop_i in range(max_loop_n):
            top_beams_ids, top_beams_ppl = generate(
                question, contexts,
            )
            
            # Detokenize top_beams_ids.
            # We do this here because we don't have the choice?
            # Maybe we do though.
            # batch_size x num_beams x (variable num_words)
            top_beams_text = []
            for unit in top_beams_ids:
                top_beams_text.append([
                    tokenizer_generator.decode(ids_beam) 
                    for ids_beam in unit
                ])
            
            # We likely have to decode and re-encode here
            retrieval_query = concatenation_technique(
                question, 
                top_beams_text,
                top_beams_ppl,
                all_results, 
                loop_i,
            )
            
            context_ids = retrieve(retrieval_query)            
            generation_text_accum[loop_i].append(top_beams_text)
            context_ids_accum[loop_i].append(contexts)
    
    
    # The following are bad, 
    context_text_accum = [[] for _ in range(max_loop_n)]

    
    for i, batch in enumerate(context_ids_accum):
        for id_ in batch:
            context_text_accum[i].append(all_contexts[id_])
    
    # [Validate somewhere]
    
#             write_contexts(
#                 all_contexts, 
#                 context_ids, 
#                 output_files_root, 
#                 loop_i,
#             )
            
#             write_generations(
#                 generated_text,
#                 output_files_root,
#                 loop_i,
#             )
            
        