In [1]:
import os
import re
import pandas as pd
import seaborn as sns
from sklearn.metrics import f1_score
import itertools

In [2]:
# dfs is a dictionary in the form of {llm_name: results_df, ...}, metric_names is a list of scoring metrics to be plotted
def barplot_metrics(dfs, metric_names): 
    metric_dfs = [dfs[llm_name][metric_names] for llm_name in sorted(dfs.keys())]
    combined = pd.concat(metric_dfs, keys=sorted(dfs.keys()))
    combined = combined.reset_index().melt(id_vars=["level_0"], value_vars=metric_names)
    combined.columns = ["LLM", "Metric", "Value"]
    return sns.catplot(combined, kind="bar", x="Metric", y="Value", hue="LLM")

In [3]:
pheno2guide_gpt4o = pd.read_csv("../../results/PhenoToGuideline_gpt-4o_results.txt", sep="\t", header=0)
pheno2guide_gpt4 = pd.read_csv("../../results/PhenoToGuideline_gpt-4-turbo_results.txt", sep="\t", header=0)
pheno2guide_gpt3_5 = pd.read_csv("../../results/PhenoToGuideline_gpt-3.5-turbo_results.txt", sep="\t", header=0)

In [4]:
pheno2guide_gpt4o

Unnamed: 0,drug,genes,question,answer,incorrect_recommendations,concurring_recommendation,llm_answer,bert_score_precision_llm_vs_ref,bert_score_recall_llm_vs_ref,bert_score_f1_llm_vs_ref,...,base_mpnet_llm_vs_ref,base_mpnet_llm_vs_concurring,base_mpnet_llm_vs_discordant,roberta_llm_vs_ref,roberta_llm_vs_concurring,roberta_llm_vs_discordant,gte_llm_vs_ref,gte_llm_vs_concurring,gte_llm_vs_discordant,gpt4_llm_vs_ref
0,abacavir,HLA-B,What would be the clinical guidance for someon...,Abacavir is not recommended,['Proceed taking abacavir as normal. There is ...,Avoid abacavir. Take an alternate therapy.,Individuals who are HLA-B*57:01 positive shoul...,-0.212475,0.180389,-0.048489,...,0.632196,0.579843,0.622965,0.536523,0.503806,0.460370,0.636677,0.644710,0.610551,1
1,abacavir,HLA-B,What would be the clinical guidance for someon...,Use abacavir per standard dosing guidelines,"['Avoid abacavir. Take an alternate therapy.',...",Proceed taking abacavir as normal. There is no...,For individuals who test negative for the HLA-...,-0.183660,0.091850,-0.060207,...,0.570103,0.604261,0.584354,0.524361,0.458540,0.542389,0.570296,0.537630,0.611667,1
2,allopurinol,HLA-B,What would be the clinical guidance for someon...,Allopurinol is contraindicated,['Proceed taking allopurinol as normal. There ...,Avoid allopurinol. Take an alternate therapy.,Individuals who are positive for the HLA-B*58:...,-0.170247,0.279745,0.013652,...,0.678736,0.636821,0.572698,0.440805,0.530582,0.448720,0.679917,0.634294,0.619782,1
3,allopurinol,HLA-B,What would be the clinical guidance for someon...,Use allopurinol per standard dosing guidelines,['Avoid allopurinol. Take an alternate therapy...,Proceed taking allopurinol as normal. There is...,If an individual tests negative for the HLA-B*...,-0.135312,0.412093,0.080035,...,0.625489,0.543427,0.585378,0.508970,0.440977,0.505617,0.659205,0.591414,0.585272,1
4,amikacin,MT-RNR1,What would be the clinical guidance for someon...,Avoid aminoglycoside anitbiotics unless the hi...,['Proceed taking amikacin as normal. There is ...,Avoid amikacin. Take an alternate therapy.,For individuals at increased risk of aminoglyc...,0.152064,0.217485,0.187191,...,0.797551,0.606724,0.581711,0.828169,0.530026,0.489887,0.676611,0.630893,0.566329,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
406,voriconazole,CYP2C19,What would be the clinical guidance for someon...,Choose an alternative agent that is not depend...,['Proceed taking voriconazole as normal. There...,Avoid voriconazole. Take an alternate therapy.,For a patient identified as a Poor Metabolizer...,0.291962,0.199882,0.247360,...,0.827334,0.641831,0.641043,0.811770,0.584432,0.552902,0.770611,0.705650,0.641404,1
407,vortioxetine,CYP2D6,What would be the clinical guidance for someon...,Initiate therapy with recommended starting dose.,['Avoid vortioxetine. Take an alternate therap...,Proceed taking vortioxetine as normal. There i...,For individuals who are intermediate metaboliz...,-0.085096,0.188625,0.038901,...,0.462844,0.711183,0.664401,0.438295,0.662196,0.618286,0.571327,0.654111,0.674231,0
408,vortioxetine,CYP2D6,What would be the clinical guidance for someon...,Select alternative drug not predominantly meta...,['Proceed taking vortioxetine as normal. There...,Avoid vortioxetine. Take an alternate therapy.,For individuals identified as Ultrarapid Metab...,0.110128,0.249666,0.179294,...,0.742795,0.637764,0.632960,0.699349,0.520659,0.504228,0.598524,0.583614,0.600740,0
409,vortioxetine,CYP2D6,What would be the clinical guidance for someon...,Initiate therapy with recommended starting dose.,['Avoid vortioxetine. Take an alternate therap...,Proceed taking vortioxetine as normal. There i...,For an individual who is an Intermediate Metab...,-0.057155,0.158047,0.043935,...,0.421226,0.702351,0.680508,0.444056,0.627440,0.606462,0.453617,0.610785,0.605827,0


In [5]:
embedding_funcs = [ # embedding functions from text_embeddings.py
            'oai_embedding',
            'negation_mpnet',
            'base_mpnet',
            'roberta',
            'gte',
            'bert_score_f1',
            'bert_score_precision',
            'bert_score_recall',
        ]
# metrics = [[f"{func_name}_llm_vs_ref", f"{func_name}_llm_vs_discordant"] for func_name in embedding_funcs]
# metrics = sum(metrics, [])

# barplot_metrics(
#     dfs={
#         "gpt-4o": pheno2guide_gpt4o,
#         "gpt-4-turbo": pheno2guide_gpt4, 
#         "gpt-3.5-turbo": pheno2guide_gpt3_5,
#     }, 
#     metric_names=metrics
# )

In [6]:
def calc_ref_win_rate(scorer_name, results_df, col_1="llm_vs_ref", col_2="llm_vs_discordant", gold_match_col=None):
    wins = (results_df[f"{scorer_name}_{col_1}"] > results_df[f"{scorer_name}_{col_2}"])

    llm_win_rate = wins.mean()
    
    if gold_match_col:
        gold_matches = results_df[gold_match_col]
        gold_win_rate = gold_matches.mean()
        metric_gold_f1 = f1_score(wins, gold_matches)
        return llm_win_rate, gold_win_rate, metric_gold_f1
    
    return llm_win_rate

In [7]:
for f in embedding_funcs:
    print(f"{f}| win rate: {calc_ref_win_rate(f, pheno2guide_gpt4o)}")

oai_embedding| win rate: 0.5304136253041363
negation_mpnet| win rate: 0.5036496350364964
base_mpnet| win rate: 0.45742092457420924
roberta| win rate: 0.4233576642335766
gte| win rate: 0.38686131386861317
bert_score_f1| win rate: 0.49635036496350365
bert_score_precision| win rate: 0.7055961070559611
bert_score_recall| win rate: 0.0340632603406326


In [8]:
for f in embedding_funcs:
    print(f"{f}| win rate: {calc_ref_win_rate(f, pheno2guide_gpt4)}")

oai_embedding| win rate: 0.49635036496350365
negation_mpnet| win rate: 0.4306569343065693
base_mpnet| win rate: 0.43795620437956206
roberta| win rate: 0.45012165450121655
gte| win rate: 0.36982968369829683
bert_score_f1| win rate: 0.5620437956204379
bert_score_precision| win rate: 0.7250608272506083
bert_score_recall| win rate: 0.072992700729927


In [9]:
for f in embedding_funcs:
    print(f"{f}| win rate: {calc_ref_win_rate(f, pheno2guide_gpt3_5)}")

oai_embedding| win rate: 0.5036496350364964
negation_mpnet| win rate: 0.41605839416058393
base_mpnet| win rate: 0.44282238442822386
roberta| win rate: 0.44768856447688565
gte| win rate: 0.3236009732360097
bert_score_f1| win rate: 0.5255474452554745
bert_score_precision| win rate: 0.7177615571776156
bert_score_recall| win rate: 0.0340632603406326


# How often does LLM answer have highest similarity to reference answer 

In [10]:
annot = pd.read_csv('./human_annot.csv')

docs = annot.document

In [11]:
def parse_string(input_str):
    # Initialize the dictionary to be returned
    parsed_dict = {"llm_answer": "", "cands": []}

    # Define regex patterns to match llm_answer and cand entries
    llm_pattern = re.compile(r"\$llm_answer:\s*(.*?)(?=\$\w+:|$)", re.DOTALL)
    cand_pattern = re.compile(r"(\$cand_\d+):\s*(.*?)(?=\$\w+:|$)", re.DOTALL)

    # Extract llm_answer
    llm_match = llm_pattern.search(input_str)
    if llm_match:
        parsed_dict["llm_answer"] = llm_match.group(1).strip()

    # Extract all cands
    cand_matches = cand_pattern.findall(input_str)
    for cand_tag, cand_text in cand_matches:
        parsed_dict["cands"].append(cand_text.strip())

    return parsed_dict

In [12]:
# Reference answer is always the last of the candidates
ref_answer = [len(parse_string(c)['cands'])-1 for c in annot.document]
docs = [parse_string(c)['llm_answer'] for c in annot.document]

annot['ref_answer'] = ref_answer

In [13]:
annot

Unnamed: 0,document,cand 0,cand 1,cand 2,cand 3,cand 4,cand 5,None of the above,ref_answer
0,$llm_answer:\nFor individuals who test negativ...,0,0,0,1,0,0,0,3
1,$llm_answer:\nFor individuals who have a genet...,0,0,0,1,0,0,0,3
2,$llm_answer:\nClinical guidance for an individ...,0,0,0,1,0,0,0,3
3,$llm_answer:\nIn individuals with an indetermi...,0,0,0,1,0,0,0,3
4,$llm_answer:\nFor a patient identified as a CY...,0,0,0,0,0,0,1,2
...,...,...,...,...,...,...,...,...,...
72,$llm_answer:\nFor a patient with an indetermin...,0,0,0,0,0,0,1,3
73,$llm_answer:\nFor someone identified as an Ult...,0,0,0,1,0,0,0,3
74,$llm_answer:\nFor an intermediate metabolizer ...,1,0,0,0,0,0,0,3
75,$llm_answer:\nFor an individual who is an ultr...,1,0,0,0,0,0,0,3


In [14]:
# True if human-rated best match for the LLM answer was the reference answer
llm_match = [(annot.iloc[row][f'cand {k}'] > 0) for row, k in enumerate(annot['ref_answer'])]

sum(llm_match) / len(llm_match)

0.6233766233766234

In [15]:
# Rescore old version
from TestUtils import PhenoToGuidelineTestRunner
import openai
import os

SYSTEM_PROMPT = "You are an AI assistant that provides evidence-based responses to pharmacogenomics questions. Please respond to the following query."

gpt_client = openai.OpenAI(
    organization=os.environ.get("KIMLAB_OAI_ID"),
    api_key=os.environ.get("OPENAI_API_KEY"),
    base_url="https://oai.hconeai.com/v1",
    default_headers={
        "Helicone-Auth": f"Bearer {os.environ.get('HELICONE_API_KEY')}",
        "Helicone-Cache-Enabled": "true",
    },
)

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
llm_answers = pd.read_csv("./old_version.tsv", sep="\t", header=0)
llm_answers = llm_answers.dropna(axis=0, subset=['answer'])
llm_answers['incorrect_recommendations'] = llm_answers['incorrect_recommendations'].map(eval)
runner = PhenoToGuidelineTestRunner(gpt_client, 'gpt-4o', SYSTEM_PROMPT, bert_score_model = 'microsoft/deberta-xlarge-mnli')

In [17]:
scores, summary = runner.score_answers(llm_answers)
scores = scores.set_index('llm_answer')

100%|████████████████████████████████████████████████████████████████████████| 410/410 [01:59<00:00,  3.42it/s]

                                        bert_score_precision_llm_vs_ref  \
bert_score_precision_llm_vs_ref                                1.000000   
bert_score_recall_llm_vs_ref                                   0.606471   
bert_score_f1_llm_vs_ref                                       0.952876   
bert_score_precision_llm_vs_concurring                         0.054360   
bert_score_recall_llm_vs_concurring                            0.205143   
bert_score_f1_llm_vs_concurring                                0.160513   
bert_score_precision_llm_vs_discordant                         0.060230   
bert_score_recall_llm_vs_discordant                            0.019495   
bert_score_f1_llm_vs_discordant                                0.070926   
oai_embedding_llm_vs_ref                                       0.756904   
oai_embedding_llm_vs_concurring                               -0.074818   
oai_embedding_llm_vs_discordant                               -0.030229   
negation_mpnet_llm_vs_ref




In [21]:
human_graded = scores.loc[docs].reset_index()

human_graded['human_matched_ref'] = llm_match

In [22]:
human_graded['human_matched_ref']

0      True
1      True
2      True
3      True
4     False
      ...  
72    False
73     True
74    False
75    False
76     True
Name: human_matched_ref, Length: 77, dtype: bool

In [24]:
print(f"Gold standard win rate: {human_graded['human_matched_ref'].mean()}")
for f in embedding_funcs:
    llm_win_rate, gold_win_rate, metric_gold_f1 = calc_ref_win_rate(f, human_graded, gold_match_col='human_matched_ref')
    print(f"{f} \t| LLM win rate: {llm_win_rate} \t| F1 with human judge: {metric_gold_f1}")

Gold standard win rate: 0.6233766233766234
oai_embedding 	| LLM win rate: 0.5324675324675324 	| F1 with human judge: 0.6067415730337079
negation_mpnet 	| LLM win rate: 0.45454545454545453 	| F1 with human judge: 0.5542168674698795
base_mpnet 	| LLM win rate: 0.42857142857142855 	| F1 with human judge: 0.49382716049382713
roberta 	| LLM win rate: 0.45454545454545453 	| F1 with human judge: 0.5060240963855421
gte 	| LLM win rate: 0.36363636363636365 	| F1 with human judge: 0.5526315789473685
bert_score_f1 	| LLM win rate: 0.5064935064935064 	| F1 with human judge: 0.5517241379310345
bert_score_precision 	| LLM win rate: 0.7012987012987013 	| F1 with human judge: 0.6274509803921569
bert_score_recall 	| LLM win rate: 0.025974025974025976 	| F1 with human judge: 0.08
