In [None]:
import json
import nltk
import torch
from transformers import AutoModel

nltk.download("punkt_tab")

retrieved_contexts_file_path = "/kaggle/input/retrieved-contexts/retrieved_contexts_45.json"
benchmark_file_path = "/kaggle/input/benchmark-techqa/benchmark_query_rewriting.json"
output_file_path = "retrieved_contexts_45_pruned.json"

# Open input files
with open(retrieved_contexts_file_path, "r") as file:
    retrieved_contexts = json.load(file)
    
with open(benchmark_file_path, "r") as benchmark_file:
    benchmark_instances = json.load(benchmark_file)

# Set GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Load the pruning model
pruning_model = AutoModel.from_pretrained("naver/provence-reranker-debertav3-v1", trust_remote_code=True).to(device)

new_retrieved_contexts = []
quest_count = 0
for benchmark_instance, retrieved_documents  in zip(benchmark_instances, retrieved_contexts):
    
    if quest_count % 50 == 0:
        print(str(quest_count))
    quest_count += 1
    
    new_question = benchmark_instance["rewrited_question"]

    new_context = []
    for document in retrieved_documents:
        # Retrieve all sections text of the document
        section_texts = []
        for section in document["sections"]:
            if section["section_text"] != "":
                section_texts.append(section["section_text"])

        # Prune sections text
        new_sections = []
        pruning_model_outputs = pruning_model.process(new_question, [section_texts])["pruned_context"][0]
        for pruned_text, section in zip(pruning_model_outputs, document["sections"]):
            new_sections.append(section)
            # Add pruned text to section
            section["pruned_context"] = pruned_text
            
        # Create and populate the new document
        new_document = document
        new_document["sections"] = new_sections
        new_context.append(new_document)
        
    new_retrieved_contexts.append(new_context)

# Save new contexts in the output file
with open(output_file_path, "w") as file:
    json.dump(new_retrieved_contexts, file, indent=4)