In [None]:
import os
from huggingface_hub import login
login(os.environ['HUGGINGFACE_ACCESS_TOKEN'])
from transformers import LlamaForCausalLM, AutoTokenizer

# Lost in the middle

In [None]:
import dataclasses
import json
import logging
import math
import pathlib
import random
import sys
from copy import deepcopy
import torch
from tqdm import tqdm
from xopen import xopen
from pydantic.dataclasses import dataclass
from typing import List, Optional, Tuple, Type, TypeVar
T = TypeVar("T")


@dataclass(frozen=True)
class Document:
    title: str
    text: str
    id: Optional[str] = None
    score: Optional[float] = None
    hasanswer: Optional[bool] = None
    isgold: Optional[bool] = None
    original_retrieval_index: Optional[int] = None

    @classmethod
    def from_dict(cls: Type[T], data: dict) -> T:
        data = deepcopy(data)
        if not data:
            raise ValueError("Must provide data for creation of Document from dict.")
        id = data.pop("id", None)
        score = data.pop("score", None)
        # Convert score to float if it's provided.
        if score is not None:
            score = float(score)
        return cls(**dict(data, id=id, score=score))


def get_qa_prompt(
    question: str, documents: List[Document], file_name: str = None
):
    with open(file_name) as f:
        prompt_template = f.read().rstrip("\n")

    # Format the documents into strings
    formatted_documents = []
    for document_index, document in enumerate(documents):
        formatted_documents.append(f"Document [{document_index+1}](Title: {document.title}) {document.text}")
    return prompt_template.format(question=question, search_results="\n".join(formatted_documents))

## load model

### llama 7b chat

In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model_label = 'llama2-7b'

## Collect predictions

In [None]:
total_doc_num = 20
position = 9
prompt_file = 'qa.prompt'
start_idx = 100
end_idx = 200
last_break_idx = 0

In [None]:
# Create directory for output path if it doesn't exist.
input_path = 'data/lost-in-the-middle/qa_data/%d_total_documents/nq-open-%d_total_documents_gold_at_%d.jsonl.gz' % (total_doc_num, total_doc_num, position)
output_path = 'data/lost-in-the-middle/qa_predictions/%s-prediction-%d-%d-%d-%d-%s.jsonl.gz' % (model_label, total_doc_num, position, start_idx, end_idx, prompt_file)
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)

# Fetch all of the prompts
with xopen(input_path) as fin:
    samples = []
    for idx, line in enumerate(fin):
        if idx >= start_idx:
            if (idx - start_idx) < last_break_idx:
                continue
            if idx >= end_idx:
                break
        else:
            continue
        input_example = json.loads(line)
        samples.append(input_example)
        
        
examples = []
prompts = []
all_model_documents = []
for input_example in samples:
    question = input_example["question"]
    
    documents = []
    for ctx in deepcopy(input_example["ctxs"]):
        documents.append(Document.from_dict(ctx))

    prompt = get_qa_prompt(
        question,
        documents,
        file_name='data/lost-in-the-middle/prompts/qa.prompt',
    )

    prompts.append(prompt)
    examples.append(deepcopy(input_example))
    all_model_documents.append(documents)

with xopen(output_path, "a") as f:
    for example, model_documents, prompt in tqdm(zip(examples, all_model_documents, prompts), total=len(prompts)):
        
        model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
        output = model.generate(**model_inputs)
        response = tokenizer.decode(output[0], skip_special_tokens=True)
    
        output_example = deepcopy(example)
        # Add some extra metadata to the output example
        output_example["model_prompt"] = prompt
        output_example["model_documents"] = [dataclasses.asdict(document) for document in model_documents]
        output_example["model_answer"] = response[len(prompt):]
        f.write(json.dumps(output_example) + "\n")