In [1]:
import json
import random
import sys
import os

# Add the project root directory to the Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), '../../..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from src.agents.context_qa_correction_agent import ContextQACorrectionAgent
from src.utils.data_generation import split_dataset, concurrent_data_postprocessing
random.seed(2)

In [2]:
PROC_NUM = 2
ERROR_DETECTION_ONLY = False
SPLIT = "train"
NUM_MAJOR_CHUNKS = 10

In [3]:
with open(f"../../../dataset/raw_model_responses/train/train_rajpurkar_squad.json", "r") as f:
    json_data = json.load(f)

In [4]:
# json_data = json_data[:500]

In [5]:
data_chunks = split_dataset(json_data, NUM_MAJOR_CHUNKS)

In [6]:
agent = ContextQACorrectionAgent(error_detection_only=ERROR_DETECTION_ONLY)

In [7]:
def extract_args(item: dict) -> list[str]:
    return {
        "input": item["input"],
        "question": item["additional_info"]["question"],
        "answer": item["additional_info"]["answer"],
        "context": item["additional_info"]["context"],
        "responses": item["response"],
    }

In [None]:
final_result = []

for index, chunk in enumerate(data_chunks):
    print(f"Processing chunk {index + 1} of {len(data_chunks)}")
    proc_data_split = split_dataset(chunk, PROC_NUM)
    data = await concurrent_data_postprocessing(
        agent=agent, 
        data_chunks=proc_data_split, 
        extract_args=extract_args,
        max_concurrency=2,
    )
    final_result.extend(data)

In [9]:
len(final_result)

87557

In [10]:
with open(f"../../../dataset/processed_data/{SPLIT}/rajpurkar_squad_processed.json", "w") as f:
    json.dump(final_result, f, indent=4)