In [1]:
#####################################################################################
# Make a few basic checks first
#####################################################################################
# Check that we are using the correct venv
import sys
assert "condaless" in sys.executable, sys.executable

# Check that we are using the correct version of transformers
print("importing transformers")
import transformers
v = tuple(transformers.__version__.strip().split(".")[:2])
assert v == ("2", "11"), v

importing transformers


In [2]:
# Standard library
import argparse
import collections
import contextlib
import copy
import importlib
import json
import logging
import numpy as np
import os
from pathlib import Path
import re
import shlex
import sys
import time
from typing import *

# Third party
import colorama
import faiss
import hydra
import more_itertools
import jsonlines
import omegaconf
import rich
import torch
import transformers
import tqdm

# First Party
ROOT_PATH = Path("/home/mila/g/gagnonju/IteratedDecoding/")
GAR_PATH = ROOT_PATH/"GAR/gar"
sys.path.insert(0, str(GAR_PATH))

import common_retriever
import dense_retriever
import train_generator
import utils_gen

LOGGER = logging.getLogger(__name__)

def format_dict_default(obj):
    if isinstance(obj, omegaconf.DictConfig):
        obj = omegaconf.OmegaConf.to_container(obj)
        return obj
    
    if isinstance(obj, Path):
        return str(obj)
    
    else:
        rich.print(f"[blue bold]Failed with {type(obj)}")
        return str(obj)

def convert_path(obj):
    assert isinstance(obj, Path)
    return str(obj)

def format_dict(d: dict):
    return json.dumps(d, indent=2, sort_keys=True, default=convert_path)

def save_json(obj, path, *args, **kwargs):
    with open(path, "w") as fout:
        json.dump(obj, fout, *args, **kwargs)
        
def load_json(path, *args, **kwargs):
    with open(path) as fin:
        return json.load(fin, *args, **kwargs)        


[139995926800000] 2021-09-30 05:05:33,586 [INFO] common_retriever: Checking versions...
[139995926800000] 2021-09-30 05:05:33,587 [INFO] common_retriever: All version checks passed.


In [3]:
format_info = "[%(levelname)s] (%(asctime)s) {%(module)s.%(funcName)s:%(lineno)d}:\n"
format = (
    colorama.Fore.CYAN + 
    format_info + 
    colorama.Style.RESET_ALL + 
    "%(message)s"
)
logging.basicConfig(
    format=format,
    level=logging.INFO,
    force=True,
)
LOGGER.info("test")
common_retriever = importlib.reload(common_retriever)

[36m[INFO] (2021-09-30 05:05:37,329) {3464324012.<module>:13}:
[0mtest
[36m[INFO] (2021-09-30 05:05:37,338) {common_retriever.<module>:58}:
[0mChecking versions...
[36m[INFO] (2021-09-30 05:05:37,339) {common_retriever.<module>:65}:
[0mAll version checks passed.


In [4]:
##############################################################################
# CONFIG
##############################################################################
def build_args(root_path):
    RUN_NAME = "first_test"
    
    SENTENCE_DATA_DIR = root_path / "GAR/data/nq-sentence"
    SENTENCE_MODEL = root_path / "GAR/gar/outputs/sentence_with_context/last.ckpt"

    DATA_DIR = SENTENCE_DATA_DIR
    DPR_CONF_PATH = ROOT_PATH / "DPR/conf"
    QUERY_AUG_MODEL_PATH = SENTENCE_MODEL
    READER_MODEL_PATH = "/home/mila/g/gagnonju/IteratedDecoding/GAR/gar/outputs/answer_with_context/last.ckpt"

    DATALOADER_MAX_TARGET_LEN = 0
    DATALOADER_MAX_SOURCE_LEN = 60
    BATCH_SIZE =  10
    MAX_LOOP_N = 15
    N_DOCS = 5
    MAX_TARGET_LEN = 160
    MAX_SOURCE_LEN = 768
    QUERY_AUG_INPUT_MAX_LEN  = 768
    DECODING_CONF = dict(
        num_beams=1,
        max_length=MAX_TARGET_LEN,
        # repetition_penalty=2.5,
        # length_penalty=1.0,
        early_stopping=True,
    )
    OUTPUT_ROOT = ROOT_PATH / "jobs/iterated_decoding_output/"
    assert OUTPUT_ROOT.exists(), OUTPUT_ROOT
    
    OUTPUT_PATH = OUTPUT_ROOT / RUN_NAME
    OUTPUT_PATH.mkdir(exist_ok=True)
    
    try:
        hydra.initialize_config_dir(config_dir=str(DPR_CONF_PATH))
    except ValueError as err:
        # LOGGER.info(err)
        pass
        
    dpr_cfg = hydra.compose(
        config_name="dense_retriever", overrides=[f"out_file=/tmp/"]
    )
    LOGGER.info("DPR_CFG:\n" + format_dict(omegaconf.OmegaConf.to_container(dpr_cfg)))

    args = dict(
        conf_path=DPR_CONF_PATH,
        data_dir= DATA_DIR,
        query_aug_model_path=QUERY_AUG_MODEL_PATH,
        reader_model_path=READER_MODEL_PATH,
        dataloader_max_target_len=DATALOADER_MAX_TARGET_LEN,
        dataloader_max_source_len=DATALOADER_MAX_SOURCE_LEN,
        batch_size=BATCH_SIZE,
        max_loop_n=MAX_LOOP_N,
        n_docs=N_DOCS,
        max_source_len=MAX_SOURCE_LEN,
        max_target_len=MAX_TARGET_LEN,
        query_aug_input_max_len=QUERY_AUG_INPUT_MAX_LEN,
        decoding_conf=DECODING_CONF,
        output_path=OUTPUT_PATH,
    )
    LOGGER.info("Args:\n" + format_dict(args))
    

    json_output_config = dict(
        indent=2,
        default=convert_path,
        sort_keys=True,
    )
    
    save_json(
        args, 
        OUTPUT_PATH / "args.json", 
        **json_output_config
    )
    save_json(
        omegaconf.OmegaConf.to_container(dpr_cfg), 
        OUTPUT_PATH / "config.json",
        **json_output_config
    )
    
    return argparse.Namespace(**args), dpr_cfg

args, dpr_cfg = build_args(ROOT_PATH)

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
[36m[INFO] (2021-09-30 05:05:37,670) {1519350923.build_args:45}:
[0mDPR_CFG:
{
  "batch_size": 128,
  "ctx_datatsets": [
    "dpr_wiki"
  ],
  "ctx_sources": {
    "dpr_wiki": {
      "_target_": "dpr.data.retriever_data.CsvCtxSrc",
      "file": "data.wikipedia_split.psgs_w100",
      "id_prefix": ""
    }
  },
  "datasets": {
    "curatedtrec_test": {
      "_target_": "dpr.data.retriever_data.CsvQASrc",
      "file": "data.retriever.qas.curatedtrec-test"
    },
    "nq_dev": {
      "_target_": "dpr.data.retriever_data.CsvQASrc",
      "file": "data.retriever.qas.nq-dev"
    },
    "nq_test": {
      "_target_": "dpr.data.retriever_data.CsvQASrc",
      "file": "data.retriever.qas.nq-test"
  

In [5]:
@contextlib.contextmanager
def time_this(title):
    start = time.monotonic()
    bleu = colorama.Fore.BLUE
    green = colorama.Fore.GREEN
    reset = colorama.Style.RESET_ALL
    LOGGER.info(f"{bleu}Starting:{reset} {title}")
    yield "pizza"
    now = time.monotonic() - start
    LOGGER.info(f"{green}Done:{reset} {title}, {now:0.2f}s")    


def build_retriever(cfg):
    with time_this("common_retriever.load_passages (~6 min)"):
        all_passages, id_prefixes = common_retriever.load_passages(
            cfg,
        )

    with time_this("hydra.compose"):
        cfg = hydra.compose(
            config_name="dense_retriever",
            overrides=[f"out_file=/tmp/"]
        )

    with time_this("common_retriever.make_retriever (~11 min.)"):
        retriever = (
            common_retriever.make_retriever(cfg, id_prefixes)
        )

    with time_this("common_retriever.load_data"):
        questions, question_answers, special_query_token = (
            common_retriever.load_data(cfg)
        )

        n_docs = cfg.n_docs
    return retriever, all_passages, special_query_token

retriever, all_passages, special_query_token = build_retriever(
    dpr_cfg
)

[36m[INFO] (2021-09-30 05:05:37,683) {2289476895.time_this:7}:
[0m[34mStarting:[0m common_retriever.load_passages (~6 min)
[36m[INFO] (2021-09-30 05:05:37,732) {options.setup_cfg_gpu:70}:
[0margs.local_rank -1
[36m[INFO] (2021-09-30 05:05:37,732) {options.setup_cfg_gpu:73}:
[0mWORLD_SIZE None
[36m[INFO] (2021-09-30 05:05:37,733) {options.setup_cfg_gpu:89}:
[0mInitialized host cn-d002 as d.rank -1 on device=cuda, n_gpu=4, world size=1
[36m[INFO] (2021-09-30 05:05:37,733) {options.setup_cfg_gpu:97}:
[0m16-bits training: False 
[36m[INFO] (2021-09-30 05:05:37,734) {common_retriever.load_passages:136}:
[0mload_passages: hydra.utils.instantiate
[36m[INFO] (2021-09-30 05:05:37,735) {common_retriever.load_passages:145}:
[0mload_passages: ctx_src.load_data_to
[36m[INFO] (2021-09-30 05:05:37,745) {download_data.download_resource:412}:
[0mRequested resource from https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
[36m[INFO] (2021-09-30 05:05:37,746) {download_d

In [6]:
def to_gpu(index):
    
    co = faiss.GpuMultipleClonerOptions()
    co.shard = True
    
    LOGGER.info(f"{torch.cuda.device_count()} GPUs")
    
    index = faiss.index_cpu_to_all_gpus(index, co=co)
    return index
    
retriever.index.index = to_gpu(retriever.index.index)

[36m[INFO] (2021-09-30 05:23:11,151) {120009568.to_gpu:6}:
[0m4 GPUs


In [7]:
def build_models(reader_model_path, query_aug_model_path):
    ###############################################################################
    # Load query model
    ###############################################################################
    with time_this("query_aug_model.load_from_checkpoint"):
        query_aug_model = train_generator.SummarizationTrainer.load_from_checkpoint(
            str(query_aug_model_path)
        )

        
    ###############################################################################
    # Load inference model
    ###############################################################################
    with time_this("reader_inference_model.load_from_checkpoint"):
        reader_inference_model = train_generator.SummarizationTrainer.load_from_checkpoint(
            str(reader_model_path)
        )
    return query_aug_model, reader_inference_model

query_aug_model, reader_inference_model = build_models(
    reader_model_path=args.reader_model_path, 
    query_aug_model_path=args.query_aug_model_path,
)

[36m[INFO] (2021-09-30 05:23:32,774) {2289476895.time_this:7}:
[0m[34mStarting:[0m query_aug_model.load_from_checkpoint
[36m[INFO] (2021-09-30 05:24:20,373) {configuration_utils.get_config_dict:265}:
[0mloading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json from cache at /home/mila/g/gagnonju/.cache/torch/transformers/7f6632e580b7d9fd4f611dd96dab877cccfc319867b53b8b72fddca7fd64de5c.8b65d3b9a47e96c1909d807f7e7f41dd1ed95092b139965be7b914aa4fb5fd08
[36m[INFO] (2021-09-30 05:24:20,374) {configuration_utils.from_dict:301}:
[0mModel config BartConfig {
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel",
    "BartForConditionalGeneration",
    "BartForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim

In [8]:
def build_stuff(
    batch_size,
    data_dir,
    max_target_len,
    max_source_len,
):
    SUBSET = "val"

    ###############################################################################
    # Tokenizers
    ###############################################################################
    tokenizer_bart = transformers.AutoTokenizer.from_pretrained(
        "facebook/bart-large"
    )
    tokenizer_bert = transformers.AutoTokenizer.from_pretrained(
        "bert-base-uncased"
    )

    ###############################################################################
    # Build dataloader
    ###############################################################################
    with time_this("Build dataloader"):
        dataset = utils_gen.SummarizationDataset(
            tokenizer_bart, 
            type_path=SUBSET, 
            data_dir=data_dir,
            max_source_length=max_source_len,
            max_target_length=max_target_len,
        )
        dataloader = torch.utils.data.DataLoader(
            dataset, 
            batch_size=batch_size, 
            collate_fn=dataset.collate_fn, 
            shuffle=False, # DO NOT CHANGE THIS!
            num_workers=0,
        )
        
    return     (
        dataloader, 
        tokenizer_bart, 
        tokenizer_bert, 
    )

(
    dataloader, 
    tokenizer_bart, 
    tokenizer_bert, 
) = build_stuff(    
    batch_size=args.batch_size,
    data_dir=args.data_dir,
    max_target_len=args.dataloader_max_target_len,
    max_source_len=args.dataloader_max_source_len,
)

[36m[INFO] (2021-09-30 05:25:50,704) {configuration_utils.get_config_dict:265}:
[0mloading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json from cache at /home/mila/g/gagnonju/.cache/torch/transformers/7f6632e580b7d9fd4f611dd96dab877cccfc319867b53b8b72fddca7fd64de5c.8b65d3b9a47e96c1909d807f7e7f41dd1ed95092b139965be7b914aa4fb5fd08
[36m[INFO] (2021-09-30 05:25:50,706) {configuration_utils.from_dict:301}:
[0mModel config BartConfig {
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel",
    "BartForConditionalGeneration",
    "BartForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopp

loading from /home/mila/g/gagnonju/IteratedDecoding/GAR/data/nq-sentence/val.target.processed (pkl)... make sure data is what you need


[36m[INFO] (2021-09-30 05:25:54,622) {2289476895.time_this:10}:
[0m[32mDone:[0m Build dataloader, 3.29s


In [9]:
# def marginal(contexts, scores, k):
#     total_scores = collections.defaultdict(lambda: 0)
#     for context, score in zip(contexts, scores):
#         total_scores[context] += score
#     unique_contexts, unique_scores = zip(
#         *total_scores.items()
#     )   
#     return unique_contexts.sort(key=unique_scores)[:k]

def decode(
    model, 
    batch, 
    tokenizer, 
    decoding_conf,
):
    batch_ids = batch["input_ids"]
    source_mask = batch["attention_mask"]
    
    source_ids, source_mask = utils_gen.trim_batch(
        batch_ids, tokenizer.pad_token_id, attention_mask=batch["attention_mask"]
    )
    
    args = dict(
        input_ids=source_ids.cuda(),
        attention_mask=source_mask.cuda(),
        **decoding_conf
    )
    generated_ids = model.model.generate(
        **args
    )
    
    return generated_ids


def question_generator(dataloader_, tokenizer_bart):
    for data_dict in tqdm.tqdm(dataloader_):
        question_ids = data_dict["source_ids"]
        question_text = [
            tokenizer_bart.decode(question_id).replace(
                 tokenizer_bart.pad_token, ""   
            )
            for question_id in question_ids
        ]
        yield question_text       


In [10]:
def selection_technique(ids, score):
    return ids

def write_contexts(
    all_contexts: Dict[str, str], 
    context_ids: List[str], 
    output_path: str, 
):
    
    text = []
    for ids_per_retrieval in context_ids:
        text.append([all_contexts[ids] for ids in ids_per_retrieval])
        
    retrieved = dict(
        text=text,
        ids=context_ids
    )          
        
    with jsonlines.open(output_path, "a") as f_out:
        f_out.write(retrieved)

def write_generations(
    generated_text: List[str],
    path: str, 
):
    with jsonlines.open(path, "a") as f_out:
        f_out.write(generated_text)


def inference(
    all_passages: Dict[str, str],
    query_aug_model: train_generator.SummarizationTrainer,
    reader_inference_model: train_generator.SummarizationTrainer,
    special_query_token: str,
    retriever,
    selection_technique: Callable,
    question_dataloader: torch.utils.data.DataLoader,
    max_loop_n: int,
    decoding_conf: dict,
    query_aug_input_max_length: int,
    n_docs: int,
    output_path: str,
): 
    output_path = Path(output_path)
    LOGGER.info("started.")
    LOGGER.info(f"output_path: {output_path}")
    
    context_output_prefix = "contexts_"
    reader_output_prefix = "reader_"
    query_aug_output_prefix = "query_aug_"
    
    for path in output_path.glob(f"{context_output_prefix}*.jsonl"):
        LOGGER.info(f"Deleting path: {path}")
        os.remove(path)
    
    for path in output_path.glob(f"{reader_output_prefix}*.jsonl"):
        LOGGER.info(f"Deleting path: {path}")
        os.remove(path)
        
    for path in output_path.glob(f"{query_aug_output_prefix}*.jsonl"):
        LOGGER.info(f"Deleting path: {path}")
        os.remove(path)        
        
    with torch.inference_mode(True):
        query_aug_text_accum = []
        sep_counter = re.compile(tokenizer_bart.sep_token)

        for loop_i in range(max_loop_n):
            context_output_path = output_path / f"{context_output_prefix}{loop_i}.jsonl"
            reader_output_path = output_path / f"{reader_output_prefix}{loop_i}.jsonl"
            query_aug_output_path = output_path / f"{query_aug_output_prefix}{loop_i}.jsonl"
            
            for batch_i, question_batch in enumerate(question_generator(
                question_dataloader, tokenizer_bart
            )):

                count = len(sep_counter.findall(question_batch[0]))
                assert count <= 1, (
                    count, len(question_batch[0].split())
                )

                ################################################
                # Prepare retrieval and retrieve
                ################################################
                if loop_i == 0:
                    # If we are at loop zero, just retrieve.
                    # We may want to generate then retrieve.
                    retrieval_query = question_batch
                else:
                    # Otherwise, add query augmentation generations to the question.
                    retrieval_query = []
                    for i, question in enumerate(question_batch):
                        # This is super sketchy. 
                        retrieval_query.append(
                            question + tokenizer_bert.sep_token + 
                            query_aug_text_accum[batch_i][-1][i]
                        )
                    rich.print("[blue bold]question_batch:", question_batch)
                    rich.print("[blue bold]retrieval_query:", retrieval_query)
                    sys.exit()
                # Retrieve.
                with time_this("retrieve"):
                    top_ids_and_scores = common_retriever.retrieve(
                        retriever,
                        all_passages=all_passages, 
                        questions=retrieval_query,
                        special_query_token=special_query_token,
                        n_docs=n_docs,
                    )

                ################################################
                # Deal with contexts
                ################################################
                top_ids, scores = zip(*top_ids_and_scores)            
                # We likely have to decode and re-encode here
                selected_contexts_ids = selection_technique(
                    top_ids, scores,
                )

                generation_input_batch = []
                for question, selected_ids in zip(
                    question_batch, selected_contexts_ids
                ):
                    contexts = [
                        all_passages[ids_].text for ids_ in selected_ids
                    ]    
                    generation_input = (
                        question + tokenizer_bart.sep_token + 
                        tokenizer_bart.sep_token.join(
                            contexts
                        )
                    )

                    generation_input_batch.append(
                        generation_input
                    )

                ################################################
                # Prepare query augmentation generation
                # and generate
                ################################################
                tokenized_query_aug_inputs = tokenizer_bart.batch_encode_plus(
                    generation_input_batch,
                    return_tensors="pt",
                    pad_to_max_length=True,
                    max_length=query_aug_input_max_length,
                )

                with time_this("query aug generation"):
                    generations_ids = decode(
                        model=query_aug_model, 
                        batch=tokenized_query_aug_inputs, 
                        tokenizer=tokenizer_bart, 
                        decoding_conf=decoding_conf,
                    )

                    query_aug_text_batch = []
                    for question, tokenized_query_aug_input, generations_id in zip(
                        question_batch, 
                        tokenized_query_aug_inputs["input_ids"], 
                        generations_ids,
                    ):
                        generation_text = tokenizer_bart.decode(generations_id)
                        input_text = tokenizer_bart.decode(tokenized_query_aug_input)

                        # rich.print(f"[bold blue]Question:[/] {question}")
                        # rich.print(f"[bold blue]Output:[/] {generation_text}")
                        # rich.print(
                        #     f"[bold blue]Input:[/] {input_text}"
                        # )

                        query_aug_text_batch.append(generation_text)

                ################################################
                # Deal with the generated text: query aug
                ################################################
                if loop_i == 0:
                    query_aug_text_accum.append([])
                query_aug_text_accum[batch_i].append(query_aug_text_batch)
                    
                ################################################
                # Generate for reader model inference
                ################################################
                with time_this("reader generation"):
                    reader_output_ids = decode(
                        model=reader_inference_model, 
                        batch=tokenized_query_aug_inputs, 
                        tokenizer=tokenizer_bart, 
                        decoding_conf=decoding_conf,
                    )

                    # Decode the tokens of the batch
                    reader_text_batch = []
                    for question, tokenized_query_aug_input, generations_id in zip(
                        question_batch, 
                        tokenized_query_aug_inputs["input_ids"], 
                        reader_output_ids,
                    ):
                        generation_text = tokenizer_bart.decode(generations_id)
                        inputs_text = tokenizer_bart.decode(tokenized_query_aug_input)

                        # rich.print(f"[bold blue]Question:[/] {question}")
                        # rich.print(f"[bold blue]Output:[/] {generation_text}")
                        # rich.print(
                        #     f"[bold blue]Input:[/] {tokenized_query_aug_input['input_ids']}"
                        # )

                        reader_text_batch.append(generation_text)

                ################################################
                # Deal with the generated text: reader inference
                ################################################
                with time_this("writing contexts"):
                    write_contexts(
                        all_contexts=all_passages, 
                        context_ids=selected_contexts_ids, 
                        output_path=context_output_path, 
                    )
                with time_this("writing reader generations"):
                    write_generations(
                        reader_text_batch,
                        reader_output_path,
                    )
                with time_this("writing query aug generations"):
                    write_generations(
                        query_aug_text_batch,
                        query_aug_output_prefix,
                    )
    
inference(
    all_passages=all_passages,
    query_aug_model=query_aug_model.cuda(),
    reader_inference_model=reader_inference_model.cuda(),
    special_query_token=special_query_token,
    retriever=retriever,
    selection_technique=selection_technique,
    question_dataloader=dataloader,
    max_loop_n=args.max_loop_n,
    query_aug_input_max_length=args.max_source_len,
    decoding_conf=args.decoding_conf,
    n_docs=args.n_docs,
    output_path=args.output_path
)

[36m[INFO] (2021-09-30 05:25:55,785) {2800190409.inference:45}:
[0mstarted.
[36m[INFO] (2021-09-30 05:25:55,787) {2800190409.inference:46}:
[0moutput_path: /home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/first_test
[36m[INFO] (2021-09-30 05:25:55,793) {2800190409.inference:53}:
[0mDeleting path: /home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/first_test/contexts_0.jsonl
[36m[INFO] (2021-09-30 05:25:55,794) {2800190409.inference:53}:
[0mDeleting path: /home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/first_test/contexts_11.jsonl
[36m[INFO] (2021-09-30 05:25:55,796) {2800190409.inference:53}:
[0mDeleting path: /home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/first_test/contexts_13.jsonl
[36m[INFO] (2021-09-30 05:25:55,797) {2800190409.inference:53}:
[0mDeleting path: /home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/first_test/contexts_6.jsonl
[36m[INFO] (2021-09-30 05:25:55,7

  0%|          | 0/876 [00:00<?, ?it/s]


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
 """
    We currently do 

    for batch_questions in questions:
        for loop_i in range(max_loop_n)
            encode_context_to_gen
            (load generator to faster memory)
            generate(batch)
            decode_text_from_gen

            encode_text_to_retriever
            (load retriever to faster memory)
            retrieve(batch)
            decode_text_from_retriever

    We could do

    for loop_i in range(max_loop_n)
        # Parallelize as needed if helpful
        # num_questions x num_beams to do
        for batch_questions in zip(
            retrieved_contexts
        ):
            encode_contexts_to_gen

        (imaginary barrier)
        (load generator to faster memory)    
        for batch_questions in questions:
            top_beam_ids, top_beam_ppls = generate(batch)

        # Parallelize as needed if helpful
        for batch_questions in questions:
            decode_text_from_gen

        # Parallelize as needed if helpful
        for batch_questions in questions:
            encode_text_to_retriever

        (imaginary barrier)
        # Parallelize as needed if helpful
        for batch_questions in questions:
            concatenate_contexts

        (imaginary barrier)
        (load retriever to faster memory)
        for batch_questions in questions:            
            retrieve(batch)

        # Parallelize as needed if helpful
        for batch_questions in questions:            
            decode_text_from_retriever

    Much better for GPU memory locality,
    worse for total memory use. Would maybe allow
    for slightly larger batches.

    I think that the fact that we don't need the 
    GPU results right away makes it async and faster

    """
    