In [15]:
import json

In [None]:
from glob import glob
from tqdm import tqdm 
from math_verify import verify, parse


def do_verify(nsol, b):
    res = 0
    try:
        a = parse(nsol)
        if len(a)==0: res = -4 # empty return when parsing prediction
        else: res = int(verify(a, b))
    except: 
        res = -5 # exception when parsing prediction
    return res 

def normalize_answer(answer):
    if 'dfrac' in answer: answer = answer.replace("dfrac", "frac")
    if '%' in answer: answer = answer.replace(r'\%',"").replace('%',"")
    if 'text' in answer: answer = answer.replace("\\text","")
    if "\\varnothing" in answer: answer = answer.replace("\\varnothing","\\emptyset")
    if "minutes" in answer: answer = answer.replace("minutes","")
    if "cm" in answer: answer = answer.replace("cm","")
    # if "a.m." in answer: answer = answer.replace("a.m.","")
    return answer 
    
def handle_boxed(sol, gt):
    flag = True 
    # try parsing gt 
    gt = normalize_answer(gt)
    try:
        b = parse(gt)
    except:
        res = -6 # exception when parsing gt
        flag = False 
        
    if flag:
        if len(b)==0: 
            b = parse(f"\\boxed{{{gt}}}") # second try to parse gt 
            
        if len(b)==0:
            res = -2
        else: 
            # if boxed, take the final boxed, otherwise parse() takes the first
            if '\\boxed' in sol: 
                boxed_index = sol.rindex("\\boxed")
                nsol = sol[boxed_index:]
                nsol = normalize_answer(nsol)
                res = do_verify(nsol, b)
                if res<0.5: 
                    if nsol.startswith('\\boxed{'+gt): res = 1.0
            else: # second try ?
                res = -3 # no boxed found in solution 
    return res


def extract_dpsk_query_and_response(input_text):
    # Split the input text by the assistant's start token
    # print(input_text)
    parts = input_text.split("<｜Assistant｜>")
    
    # The first part contains the system and user messages
    if len(parts)==0:
        print('!!!! warning extraction', input_text)
    user_part = parts[0]
    
    # The second part contains the assistant's response
    if len(parts)==1: assistant_response = ""
    else: assistant_response = parts[1]
    
    # Extract the user query by splitting the user part
    user_query = user_part.split("<｜User｜>")[1]
    
    # Return the user query and the assistant's response
    return user_query, assistant_response


def extract_qwen_query_and_response(input_text):
    # Split the input text by the assistant's start token
    parts = input_text.split("<|im_start|>assistant\n")
    
    # The first part contains the system and user messages
    user_part = parts[0]
    
    # The second part contains the assistant's response
    if len(parts)==1: assistant_response = ""
    else: assistant_response = parts[1]
    
    # Extract the user query by splitting the user part
    user_query = user_part.split("<|im_start|>user\n")[1].split('<|im_end|>')[0]
    
    # Return the user query and the assistant's response
    return user_query, assistant_response

In [36]:

fp = "/home/ma-user/work/haozhe/workspace/deepscaler2/inference_out/eval-medqweninsbase-ml4096-n1"
is_qwen = True 
data = []
for f in tqdm(glob(fp+'/*.json')):
    tmp = json.load(open(f))
    tmp2 = []
    for k in ['text','answers','bench']:
        tmp2.append(tmp[k])
    for a,b,c in zip(*tmp2): 
        data.append(dict(query=a, answer=b, bench=c))

# texts = get_data(fp)
cnt = 0
results = []
if is_qwen: extract = extract_qwen_query_and_response
else: extract = extract_dpsk_query_and_response


100%|██████████| 22/22 [00:00<00:00, 565.24it/s]


In [37]:
data[0]

{'query': "<|im_start|>system\nPlease reason step by step and put the final answer within \\boxed{}.<|im_end|>\n<|im_start|>user\nA 26-year-old nullipara presents to her physician for a routine check-up at 18 weeks gestation. She has no co-morbidities. Her only complaints are fatigability and a depressed mood for the past 2 weeks. Her vital signs are as follows: blood pressure, 125/80 mm Hg; heart rate, 87/min; respiratory rate, 14/min; and temperature, 36.7℃ (98℉). The physical examination is unremarkable and the gynecologic examination is consistent with 18 weeks gestation. A thyroid profile s ordered to check for a possible cause of the fatigability and decreased mood:\nThyroid stimulating hormone (TSH) 0.3 mU/L\nTotal T4 160 nmol/L\nFree T4 13 pmol/L\nCorresponding to the obtained results, how should the patient be managed? A: Prescribe levothyroxine 50 mcg daily B: No specific management required C: Recommend additional anti-TPO test D: Recommend additional T3 assessment<|im_end|>

In [42]:
for item in tqdm(data):
    query = item['query']
    q, rsp = extract(query)
    gold = item['answer']
    
    match = handle_boxed(rsp, gold)
    results.append(dict(q=q, gold=gold, rsp=rsp, match=match, bench=item['bench']))


100%|██████████| 1408/1408 [00:06<00:00, 209.14it/s]


In [44]:
import pandas as pd

def print_stats_by_bench(df, bench_col='bench'):
    """
    Prints statistics of other columns in a DataFrame, grouped by unique values in a specified 'bench' column.

    Args:
        df (pd.DataFrame): The input DataFrame.
        bench_col (str): The name of the column to group by (default: 'bench').
    """
    if bench_col not in df.columns:
        print(f"Error: Column '{bench_col}' not found in the DataFrame.")
        return

    for bench_value in df[bench_col].unique():
        print(f"\nStatistics for bench value: {bench_value}")
        subset_df = df[df[bench_col] == bench_value]

        # Exclude the bench column itself from statistics
        other_cols = [col for col in ['match'] if col != bench_col]

        if other_cols:
            print(subset_df[other_cols].describe()) # describe() is the most useful default
        else:
            print("No other columns to describe.")
print_stats_by_bench(pd.DataFrame(results))



Statistics for bench value: MedQA-USLME-Test
             match
count  1275.000000
mean      0.401569
std       0.923442
min      -3.000000
25%       0.000000
50%       1.000000
75%       1.000000
max       1.000000

Statistics for bench value: GPQA-Medical-Test
            match
count  135.000000
mean     0.355556
std      0.706755
min     -3.000000
25%      0.000000
50%      0.000000
75%      1.000000
max      1.000000


In [43]:
idx = 0
for item in results:
    if item['match']<0.5:
        print(item['rsp'])
        print('gold=', item['gold'])
        print('==========')
        if idx>10: break 
        idx += 1

To determine the most likely novel drug that would benefit this patient in addition to valsartan, let's analyze the patient's symptoms, medical history, and findings step by step:

1. **Symptoms and History:**
   - Easy fatigability and breathlessness on climbing stairs for 2 weeks.
   - Occasional night-time cough relieved by sitting upright.
   - No shortness of breath at rest, palpitations, or loss of consciousness.
   - Hypertension for 20 years, on antihypertensive medications.
   - Physical examination: temperature 36.9°C, pulse 104/min, blood pressure 122/82 mm Hg, respirations 18/min.
   - Chest auscultation reveals crackles over the lung bases bilaterally.
   - Abdominal examination reveals mildly tender hepatomegaly.
   - Lab results: hemoglobin 14.8 g/dL, elevated serum B-type natriuretic peptide (BNP).
   - Echocardiogram: enlarged left atrium, ejection fraction 55%.

2. **Interpretation of Findings:**
   - The patient's symptoms (breathlessness, night-time cough, crackles 

In [8]:
!pip install math-verify

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [None]:
# from datasets import load_dataset
# def regularize(q):
#     return q.strip().replace(' ','')
# data = load_dataset("/home/ma-user/work/haozhe/workspace/OpenRLHF/data/0209_eval_amcaime_queries")['train']
# q2src = dict()
# q2gold = dict()
# for item in data:
#     q = item['question']
#     s = item['source'] 
#     nq = regularize(q)
#     q2src[nq] = s
#     q2gold[nq] = item['answer']