In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
###############################################################################
# Imports
###############################################################################
# Standard library
import argparse
import collections
import contextlib
import copy
import dataclasses
import importlib
import json
import logging
import numpy as np
import os
from pathlib import Path
import re
import shlex
import shutil
import sys
import time
import timeit
from typing import *

# Third party
import beartype
import colorama
import faiss
import hydra
import more_itertools
import jsonlines
import omegaconf
import rich
import torch
import tqdm
import transformers
print(transformers.__version__)

# First Party
import iterated_utils as utils
import iterated_retrieval as ir
import iterated_retrieval
import common_retriever


ROOT_PATH = Path("/home/mila/g/gagnonju/IteratedDecoding/")
os.chdir(ROOT_PATH / "DPR")
import dense_retriever

GAR_PATH = ROOT_PATH / "GAR" / "gar"
sys.path.insert(0, str(GAR_PATH))
import train_generator
import utils_gen
assert "condaless" in sys.executable, sys.executable


###############################################################################
# Logging
###############################################################################
LOGGER = logging.getLogger(__name__)

format_info = (
    "[%(levelname)s] (%(asctime)s) "
    "{%(name)s.%(funcName)s:%(lineno)d}:\n"
)

logging_format = (
    colorama.Fore.CYAN +
    format_info +
    colorama.Style.RESET_ALL +
    "%(message)s"
)
logging.basicConfig(
    format=logging_format,
    level=logging.INFO,
    force=True,
)
logging.getLogger(
    "transformers.configuration_utils"
).setLevel(logging.WARN)
logging.getLogger(
    "transformers.tokenization_utils"
).setLevel(logging.WARN)
logging.getLogger(
    "transformers.modeling_utils"
).setLevel(logging.WARN)
logging.getLogger(
    "common_retriever"
).setLevel(logging.INFO)
logging.getLogger(
    "dense_retriever"
).setLevel(logging.INFO)


###############################################################################
# CONFIG
###############################################################################
def build_args(root_path):
    RUN_NAME = "first_test"

    SENTENCE_DATA_DIR = root_path / "GAR/data/nq-sentence"
    SENTENCE_MODEL = root_path / "GAR/gar/outputs/sentence_with_context/last.ckpt"

    DATA_DIR = SENTENCE_DATA_DIR
    DPR_CONF_PATH = ROOT_PATH / "DPR/conf"
    QUERY_AUG_MODEL_PATH = SENTENCE_MODEL
    READER_MODEL_PATH = ROOT_PATH / "GAR/gar/outputs/answer_with_context/last.ckpt"

    DATALOADER_MAX_TARGET_LEN = 0
    DATALOADER_MAX_SOURCE_LEN = 30

    GENERATION_BATCH_SIZE = 10
    NUM_RETURN_SEQUENCES_QUERY_AUG = 3
    RETRIEVER_BATCH_SIZE = 15 // NUM_RETURN_SEQUENCES_QUERY_AUG

    AUG_METHOD = "RETRIEVE_ALL_INDIVIDUALLY"
    MAX_LOOP_N = 15
    N_DOCS = 5
    MAX_TARGET_LEN = 160
    MAX_SOURCE_LEN = 768
    FINAL_NUM_CONTEXTS = 5
    QUERY_AUG_INPUT_MAX_LEN = 768
    DECODING_CONF_QUERY_AUG = iterated_retrieval.DecoderConf(
        max_length=MAX_TARGET_LEN,
        num_beams=NUM_RETURN_SEQUENCES_QUERY_AUG,
        # repetition_penalty=2.5,
        # length_penalty=1.0,
        temperature=0.5,
        num_return_sequences=NUM_RETURN_SEQUENCES_QUERY_AUG,
        early_stopping=True,
    )
    DECODING_CONF_READER = iterated_retrieval.DecoderConf(
        num_beams=1,
        max_length=MAX_TARGET_LEN,
        # repetition_penalty=2.5,
        # length_penalty=1.0,
        num_return_sequences=1,
        early_stopping=True,
    )
    OUTPUT_ROOT = ROOT_PATH / "jobs/iterated_decoding_output/"
    assert OUTPUT_ROOT.exists(), OUTPUT_ROOT

    out_path = OUTPUT_ROOT / RUN_NAME
    if out_path.exists():
        shutil.rmtree(out_path)
    out_path.mkdir()

    try:
        hydra.initialize_config_dir(config_dir=str(DPR_CONF_PATH))
    except ValueError as err:
        message = (
            "GlobalHydra is already initialized, call "
            "GlobalHydra.instance().clear() if you want to re-initialize"
        )
        if message not in err.args[0]:
            raise err


    dpr_cfg = hydra.compose(
        config_name="dense_retriever",
        overrides=["out_file=/tmp/"],
    )

    args = dict(
        conf_path=DPR_CONF_PATH,
        data_dir=DATA_DIR,
        query_aug_model_path=QUERY_AUG_MODEL_PATH,
        reader_model_path=READER_MODEL_PATH,
        dataloader_max_target_len=DATALOADER_MAX_TARGET_LEN,
        dataloader_max_source_len=DATALOADER_MAX_SOURCE_LEN,
        generation_batch_size=GENERATION_BATCH_SIZE,
        max_loop_n=MAX_LOOP_N,
        n_docs=N_DOCS,
        max_source_len=MAX_SOURCE_LEN,
        max_target_len=MAX_TARGET_LEN,
        query_aug_input_max_len=QUERY_AUG_INPUT_MAX_LEN,
        decoding_conf_reader=DECODING_CONF_READER,
        decoding_conf_query_aug=DECODING_CONF_QUERY_AUG,
        out_path=out_path,
        retriever_batch_size=RETRIEVER_BATCH_SIZE,
        aug_method=AUG_METHOD,
        final_num_contexts=FINAL_NUM_CONTEXTS,
    )

    json_output_config = dict(
        indent=2,
        default=utils.json_default,
        sort_keys=True,
    )

    utils.save_json(
        args,
        out_path / "args.json",
        **json_output_config
    )
    utils.save_json(
        omegaconf.OmegaConf.to_container(dpr_cfg),
        out_path / "config.json",
        **json_output_config
    )

    return argparse.Namespace(**args), dpr_cfg


args, dpr_cfg = build_args(ROOT_PATH)

(dataloader, tokenizer_bart, tokenizer_bert,
) = iterated_retrieval.build_tokenizers_and_datasets(
    generation_batch_size=args.generation_batch_size,
    data_dir=args.data_dir,
    max_target_len=args.dataloader_max_target_len,
    max_source_len=args.dataloader_max_source_len,
)


LOGGER.info("Done.")

(Re)/Loading iterated_utils.py
(Re)/Loading iterated_retrieval.py
(Re)/Loading common_retriever.py


See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
[36m[INFO] (2021-10-18 18:10:13,191) {iterated_utils.time_this:124}:
[0m[34mStarting:[0m Build dataloader


loading from /home/mila/g/gagnonju/IteratedDecoding/GAR/data/nq-sentence/train.target.processed (pkl)... make sure data is what you need


[36m[INFO] (2021-10-18 18:10:44,664) {iterated_utils.time_this:127}:
[0m[32mDone:[0m Build dataloader, 31.47s
[36m[INFO] (2021-10-18 18:10:44,665) {__main__.<module>:214}:
[0mDone.


In [3]:
retriever, all_passages, special_query_token = common_retriever.build_retriever(
    dpr_cfg
)

[36m[INFO] (2021-10-18 18:10:47,723) {iterated_utils.time_this:124}:
[0m[34mStarting:[0m common_retriever.load_passages (~6 min)
[36m[INFO] (2021-10-18 18:10:47,778) {dpr.options.setup_cfg_gpu:70}:
[0margs.local_rank -1
[36m[INFO] (2021-10-18 18:10:47,779) {dpr.options.setup_cfg_gpu:73}:
[0mWORLD_SIZE None
[36m[INFO] (2021-10-18 18:10:47,780) {dpr.options.setup_cfg_gpu:89}:
[0mInitialized host cn-d002 as d.rank -1 on device=cuda, n_gpu=4, world size=1
[36m[INFO] (2021-10-18 18:10:47,780) {dpr.options.setup_cfg_gpu:97}:
[0m16-bits training: False 
[36m[INFO] (2021-10-18 18:10:47,794) {dpr.data.download_data.download_resource:412}:
[0mRequested resource from https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
[36m[INFO] (2021-10-18 18:10:47,794) {dpr.data.download_data.download_resource:424}:
[0mDownload root_dir /home/mila/g/gagnonju/IteratedDecoding/DPR
[36m[INFO] (2021-10-18 18:10:47,798) {dpr.data.download_data.download_resource:435}:
[0mFile to be d

ModuleNotFoundError: No module named 'transformers.modeling_bert'

In [None]:
retriever.index.index = common_retriever.faiss_to_gpu(retriever.index.index)

In [None]:
query_aug_model, reader_model = ir.build_models(
    reader_model_path=args.reader_model_path,
    query_aug_model_path=args.query_aug_model_path,
)

In [None]:
###############################################################################
# Specific to selection technique
###############################################################################

def topk_w_torch(stuff_np: np.ndarray, k, dim):
    stuff_pt = torch.Tensor(stuff_np)
    try:
        end = torch.topk(stuff_pt, k=k, dim=dim).indices.numpy()
    except RuntimeError as err:
        raise utils.add_to_err(
            err,
            f"{stuff_pt.shape = }\n"
            f"{dim = }\n"
            f"{k = }\n"
        )

    return end


def top_k_sum(
    scores: np.ndarray, indices: np.ndarray, final_qty: int
):
    utils.check_shape(scores.shape, indices.shape)
    utils.check_equal(scores.ndim, 3)
    utils.check_equal(indices.ndim, 3)

    output = []
    # TODO: inner loops in pure python
    for batch_i in range(len(scores)):
        per_id = collections.defaultdict(int)
        for query_i in range(len(scores[batch_i])):
            for retrieved_i in range(len(scores[batch_i][query_i])):
                index = indices[batch_i][query_i][retrieved_i]
                per_id[index] += scores[batch_i][query_i][retrieved_i]

        # TODO: sort the whole list when we don't need to
        top_k = sorted(
            per_id.items(), key=lambda key_value: -key_value[1]
        )[-final_qty:]

        top_k_keys = list(zip(*top_k))[0]
        output.append(top_k_keys)

    utils.check_equal(len(output), scores.shape[0])
    output = np.asarray(output)
    utils.check_equal(output.shape[0], scores.shape[0])
    return output


def topk_w_numpy(stuff_np: np.ndarray, k, dim):
    indices = np.argpartition(
        stuff_np, -k, axis=dim
    )
    return indices


def get_reference(arr, indices):
    assert arr is not None
    assert indices is not None
    assert arr.shape[0] == indices.shape[0], (arr.shape[0], indices.shape[0])
    for batch_i in range(arr.shape[0]):
        arr[batch_i] = arr[batch_i, indices[batch_i]]

    return arr


def get_torch(arr, indices):
    return torch.gather(
        input=torch.Tensor(arr), index=torch.Tensor(indices).long(), dim=1
    ).data.numpy()


def get_numpy(arr, indices):
    return np.take_along_axis(arr, indices, 1)


@utils.class_checker
@dataclasses.dataclass
class SelectionTechniqueChecksInfo:
    batch_size: int
    num_sequences: int
    n_docs: int
    loop_i: int


@beartype.beartype
def selection_technique(
    top_ids_np: np.ndarray,
    scores_retr_np: np.ndarray,
    final_num_contexts: int,
    query_scores_batch: np.ndarray,
    checks_info: SelectionTechniqueChecksInfo,
) -> np.ndarray:

    DUMB_TOP_K = "DUMB_TOP_K"
    # UNIQUE_DUMB_TOP_K = "UNIQUE_DUMB_TOP_K"
    ADDITIVE_TOP_K = "ADDITIVE_TOP_K"
    MARGINAL = "MARGINAL_TOP_K"
    mode = ADDITIVE_TOP_K

    # Shape verifications
    utils.check_shape(top_ids_np.shape, (
         checks_info.batch_size, checks_info.num_sequences, checks_info.n_docs
    ))
    effective_batch_size, queries_per_question, n_docs = top_ids_np.shape

    if mode == DUMB_TOP_K:
        # Actual Work
        top_ids_np = top_ids_np.reshape(
            effective_batch_size,  queries_per_question * n_docs
        )
        scores_retr_np = scores_retr_np.reshape(
            effective_batch_size,  queries_per_question * n_docs
        )

        indices_w_torch = topk_w_torch(
            scores_retr_np, final_num_contexts, dim=1,
        )
        assert indices_w_torch is not None
        output = get_reference(top_ids_np, indices_w_torch)

    elif mode == ADDITIVE_TOP_K:
        output = top_k_sum(
            scores=scores_retr_np,
            indices=top_ids_np,
            final_qty=final_num_contexts,
        )
    else:
        raise NotImplementedError(f"mode {mode} not implemented or invalid")

    # Shape Verification
    try:
        utils.check_shape(
            output.shape, (checks_info.batch_size, final_num_contexts)
        )

    except ValueError as err:
        raise utils.add_to_err(
            err,
            f"\t- {checks_info.batch_size = }\n"
            f"\t- {checks_info.num_sequences = }\n"
            f"\t- {checks_info.n_docs = }\n"
            f"\t- {checks_info.loop_i = }\n"
        )

    return output


###############################################################################
# Inference
###############################################################################
@beartype.beartype
def inference(
    all_passages: Dict[str, str],
    query_aug_model: train_generator.SummarizationTrainer,
    reader_model: train_generator.SummarizationTrainer,
    special_query_token: Optional[str],
    retriever: dense_retriever.LocalFaissRetriever,
    selection_technique: Callable,
    question_dataloader: torch.utils.data.DataLoader,
    max_loop_n: int,
    decoding_conf_reader: ir.DecoderConf,
    decoding_conf_query_aug: ir.DecoderConf,
    query_aug_input_max_length: int,
    n_docs: int,
    out_path: Union[str, Path],
    retriever_batch_size: int,
    aug_method: str,
    final_num_contexts: int,
    generation_batch_size: int,
) -> None:

    out_path = Path(out_path)

    # Prepare the output files
    prefixes = dict(
        retr_outs="retr_outs_",
        reader_outs="reader_outs_",
        q_aug_outs="q_aug_outs_",
        gen_inputs="gen_inputs_",
        retr_inputs="retr_inputs_",
    )

    for prefix in prefixes.values():
        for path in out_path.glob(f"{prefix}*.jsonl"):
            LOGGER.info(f"Deleting path: {path}")
            os.remove(path)

    with torch.inference_mode(True):
        query_aug_text_all_loops = []
        query_aug_score_all_loops = []

        for loop_i in range(max_loop_n):
            LOGGER.info(f"{loop_i = }")
            output_paths = {}

            for name, prefix in prefixes.items():
                output_paths[name] = out_path / f"{prefix}{loop_i}.jsonl"

            ###################################################################
            # PREPARE THE RETRIEVAL QUERIES
            ###################################################################
            all_queries_this_loop = []
            all_queries_scores_this_loop = []
            questions_batching_generator = ir.question_generator(
                question_dataloader,
                tokenizer_bart,
                f"[{loop_i = }] Preparing the retrieval queries :: ",
            )

            if loop_i == 0:
                query_batch_generator = (
                    None for _ in range(len(question_dataloader))
                )
                query_batch_scores_generator = (
                    None for _ in range(len(question_dataloader))
                )

            else:
                query_batch_generator = more_itertools.chunked(
                    query_aug_text_all_loops[-1],
                    question_dataloader.batch_size,
                )
                query_batch_scores_generator = more_itertools.chunked(
                    query_aug_score_all_loops[-1],
                    question_dataloader.batch_size,
                )


            for batch_i, (
                questions_batch, query_aug_batch, query_aug_batch_scores
            ) in enumerate(
                more_itertools.zip_equal(
                    questions_batching_generator,
                    query_batch_generator,
                    query_batch_scores_generator,
                )
            ):

                if loop_i == 0:
                    assert query_aug_batch is None
                    # The questions are our queries.
                    all_queries_this_loop.extend(
                        [[x] for x in questions_batch]
                    )
                else:
                    query_aug_batch = np.array(query_aug_batch, dtype="object")

                    # Use the query augs to augment the question.
                    if aug_method == "RETRIEVE_ALL_INDIVIDUALLY":
                        # If we retrieve all queries individually, then
                        # we keep the 1:1 relationship between the score
                        # qty and the query qty

                        utils.check_equal(
                            query_aug_batch.shape[1],
                            decoding_conf_query_aug.num_return_sequences,
                        )
                        utils.check_equal(query_aug_batch.ndim, 2)
                        for i, (question, query_set) in enumerate(
                            more_itertools.zip_equal(
                                questions_batch,
                                query_aug_batch,
                            )
                        ):
                            per_question = []
                            for gen in query_set:
                                sentence = (
                                    question + tokenizer_bert.sep_token + gen
                                )
                                per_question.append(sentence)

                            all_queries_this_loop.append(per_question)
                    else:
                        raise ValueError(aug_method)

            iterated_retrieval.write_generations(
                all_queries_this_loop, output_paths["retr_inputs"],
            )
            

            ###################################################################
            # RETRIEVE
            ###################################################################
            retrieved_this_loop = []
            with utils.time_this("retrieve", no_start=True):
                # If we are at loop_i == 0, the number of queries is 1
                # so the number of retrievals is batch_size * 1, which is
                # num_augs times smaller than it is for loop_i > 0. To
                # compensate, we make the batches larger by a factor of
                # num_augs.
                if loop_i == 0:
                    effective_batch_size = (
                        retriever_batch_size *
                        decoding_conf_query_aug.num_return_sequences
                    )
                    queries_per_question = 1
                    all_queries_scores_this_loop = [
                        None for _ in range(len(all_queries_this_loop))
                    ]
                else:
                    effective_batch_size = retriever_batch_size
                    queries_per_question = (
                        decoding_conf_query_aug.num_return_sequences
                    )
                    all_queries_scores_this_loop = (
                        query_aug_score_all_loops[-1]
                    )

                # Make sure we have as many scores as we have queries.
                # This should always be true.
                try:
                    utils.check_equal(
                        len(all_queries_this_loop),
                        len(all_queries_scores_this_loop),
                    )

                except ValueError as err:
                    raise utils.add_to_err(
                        err, (
                            f"{len(all_queries_this_loop) = }\n"
                            f"{np.array(all_queries_this_loop).shape = }\n"
                            f"{len(all_queries_scores_this_loop) = }\n"
                            f"{np.array(all_queries_scores_this_loop).shape = }\n"
                            f"{loop_i = }\n"
                        )
                    )

                for batch_i, (query_batch, query_scores_batch) in enumerate(
                    more_itertools.zip_equal(
                        more_itertools.chunked(
                            tqdm.notebook.tqdm(
                                all_queries_this_loop,
                                desc="retrieval all_queries_this_loop",
                            ),
                            effective_batch_size
                        ),
                        more_itertools.chunked(
                            all_queries_scores_this_loop,
                            effective_batch_size,
                        ),
                    )
                ):

                    # Retrieve.
                    if aug_method == "RETRIEVE_ALL_INDIVIDUALLY":
                        query_batch_np = np.array(
                            query_batch, dtype="object",
                        ).reshape(-1)
                        
                        real_batch_size = len(query_batch)

                        # TODO: Make sure the reshaping makes sense
                        top_ids_and_scores = common_retriever.retrieve(
                            retriever=retriever,
                            all_passages=all_passages,
                            questions=query_batch_np,
                            special_query_token=special_query_token,
                            n_docs=n_docs,
                        )

                        ################################################
                        # Deal with contexts
                        ################################################
                        top_ids, scores_retr = more_itertools.zip_equal(
                            *top_ids_and_scores
                        )

                        try:
                            top_ids_np = np.array(top_ids).reshape(
                                real_batch_size,
                                queries_per_question,
                                n_docs
                            )
                            scores_retr_np = np.array(scores_retr).reshape(
                                real_batch_size,
                                queries_per_question,
                                n_docs,
                            )
                        except ValueError as err:
                            args = utils.add_to_err(
                                f"\t- {top_ids = }\n"
                                f"\t- {real_batch_size = }\n"
                                f"\t- {effective_batch_size = }\n"
                                f"\t- {queries_per_question = }\n"
                                f"\t- {n_docs = }\n"
                                f"\t- {loop_i = }"
                            )
                            raise err

                        selected_contexts_ids_np = selection_technique(
                            top_ids_np,
                            scores_retr_np,
                            final_num_contexts,
                            np.array(query_scores_batch),
                            SelectionTechniqueChecksInfo(
                                batch_size=real_batch_size,
                                num_sequences=queries_per_question,
                                n_docs=n_docs,
                                loop_i=loop_i,
                            )
                        )

                        utils.check_shape(
                            selected_contexts_ids_np.shape,
                            (
                                real_batch_size,
                                final_num_contexts
                            )
                        )

                        retrieved_this_loop.extend(selected_contexts_ids_np)

                    else:
                        raise ValueError(aug_method)

                    iterated_retrieval.write_contexts(
                        all_contexts=all_passages,
                        context_ids=selected_contexts_ids_np.tolist(),
                        out_path=output_paths["retr_outs"],
                    )

            del selected_contexts_ids_np
            del all_queries_this_loop
            del all_queries_scores_this_loop

            ###################################################################
            # Generation with the Barts
            ###################################################################
            LOGGER.info(f"[{loop_i = }] Starting generation.")
            query_aug_text_all_loops.append([])
            query_aug_score_all_loops.append([])

            utils.check_equal(len(query_aug_text_all_loops), loop_i + 1)
            utils.check_equal(len(query_aug_score_all_loops), loop_i + 1)
            tqdm_info = f"[{loop_i = }] Generating with BART models :: "
            num_batchs_generation = np.ceil(
                len(retrieved_this_loop) / question_dataloader.batch_size
            )
            utils.check_equal(
                question_dataloader.batch_size,
                generation_batch_size,
            )
            try:
                utils.check_equal(
                    num_batchs_generation,
                    len(question_dataloader),
                )
            except ValueError as err:
                err = utils.add_to_err(
                    f"\t- {loop_i = }\n"
                    , err
                )
                raise err

            for batch_i, (questions_batch, context_batch) in enumerate(
                more_itertools.zip_equal(
                    iterated_retrieval.question_generator(
                        question_dataloader, tokenizer_bart, tqdm_info,
                    ),
                    more_itertools.chunked(
                        retrieved_this_loop, generation_batch_size,
                    ),
                )
            ):
                utils.check_equal(len(questions_batch), len(context_batch))

                try:
                    utils.check_batch_size(
                        len(questions_batch),
                        generation_batch_size,
                        len(question_dataloader.dataset),
                    )
                    utils.check_batch_size(
                        len(context_batch),
                        generation_batch_size,
                        len(question_dataloader.dataset),
                    )
                except RuntimeError as err:
                    raise utils.add_to_err(
                        err,
                        f"\t- {loop_i = }\n"
                        f"\t- {batch_i = }\n"
                    )

                ###############################################################
                # PREPARE GENERATION INPUTS
                ###############################################################
                # Take the contexts and append them to the questions
                gen_inputs_text = []
                for question, selected_ids in more_itertools.zip_equal(
                    questions_batch,
                    context_batch,
                ):
                    contexts = [
                        all_passages[ids_].text
                        for ids_ in selected_ids
                    ]

                    generation_input = (
                        question + tokenizer_bart.sep_token +
                        tokenizer_bart.sep_token.join(contexts)
                    )

                    gen_inputs_text.append(generation_input)

                utils.check_batch_size(
                    len(gen_inputs_text),
                    generation_batch_size,
                    len(question_dataloader.dataset),
                )

                gen_inputs = tokenizer_bart.batch_encode_plus(
                    gen_inputs_text,
                    return_tensors="pt",
                    pad_to_max_length=True,
                    max_length=query_aug_input_max_length,
                )

                ###############################################################
                # QUERY_AUG INFERENCE
                ###############################################################

                utils.check_batch_size(
                    gen_inputs["input_ids"].shape[0],
                    generation_batch_size,
                    len(question_dataloader.dataset),
                )

                query_aug_ids_batch, query_aug_scores_batch = (
                    iterated_retrieval.decode(
                        model=query_aug_model,
                        batch=gen_inputs,
                        tokenizer=tokenizer_bart,
                        decoding_conf=decoding_conf_query_aug,
                    )
                )

                try:
                    utils.check_batch_size(
                        query_aug_ids_batch.shape[0],
                        generation_batch_size,
                        len(question_dataloader.dataset),
                    )
                    utils.check_batch_size(
                        query_aug_scores_batch.shape[0],
                        generation_batch_size,
                        len(question_dataloader.dataset),
                    )

                except RuntimeError as err:
                    raise utils.add_to_err(err,
                        f"{query_aug_ids_batch.shape = }\n" +
                        f"{query_aug_scores_batch.shape = }\n" +
                        f"{gen_inputs['input_ids'].shape = }\n"
                    )

                query_aug_text_batch = []
                for (
                    question, query_aug_input, query_aug_ids_per_question
                ) in more_itertools.zip_equal(
                    questions_batch,
                    gen_inputs["input_ids"],
                    query_aug_ids_batch,
                ):
                    texts_per_question = []
                    for generation in query_aug_ids_per_question:
                        gen = tokenizer_bart.decode(generation)
                        cleaned = ir.clean_bart_decode(gen, tokenizer_bart)
                        texts_per_question.append(cleaned)
                    query_aug_text_batch.append(texts_per_question)

                assert len(query_aug_text_all_loops) == loop_i + 1, (
                    len(query_aug_text_all_loops), loop_i + 1
                )
                assert len(query_aug_score_all_loops) == loop_i + 1, (
                    len(query_aug_score_all_loops), loop_i + 1
                )

                query_aug_text_batch = np.array(
                    query_aug_text_batch, dtype="object"
                )

                # Make sure that query_aug_text_batch are of the expected shape
                utils.check_shape(
                    query_aug_text_batch.shape,
                    (
                        query_aug_ids_batch.shape[0],
                        decoding_conf_query_aug.num_return_sequences
                    )
                )

                # The quantity of query aug to query score should be 1:1
                utils.check_equal(
                    query_aug_text_batch.shape[0],
                    query_aug_scores_batch.shape[0],
                )

                # Accumulate the query augmentation text by loop
                query_aug_text_all_loops[loop_i].extend(
                    query_aug_text_batch
                )

                # Accumulate the query auggmentation generation score per loop
                query_aug_score_all_loops[loop_i].extend(
                    query_aug_scores_batch
                )

                ###############################################################
                # READER INFERENCE
                ###############################################################
#                 reader_batch_ids, reader_batch_scores = (
#                     iterated_retrieval.decode(
#                         model=reader_model,
#                         batch=gen_inputs,
#                         tokenizer=tokenizer_bart,
#                         decoding_conf=decoding_conf_reader,
#                     )
#                 )
#                 reader_batch_ids = (
#                     reader_batch_ids.squeeze(1)
#                 )
#                 # Decode the tokens of the batch
#                 reader_text_batch = []
#                 for (
#                     question, query_aug_input, generations_ids, scores
#                 ) in more_itertools.zip_equal(
#                     questions_batch,
#                     gen_inputs["input_ids"],
#                     reader_batch_ids,
#                     reader_batch_scores,
#                 ):
#                     reader_text_batch.append(
#                         ir.clean_bart_decode(
#                             tokenizer_bart.decode(generations_ids),
#                             tokenizer_bart
#                         )
#                     )

                ###############################################################
                # Deal with the generated text: reader inference
                ###############################################################
                # iterated_retrieval.write_generations(
                #     reader_text_batch,
                #     output_paths["reader_outs"],
                # )
                iterated_retrieval.write_generations(
                    query_aug_text_batch.tolist(),
                    output_paths["q_aug_outs"],
                )
                iterated_retrieval.write_generations(
                    gen_inputs_text,
                    output_paths["gen_inputs"],
                )


inference(
    all_passages=all_passages,
    query_aug_model=query_aug_model.cuda(),
    reader_model=reader_model.cuda(),
    special_query_token=special_query_token,
    retriever=retriever,
    selection_technique=selection_technique,
    question_dataloader=dataloader,
    max_loop_n=args.max_loop_n,
    query_aug_input_max_length=args.max_source_len,
    decoding_conf_query_aug=args.decoding_conf_query_aug,
    decoding_conf_reader=args.decoding_conf_reader,
    n_docs=args.n_docs,
    out_path=args.out_path,
    retriever_batch_size=args.retriever_batch_size,
    aug_method=args.aug_method,
    final_num_contexts=args.final_num_contexts,
    generation_batch_size=args.generation_batch_size,
)

In [None]:
batches = list(iterated_retrieval.question_generator(
    dataloader, tokenizer_bart, "asd",
))