In [247]:
from src import Icsr

import datasets
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
from math import log10
import seaborn as sns
import pandas as pd

sns.set_style('whitegrid')


In [249]:
# load validation split
dataset = datasets.load_dataset("FAERS-PubMed/BioDEX-ICSR")
val = dataset['validation']


Using custom data configuration FAERS-PubMed--BioDEX-ICSR-40aa49fec6af4868
Found cached dataset parquet (/Users/kldooste/.cache/huggingface/datasets/FAERS-PubMed___parquet/FAERS-PubMed--BioDEX-ICSR-40aa49fec6af4868/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [250]:
# post-process prediction
def postprocess(prediction):
    prediction = prediction.replace('\n', ' ')
    return prediction.lower().strip()

In [251]:
# load predictions
models_and_files = [
    ("flan-t5-large", "./predictions/generated-eval-predictions-flan-t5-large-s(2048)-t(258).txt"),
    ("gpt-4", "./predictions/generated-eval-predictions-gpt-4-0314-run01.txt")
]

models = [mf[0] for mf in models_and_files]

predictions = []
for _, file in models_and_files:
    with open(file, 'r') as fp:
        raw = fp.readlines()
    raw = [postprocess(p) for p in raw]
    predictions.append(raw)

# parse icsrs
predictions_icsrs = []
for output in predictions:
    icsrs = [Icsr.from_string(p) for p in output]
    predictions_icsrs.append(icsrs)

In [252]:
# load targets and inputs
inputs = val['fulltext_processed']
pmids = val['pmid']
targets = [postprocess(p) for p in val['target']]
targets_icsr = [Icsr.from_string(p) for p in targets]

In [253]:
# create dataframe
min_length = min([len(p) for p in predictions])

data = {
    "input": inputs[:min_length],
    "pmid": pmids[:min_length],
    "target_icsr": targets_icsr[:min_length],
    models[0] + "_output": predictions[0][:min_length],
    models[0] + "_icsr": predictions_icsrs[0][:min_length],
    models[1] + "_output": predictions[1][:min_length],
    models[1] + "_icsr": predictions_icsrs[1][:min_length]
}

df = pd.DataFrame(data=data)

for model in models + ['target']:
    df[f'{model}_serious'] = df[f'{model}_icsr'].apply(lambda x: x.serious)
    df[f'{model}_patientsex'] = df[f'{model}_icsr'].apply(lambda x: x.patientsex)
    df[f'{model}_drugs'] = df[f'{model}_icsr'].apply(lambda x: x.drugs)
    df[f'{model}_reactions'] = df[f'{model}_icsr'].apply(lambda x: x.reactions)


In [254]:
col_view = ['input', 
            'pmid',
            'target_serious',
            'flan-t5-large_serious',
            'gpt-4_serious',
            "target_patientsex",
            "flan-t5-large_patientsex",
            "gpt-4_patientsex",
            "target_drugs",
            "flan-t5-large_drugs",
            "gpt-4_drugs",
            "target_reactions",
            "flan-t5-large_reactions",
            "gpt-4_reactions"]

df_view = df.reindex(columns=col_view)
df_view['input'] = df_view['input'].apply(lambda x: x.replace('\n\n', ' ').replace('\n',' '))

# handle lists
def handle_list(element):
    if not isinstance(element, list):
        return element
    return ', '.join(element)
# df_view = df_view.applymap(lambda x: str(x).replace('[','').replace(']', ''))
df_view = df_view.applymap(handle_list)


pd.set_option('display.max_colwidth', 1000)
latex = df_view.to_latex(index=False, longtable=True, escape=False)
latex = latex.replace('_', '\\_')

with open('./predictions/prediction_comparison.tex', 'w') as fp:
    fp.write(latex)

  latex = df_view.to_latex(index=False, longtable=True, escape=False)


In [261]:
import pandas as pd
import re
import pylatex

# Set the desired input cutoff and text format
input_cutoff = 2500
n_examples = 10

subtable = '''(PMID: {pmid}) {input} \\\\ \\\\ \\begin{{tabular}}{{lll}} \
\\hline \
              & serious                                       & patientsex                                      \\\\ \\hline \
target        & {target_serious}                                             & {target_patientsex}                                                         \\\\ \
flan-t5-large & {flan-t5-large_serious}                                      & {flan-t5-large_patientsex}                                                  \\\\ \
gpt-4         & {gpt-4_serious}                                              & {gpt-4_patientsex}                                                          \\\\ \\hline \
              & \\multicolumn{{2}}{{l}}{{drugs}}                                                                       \\\\ \\hline \
target        & \\multicolumn{{2}}{{p{{13.2cm}}}}{{{target_drugs}}} \\\\ \
flan-t5-large & \\multicolumn{{2}}{{p{{13.2cm}}}}{{{flan-t5-large_drugs}}} \\\\ \
gpt-4         & \\multicolumn{{2}}{{p{{13.2cm}}}}{{{gpt-4_drugs}}} \\\\ \\hline \
              & \\multicolumn{{2}}{{l}}{{reactions}}                                                                   \\\\ \\hline \
target        & \\multicolumn{{2}}{{p{{13.2cm}}}}{{{target_reactions}}}                                                 \\\\ \
flan-t5-large & \\multicolumn{{2}}{{p{{13.2cm}}}}{{{flan-t5-large_reactions}}} \\\\ \
gpt-4         & \\multicolumn{{2}}{{p{{13.2cm}}}}{{{gpt-4_reactions}}} \\\\ \\hline \
\\end{{tabular}} \\\\'''

# Define the function to convert a row to text
def row_to_text(row):
    dct = dict(row)
    dct = {k:str(pylatex.escape_latex(v)) for k,v in dct.items()}
    dct['input'] = dct['input'][:input_cutoff] + '... [Truncated]'

    return subtable.format(**dct)

# Apply the row_to_text function to the DataFrame to convert it to a single column
df_singlecol = df_view.apply(row_to_text, axis=1).iloc[:n_examples]

# Set the max_colwidth option to a large value to prevent truncation
pd.set_option('display.max_colwidth', 10000)

# Convert the DataFrame to LaTeX table
latex_string = df_singlecol.to_string(index=False)

# Modify the LaTeX code to create a longtable that wraps the text
latex = "\\begin{longtable}{@{}p{1.00\\textwidth}@{}}\n"
latex += "\\toprule\n"
latex += "\\textbf{input} \\tabularnewline\n"
latex += "\\midrule\n"
latex += "\\endhead\n"
latex += "\\bottomrule\n"
latex += "\\endfoot\n"
latex += "\\endlastfoot\n"
latex += latex_string.replace('\n', ' \\\\\n')
latex += "\\end{longtable}\n"


# Remove HTML tags
latex = re.sub(r'[^\S\n]+', ' ', latex)



# Deal with whitespace
# latex = re.sub(r'\s+',' ', latex)

# Write the LaTeX table to a file
with open('./predictions/prediction_comparison_singlecol.tex', 'w') as fp:
    fp.write(latex)
