## Prepare QA

In [2]:
# ...existing code...
import json
def prepare_qa(file_path: str):
    """
    Prepares QA pairs from a JSONL file.
    """
    with open(file_path, 'r', encoding='utf-8') as file:  # Specify UTF-8 encoding
        data = [json.loads(line) for line in file]

    qa = []
    for item in data:
        qa.append({'question': item['question'], 'options': item['options'], 'answer': item['answer_idx']})
    
    return qa

# ...existing code...

In [3]:
QA = prepare_qa(r"E:\Git_clone\RAG\qa_dataset\data_clean\questions\US\4_options\phrases_no_exclude_test.jsonl")

## Evaluation

In [4]:
import chatbot, importlib
importlib.reload(chatbot)
from chatbot import Chatbot
chatbot = Chatbot("mistral")

### Vector retriever

In [5]:
from vectordb import create_retriever

In [6]:
vretriever = create_retriever(r"/workspaces/YuE/faiss_index")

### Graph retriever

In [7]:
from graphdb import gretriever

In [8]:
def retrieve(query, rag_type = None, k = 5):
    if rag_type == None:
        return ""
    elif rag_type == "rag":
        contexts = vretriever.get_relevant_documents(query, k=k)
        return "\n\n".join([context.page_content for context in contexts])
    elif rag_type == "grag":
        contexts = vretriever.get_relevant_documents(query, k=k)
        contexts = gretriever("\n".join([context.page_content for context in contexts]), extract_model="mistral")
        return "\n".join([context for context in contexts])

In [9]:
def process_qa(qa, rag_type = None, k = 5):
    context = retrieve(qa['question'], rag_type, k=k)
    prompt = f"""
    You are a medical expert. Answer the question by coorperate the provided context with your knowledgement.
    Document: {context}
    Question: {qa['question']}
    Options:
    A: {qa['options']['A']}
    B: {qa['options']['B']}
    C: {qa['options']['C']}
    D: {qa['options']['D']}
    Answer: (Only return A:, B:, C:, or D: without any explanation or other text. Do not include the context or the question in your answer.) 
    """
    response = chatbot.chat(prompt)
    # print("Prompt: ", prompt, '\n')
    if "A:" in response:
        answer = "A"
    elif "B:" in response:
        answer = "B"
    elif "C:" in response:
        answer = "C"
    elif "D:" in response:
        answer = "D"
    else:
        answer = None

    # print(response," ", answer, "  ", qa['answer'])
    return answer == qa['answer']

In [10]:
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import time

def process_qa_with_retry(qa, rag_type=None, retries=8, delay=1):
    """
    Processes a single QA pair with retry logic for rate-limiting errors.

    Parameters:
        qa (dict): A single QA pair to process.
        rag_type (str): Type of retriever to use (e.g., "rag", "grag").
        retries (int): Number of retries for rate-limiting errors.
        delay (int): Delay in seconds between retries.

    Returns:
        bool: Whether the processed answer matches the expected answer.
    """
    for attempt in range(retries):
        try:
            return process_qa(qa, rag_type)
        except Exception as e:
            if "rate limit" in str(e).lower():
                # print(f"Rate limit encountered. Retrying in {delay} seconds... (Attempt {attempt + 1}/{retries})")
                time.sleep(delay)
            else:
                print(f"Error processing QA: {e}")
                break
    return False  # Return False if all retries fail

def process_qa_parallel(QA, n_workers=4, rag_type=None, k=5):
    """
    Processes QA pairs in parallel and calculates accuracy.

    Parameters:
        QA (list): List of QA pairs to process.
        n_workers (int): Number of worker threads to use.
        rag_type (str): Type of retriever to use (e.g., "rag", "grag").

    Returns:
        float: Accuracy of the processed QA pairs.
    """
    with ThreadPoolExecutor(max_workers=n_workers) as executor:
        # Use tqdm for progress tracking
        results = list(tqdm(
            executor.map(lambda qa: process_qa_with_retry(qa, rag_type, k), QA),
            total=len(QA),
            desc="Processing QA"
        ))

    # Calculate and return accuracy
    total_correct = sum(results)
    accuracy = total_correct / len(QA)
    return accuracy

In [1]:
accuracy = process_qa_parallel(QA, n_workers=1, rag_type="grag", k=8)
print(f"Accuracy: {accuracy}")


NameError: name 'process_qa_parallel' is not defined

In [None]:
accuracy

0.7250589159465829