In [1]:
import os
os.environ['OPENAI_API_KEY'] = 'your_openai_api_key_here'
os.environ['MISTRAL_API_KEY'] = 'your_mistral_api_key_here'

## Prepare QA

In [26]:
# ...existing code...
import json
def prepare_qa(file_path: str):
    """
    Prepares QA pairs from a JSONL file.
    """
    with open(file_path, 'r', encoding='utf-8') as file:  # Specify UTF-8 encoding
        data = [json.loads(line) for line in file]

    qa = []
    for item in data:
        qa.append({'question': item['question'], 'options': item['options'], 'answer': item['answer_idx']})
    
    return qa

# ...existing code...

In [27]:
QA = prepare_qa(r"E:\Git_clone\RAG\qa_dataset\data_clean\questions\US\4_options\phrases_no_exclude_test.jsonl")

## Evaluation

In [28]:
import chatbot, importlib
importlib.reload(chatbot)
from chatbot import Chatbot
chatbot = Chatbot("mistral")

### Vector retriever

In [29]:
from vectordb import create_retriever

In [30]:
vretriever = create_retriever(r"/workspaces/YuE/faiss_index")

### Graph retriever

In [31]:
from graphdb import gretriever

In [32]:
def retrieve(query, rag_type = None, k = 5):
    if rag_type == None:
        return ""
    elif rag_type == "rag":
        contexts = vretriever.get_relevant_documents(query, k=k)
        return "\n\n".join([context.page_content for context in contexts])
    elif rag_type == "grag":
        contexts = vretriever.get_relevant_documents(query, k=k)
        contexts = gretriever("\n".join([context.page_content for context in contexts]), "openai")
        return "\n".join([context for context in contexts])

In [45]:
def process_qa(qa, rag_type = None):
    context = retrieve(qa['question'], rag_type)
    prompt = f"""
    You are a medical expert. Answer the question by coorperate the provided context with your knowledgement.
    Document: {context}
    Question: {qa['question']}
    Options:
    A: {qa['options']['A']}
    B: {qa['options']['B']}
    C: {qa['options']['C']}
    D: {qa['options']['D']}
    Answer: (Only return A:, B:, C:, or D:)
    """
    response = chatbot.chat(prompt)
    # print("Prompt: ", prompt, '\n')
    if "A:" in response:
        answer = "A"
    elif "B:" in response:
        answer = "B"
    elif "C:" in response:
        answer = "C"
    elif "D:" in response:
        answer = "D"
    else:
        answer = None

    print(response," ", answer, "  ", qa['answer'])
    return answer == qa['answer']

In [46]:
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import time

def process_qa_with_retry(qa, rag_type=None, retries=5, delay=1):
    """
    Processes a single QA pair with retry logic for rate-limiting errors.

    Parameters:
        qa (dict): A single QA pair to process.
        rag_type (str): Type of retriever to use (e.g., "rag", "grag").
        retries (int): Number of retries for rate-limiting errors.
        delay (int): Delay in seconds between retries.

    Returns:
        bool: Whether the processed answer matches the expected answer.
    """
    for attempt in range(retries):
        try:
            return process_qa(qa, rag_type)
        except Exception as e:
            if "rate limit" in str(e).lower():
                # print(f"Rate limit encountered. Retrying in {delay} seconds... (Attempt {attempt + 1}/{retries})")
                time.sleep(delay)
            else:
                print(f"Error processing QA: {e}")
                break
    return False  # Return False if all retries fail

def process_qa_parallel(QA, n_workers=4, rag_type=None):
    """
    Processes QA pairs in parallel and calculates accuracy.

    Parameters:
        QA (list): List of QA pairs to process.
        n_workers (int): Number of worker threads to use.
        rag_type (str): Type of retriever to use (e.g., "rag", "grag").

    Returns:
        float: Accuracy of the processed QA pairs.
    """
    with ThreadPoolExecutor(max_workers=n_workers) as executor:
        # Use tqdm for progress tracking
        results = list(tqdm(
            executor.map(lambda qa: process_qa_with_retry(qa, rag_type), QA),
            total=len(QA),
            desc="Processing QA"
        ))

    # Calculate and return accuracy
    total_correct = sum(results)
    accuracy = total_correct / len(QA)
    return accuracy

In [47]:
accuracy = process_qa_parallel(QA, n_workers=1, rag_type=None)
print(f"Accuracy: {accuracy}")

# 1273/1273
# Accuracy: 0.6197957580518461 (no rag)
# Accuracy: 0.6975648075412412 (rag)
# 

# 100/1273
# Acc: 0.71
# Acc: 0.72
# Acc: 0.72

Processing QA:   0%|          | 1/1273 [00:00<09:06,  2.33it/s]

A:   A    B


Processing QA:   0%|          | 2/1273 [00:00<08:17,  2.56it/s]

B:   B    D


Processing QA:   0%|          | 3/1273 [00:02<21:06,  1.00it/s]

B:   B    B


Processing QA:   0%|          | 4/1273 [00:05<37:27,  1.77s/it]

B:   B    D


Processing QA:   0%|          | 5/1273 [00:08<46:01,  2.18s/it]

B: Ketotifen eye drops   B    B


Processing QA:   0%|          | 6/1273 [00:08<32:55,  1.56s/it]

A:   A    D


Processing QA:   1%|          | 7/1273 [00:12<45:40,  2.16s/it]

B:   B    C


Processing QA:   1%|          | 8/1273 [00:13<42:51,  2.03s/it]

C:   C    C


Processing QA:   1%|          | 9/1273 [00:14<32:41,  1.55s/it]

B:   B    B


Processing QA:   1%|          | 10/1273 [00:15<31:44,  1.51s/it]

Based on the information provided, the most likely additional finding in a patient with nail changes is:

C: Erosions of the dental enamel

This is because nail changes can often be associated with conditions that also affect dental enamel, such as eating disorders or certain systemic conditions.   C    A


Processing QA:   1%|          | 11/1273 [00:17<34:31,  1.64s/it]

D:   D    D


Processing QA:   1%|          | 12/1273 [00:18<26:48,  1.28s/it]

D: Ruxolitinib   D    D


Processing QA:   1%|          | 13/1273 [00:18<20:57,  1.00it/s]

B:   B    B


Processing QA:   1%|          | 14/1273 [00:20<24:59,  1.19s/it]

A   None    D


Processing QA:   1%|          | 15/1273 [00:20<20:10,  1.04it/s]

C:   C    C


Processing QA:   1%|▏         | 16/1273 [00:21<18:27,  1.13it/s]

B:   B    B


Processing QA:   1%|▏         | 17/1273 [00:21<15:26,  1.36it/s]

D:   D    D


Processing QA:   1%|▏         | 18/1273 [00:24<30:49,  1.47s/it]

D:   D    D


Processing QA:   1%|▏         | 19/1273 [00:25<27:34,  1.32s/it]

A:   A    B


Processing QA:   2%|▏         | 20/1273 [00:27<30:23,  1.46s/it]

B:   B    D


Processing QA:   2%|▏         | 21/1273 [00:30<40:00,  1.92s/it]

C:   C    C


Processing QA:   2%|▏         | 22/1273 [00:31<30:51,  1.48s/it]

A:   A    A


Processing QA:   2%|▏         | 23/1273 [00:32<32:42,  1.57s/it]

C:   C    C


Processing QA:   2%|▏         | 24/1273 [00:33<25:40,  1.23s/it]

A:   A    D


Processing QA:   2%|▏         | 25/1273 [00:34<28:26,  1.37s/it]

A:   A    A


Processing QA:   2%|▏         | 26/1273 [00:38<39:36,  1.91s/it]

D:   D    D


Processing QA:   2%|▏         | 27/1273 [00:43<1:02:02,  2.99s/it]

D:   D    D


Processing QA:   2%|▏         | 28/1273 [00:44<45:55,  2.21s/it]  

A:   A    A


Processing QA:   2%|▏         | 29/1273 [00:45<38:48,  1.87s/it]

D:   D    D


Processing QA:   2%|▏         | 29/1273 [00:46<33:26,  1.61s/it]

C:   C    C





D:   D    D


KeyboardInterrupt: 

In [None]:
accuracy

0.7234878240377062