In [None]:
import json
from peft import PeftModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

retrieved_contexts_file_path = "/kaggle/input/retrieved-contexts/retrieved_contexts_45.json"
benchmark_file_path = "/kaggle/input/benchmark-techqa/benchmark_query_rewriting.json"
file_path_out = "/kaggle/working/answerable_questions.json"

# Open input files
with open(retrieved_contexts_file_path, "r") as file:
    retrieved_contexts = json.load(file)
    
with open(benchmark_file_path, "r") as benchmark_file:
    benchmark_instances = json.load(benchmark_file)

# Answerability tag for LoRA
ANSWERABILITY_PROMPT = "<|start_of_role|>answerability<|end_of_role|>"

# Load models
model_name = "ibm-granite/granite-3.2-8b-instruct"
LORA_NAME = "ibm-granite/granite-3.2-8b-lora-rag-answerability-prediction"
    
# Set up GPUs if available
if torch.cuda.device_count() > 1:
    print(f"Usando {torch.cuda.device_count()} GPU")
    device_map = "auto"

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
# Quantization in 8bit for reducing model size
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True, 
    bnb_8bit_compute_dtype=torch.float16,
    bnb_8bit_use_double_quant=True,
    bnb_8bit_quant_type="nf8"
)
    
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", trust_remote_code=True)
    
# Load base model
model_base = AutoModelForCausalLM.from_pretrained(
    model_name, 
    trust_remote_code=True, 
    quantization_config=quantization_config, 
    device_map=device_map,
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True
)

# Load answerability model
model_answerability = PeftModel.from_pretrained(model_base, LORA_NAME)

answerable_questions = []
cont_ques = 0
for benchmark_instance, retrieved_documents in zip(benchmark_instances, retrieved_contexts):

    if cont_ques % 50 == 0:
        print(str(cont_ques))
    cont_ques = cont_ques + 1
    
    # Rebuild document's text
    documents = []
    for document in retrieved_documents:
        document_text = ""
        for section in document["sections"]:
            document_text += section["section_text"]
        documents.append({"text": document_text, "document_id": document["document_id"]})
    
    messages = [
        {"role": "user", "content": benchmark_instance["question"]}
    ]

    # Prepare answerability prompt
    string = tokenizer.apply_chat_template(messages, documents=documents, tokenize=False, add_generation_prompt=False)
    inputs = string + ANSWERABILITY_PROMPT
    
    inputs_tokenized = tokenizer(
        inputs, 
        return_tensors="pt",
        truncation=True, 
        max_length=32768
    )
    
    input_ids = inputs_tokenized["input_ids"].to(device)
    # Avoid gradient calculation during inference since its not needed
    with torch.no_grad():
        # Generate answerability prediction using LoRA
        output = model_answerability.generate(
            input_ids,
            attention_mask=inputs_tokenized["attention_mask"].to(device),
            max_new_tokens=3,
            use_cache=True
        )

    generated_only_ids = output[0][input_ids.shape[-1]:]
    answerability = tokenizer.decode(generated_only_ids, skip_special_tokens=True).strip()
    
    if "unanswerable" in answerability:
        answerable_questions.append(0)
    else:
        answerable_questions.append(1)
    
print("Ended answerability evaluation")

with open(file_path_out, "w") as file:
    json.dump(answerable_questions, file, indent=4)