In [None]:
import json
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

retrieved_contexts_file_path = "/kaggle/input/retrieved-contexts-45-no-pruning/retrieved_contexts_45.json"
benchmark_file_path = "/kaggle/input/benchmark-techqa/benchmark_query_rewriting.json"
generated_answers_file_path = "/kaggle/working/generated_answers.json"

with open(retrieved_contexts_file_path, "r") as file:
    retrieved_contexts = json.load(file)

with open(benchmark_file_path, "r") as benchmark_file:
    benchmark_instances = json.load(benchmark_file)

# Prompt message for LLM
prompt = """
You are Granite, an AI developed by IBM. You are a helpful RAG (Retrieval-Augmented Generation) system designed to answer user queries based only on the content of the retrieved documents.

### Key Instructions:
1. **Only Use Retrieved Documents for Answers**: Provide an answer using specific, direct information in the retrieved documents. 
2. **No Speculation**: Do not try to make inferences or use general knowledge.
3. **Do Not Mention Documents**: Never refer to, mention, or include any details about the documents in your response. Do not say things like "According to the document," or "The document indicates...". Simply provide the answer.
4. **No Extra Information**: Do not elaborate or provide additional context.
"""  

# Setup the LLM to generate answers
model_name = "ibm-granite/granite-3.2-8b-instruct"
max_len = 16384
tokenizer = AutoTokenizer.from_pretrained(model_name)
llm = LLM(model=model_name, tensor_parallel_size=2, trust_remote_code=True, max_model_len = max_len, gpu_memory_utilization=0.90)
sampling_params = SamplingParams(temperature=0.05, max_tokens=8192, stop_token_ids=[tokenizer.eos_token_id])


quest_count = 0
generated_answers = []
all_messages = []
all_document_texts = [] 
answerable_queries = []
correct_answers = []
for benchmark_instance, retrieved_documents in zip(benchmark_instances, retrieved_contexts):
    # Get only answerable questions
    if benchmark_instance["is_impossible"] == True:
        continue
    
    if quest_count % 25 == 0:
        print(quest_count)

    answerable_queries.append(benchmark_instance["question"])
    correct_answers.append(benchmark_instance["answer"])
    
    quest_count += 1
    document_texts = []
    text_len = 0
    # Get text of all documents
    for document in retrieved_documents:
        document_text = ""
        for section in document["sections"]:
            document_text = document_text + section["section_text"]
        text_len = text_len + len(document_text)
        document_texts.append(document_text)
    
    # Create chat_template for the LLM
    messages = [
        {"role": "system", "content" : prompt},
        {"role": "user", "content": benchmark_instance["question"]},
        {"role": "system", "content" : "Retrieved Documents:" + str(document_texts)}        
    ]
        
 
    all_messages.append(messages)
    all_document_texts.append(document_texts)


print("Created all messages")

# Generate all the answers for answerable questions
prompt_token_ids = [tokenizer.apply_chat_template(message, add_generation_prompt=True)[:max_len] for message in all_messages]

outputs = None
with torch.no_grad():
    outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)

# Get the text of all the generated answers
llm_answers = [output.outputs[0].text for output in outputs]
for llm_answer, query, context, correct_answer in zip(llm_answers, answerable_queries, all_document_texts, correct_answers):
    # Prepare the output
    generated_answers.append(
        {
            "user_input": query, 
            "retrieved_contexts": context, 
            "answer": llm_answer, 
            "reference" : correct_answer
        }
    )

with open(generated_answers_file_path, "w") as file:
    json.dump(generated_answers, file, indent=4)