# MedQA Relabelling Analysis

In [None]:
import numpy as np
import functools
import pandas as pd
from matplotlib import pyplot as plt

In [None]:
COLORS = [
    'salmon', 'orange', 'mediumseagreen', 'cornflowerblue'
]

In [None]:
input_file = 'medqa_relabelling.csv'
with open(input_file, 'r') as f:
  df = pd.read_csv(f)

In [None]:
df.head()

## Voting

In [None]:
def compute_blind_errors(row):
  """Computes blind errors/mis-coverage, if GT is in rater's answers."""
  answer_idx = row['answer_idx']
  responses = row['blind_answers']
  if not isinstance(responses, list):
    # No answer given.
    return True
  return answer_idx not in responses


def compute_seen_errors(row):
  """Computes seen errors/mis-coverage, if GT is in rater's answers."""
  answer_idx = row['answer_idx']
  responses = row['seen_answers']
  if not isinstance(responses, list):
    return True
  return answer_idx not in responses


def select_combined_answers(row):
  """Selects either seen or blind answers depending."""
  responses = row['seen_answers']
  if not isinstance(responses, list):
    # Rater did not change answer after revealing GT.
    responses = row['blind_answers']
  return responses


def select_combined_answerable(row):
  """Selects either seen or blind answerable."""
  if pd.isnull(row['seen_answerable']):
    return row['blind_answerable']
  return row['seen_answerable']


def compute_combined_errors(row):
  """Combined uses seen if seen_change, and blind answers otherwise."""
  answer_idx = row['answer_idx']
  responses = row['combined_answers']
  if not isinstance(responses, list):
    return True
  return answer_idx not in responses


def compute_combined_size(row):
  """Computes size of combined answers."""
  responses = row['combined_answers']
  if not isinstance(responses, list):
    return 0
  return len(responses)


# Convert answers to lists.
df['blind_answers'] = df['blind_answers'].apply(lambda x: eval(x) if isinstance(x, str) else x)
df['seen_answers'] = df['seen_answers'].apply(lambda x: eval(x) if isinstance(x, str) else x)
# Compute rater errors before and after revealing the GT.
df['blind_errors'] = df.apply(compute_blind_errors, axis=1)
df['seen_errors'] = df.apply(compute_seen_errors, axis=1)
# Compute combined answer, answer size = # of selected options, errors.
df['combined_answers'] = df.apply(select_combined_answers, axis=1)
df['combined_size'] = df.apply(compute_combined_size, axis=1)
df['combined_ambiguous'] = df['combined_size'] > 1
df['combined_answerable'] = df.apply(select_combined_answerable, axis=1)
df['combined_errors'] = df.apply(compute_combined_errors, axis=1)
# If info_missing is False, then important_info_missing should also be False.
df['important_info_missing'] = df['important_info_missing'].apply(
    lambda x: x == True)

In [None]:
# Columns that stay the same.
same_keys = [
    'qid', 'question', 'A', 'B', 'C', 'D', 'answer_idx',
]
# Columns that we want to vote over (i.e., aggregate rater opinions for).
vote_keys = [
    'blind_answerable', 'seen_answerable', 'combined_answerable',
    'important_info_missing', 'info_missing',
    'blind_errors', 'seen_errors', 'combined_errors',
    'seen_change', 'combined_ambiguous',
]
keep_keys = ['blind_answers', 'seen_answers', 'combined_answers']
vote_df = df[['qid'] + vote_keys + keep_keys]
core_df = df[same_keys]
core_df = core_df.drop_duplicates(['qid'])
# We aggregate all rater opinions for the columns we want to vote on.
vote_dfs = [core_df]
for vote_key in vote_keys:
  agg_df = vote_df.groupby('qid')[vote_key].apply(list).reset_index()
  vote_dfs.append(agg_df)
for keep_key in keep_keys:
  agg_df = vote_df.groupby('qid')[keep_key].apply(list).reset_index()
  vote_dfs.append(agg_df)
# Merge all the individually aggregated columns.
vote_df = functools.reduce(
    lambda left, right: pd.merge(left, right, on=['qid']), vote_dfs)
# Vote by requiring 2 or 3 rater majority.
for vote_key in vote_keys:
  vote_df[f'sum_{vote_key}'] = vote_df[vote_key].apply(
      lambda xs: np.sum([x if isinstance(x, bool) else False for x in xs]))
  for k in [1, 2, 3]:
    vote_df[f'vote{k}_{vote_key}'] = vote_df[f'sum_{vote_key}'] >= k

## Simulate model predictions

In [None]:
# We target an accuracy of 0.911 (in expectation), but fix an accuracy of 0.98
# on examples that are likely unfilteres; this creates a scenario similar to the paper
# where the model makes more mistakes on examples that are about to be filtered out.
# This is the place where you can load your modal predictions!
np.random.seed(42)
num_filtered = int(np.sum(vote_df['vote1_combined_errors'] == True))
filter_rate = num_filtered/vote_df.shape[0]
accuracy_on_unfiltered = 0.98
accuracy_on_filtered = (
    (0.911 - accuracy_on_unfiltered * (1 - filter_rate)) / filter_rate
)
error_df = vote_df[['vote1_combined_errors', 'qid']]
error_df['error'] = False
error_df.loc[error_df['vote1_combined_errors'] == True, 'error'] = (
    np.random.uniform(0, 1, (num_filtered,)) > accuracy_on_filtered
)
error_df.loc[error_df['vote1_combined_errors'] == False, 'error'] = (
    np.random.uniform(0, 1, (error_df.shape[0] - num_filtered,)) > accuracy_on_unfiltered
)

In [None]:
df = pd.merge(df, error_df[['qid', 'error']], on=['qid'])

## Evaluation results

In [None]:
k = 3  # 3 = unanomous voting (main paper), 2 = majority voting (appendix).
n_trials = 10  # 1000 in the paper.
num_questions = []
error_rates = []
agg_df = df.groupby(['qid', 'error']).agg(list).reset_index()
for t in list(range(n_trials)):
  agg_df['sample_info_missing'] = agg_df['info_missing'].apply(
      lambda x: np.sum(np.random.choice(x, 3, replace=True)) >= k)
  agg_df['sample_combined_errors'] = agg_df['combined_errors'].apply(
      lambda x: np.sum(np.random.choice(x, 3, replace=True)) >= k)
  agg_df['sample_combined_ambiguous'] = agg_df['combined_ambiguous'].apply(
      lambda x: np.sum(np.random.choice(x, 3, replace=True)) >= k)

  qs = [
      # All questions.
      agg_df.shape[0],
      # Filter questions that are missing information.
      np.sum(agg_df['sample_info_missing'] == False),
      # Filter label errors.
      np.sum(np.logical_and(
          agg_df['sample_info_missing'] == False,
          agg_df['sample_combined_errors'] == False)),
      # Filter ambiguous questions.
      np.sum(np.logical_and.reduce((
          agg_df['sample_info_missing'] == False,
          agg_df['sample_combined_errors'] == False,
          agg_df['sample_combined_ambiguous'] == False))),
  ]
  es = [
      np.sum(agg_df['error']) / qs[0],
      np.sum(np.logical_and(
          agg_df['error'],
          agg_df['sample_info_missing'] == False
      )) / qs[1],
      np.sum(np.logical_and.reduce((
          agg_df['error'],
          agg_df['sample_info_missing'] == False,
          agg_df['sample_combined_errors'] == False,
      ))) / qs[2],
      np.sum(np.logical_and.reduce((
          agg_df['error'],
          agg_df['sample_info_missing'] == False,
          agg_df['sample_combined_errors'] == False,
          agg_df['sample_combined_ambiguous'] == False,
      ))) / qs[3],
  ]
  num_questions.append(qs)
  error_rates.append(es)
  print(t)
num_questions = np.array(num_questions)
error_rates = np.array(error_rates)

groups = ['Before', 'w/o\nmissing info', 'w/o\nlabel errors', 'w/o\nambiguous']
x = np.arange(len(groups))  # the label locations
width = 0.45  # the width of the bars
multiplier = 0

fig, ax1 = plt.subplots(layout='constrained')
ax2 = ax1.twinx()

for ax, color, label, mean, std in [
    (ax1, COLORS[3], 'Accuracy', 100 * (1 - np.mean(error_rates, axis=0)), np.std(100 * error_rates, axis=0)),
    (ax2, COLORS[0], 'Questions', 100 * np.mean(num_questions / 1273, axis=0), 100 * np.std(num_questions / 1273, axis=0)),
]:
    offset = width * multiplier
    rects = ax.bar(x + offset, mean, yerr=std, width=width, label=label, capsize=5, color=color)
    ax.bar_label(rects, padding=3, fmt='%.1f', label_type='edge')
    multiplier += 1

ax1.spines['top'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax1.set_ylabel('MedQA Accuracy', fontsize=12)
ax2.set_ylabel('Fraction of questions', fontsize=12)
ax1.set_xticks(x + width/2, groups, rotation=30, fontsize=12)
ax1.legend(loc='upper right', bbox_to_anchor=(1, 1), ncols=1)
ax2.legend(loc='upper right', bbox_to_anchor=(0.68, 1), ncols=1)
ax1.set_ylim(90, 100)
ax2.set_ylim(90, 100)
ax.grid(False)
plt.gcf().set_size_inches((5, 3))
plt.savefig('medqa_filtering.pdf', dpi=300, format='pdf')
plt.show()