In [8]:
import json
import re

from utils import MultiChoiceFilter

In [15]:
class MedQAPrediction:
    def __init__(self, id: str, gold: int, generated_response: str):
        self.id = id
        self.gold = chr(gold + ord('A'))
        self.generated_response = generated_response

        assert self.gold in ['A', 'B', 'C', 'D']

        self.filter = MultiChoiceFilter()
        self.pred_label, _ = self.filter.extract_answer(self.generated_response)

    def __str__(self):
        return f"MedQAPrediction(id={self.id}, gold={self.gold}), pred_label={self.pred_label})"



In [13]:
baseline_result_path = "../data/mistral-7B-Instruct-v0.2_medqa-official_results.json"
finetuned_result_path = "../data/meerkat-7b-v1.0_medqa-official_results.json"

with open(baseline_result_path, "r") as f:
    baseline_results = json.load(f)

with open(finetuned_result_path, "r") as f:
    finetuned_results = json.load(f)

print(len(baseline_results))
print(len(finetuned_results))

1273
1273


In [26]:
mismatch_ids = []

for i in range(len(baseline_results)):
    baseline_pred = MedQAPrediction(
        id=baseline_results[i]["id"],
        gold=baseline_results[i]["gold"],
        generated_response=baseline_results[i]["generated_response"]
    )
    finetuned_pred = MedQAPrediction(
        id=finetuned_results[i]["id"],
        gold=finetuned_results[i]["gold"],
        generated_response=finetuned_results[i]["generated_response"]
    )

    assert baseline_pred.id == finetuned_pred.id

    # finetuned model is correct and baseline model is incorrect
    if ((finetuned_pred.pred_label != finetuned_pred.gold) and 
        (baseline_pred.pred_label == baseline_pred.gold)):
        mismatch_ids.append(baseline_pred.id)

print(len(mismatch_ids))

139


In [22]:
print(mismatch_ids)

['test-00001']


In [25]:
mismatch_ids_path = "../data/medqa_official_mismatch_ids.json"

with open(mismatch_ids_path, "w") as f:
    json.dump(mismatch_ids, f)
