In [34]:
from pathlib import Path

dataset='medmcqa'
model='vicuna-33b'

path_to_outputs = Path(f'/Users/andreasmotz/Research/meluxina/medical-reasoning/self-consistency/{dataset}/{model}/2023-07-06/outputs')
path_to_results = Path(f'/Users/andreasmotz/Research/meluxina/medical-reasoning/self-consistency/{dataset}/{model}/2023-07-06/results.json')

In [35]:
import json
import os


outputs = os.listdir(path_to_outputs)
with open(path_to_results, 'r') as f:
    results_file = json.load(f)

In [30]:
from dataclasses import dataclass
import re
from typing import List, Optional

SYMBOLS= ["A", "B", "C", "D"]

@dataclass(frozen=True)
class Sample:
    id: str
    question: str
    options: List[str]
    answer: str
    answer_idx: int

@dataclass
class Prediction:
    answer_str: str
    answer_idx: int
    outcome: Optional[bool] = None
    probs: Optional[List[int]] = None
    flows: Optional[List[str]] = None
    sample: Optional[Sample] = None

def format_sample(idx: str, question: str, choices: list, target: int)->Sample:
    answer_str = choices[target]
    return Sample(
        id=idx,
        question=question,
        options=choices,
        answer=answer_str,
        answer_idx=target,
    )

def extract_answer_idx(answer_str: str, options: list[str])->int:
    """
    Extracts the index of the selected answer option from a given answer string.
    :param answer_str: The answer string to extract the answer index from.
    :param options: The list of answer options to match against.
    :return: The index of the selected answer option, or -1 if the answer index couldn't be inferred.
    """
    symbols_pattern = r'(?:^|\()([A-D])(?:[\s.,:\)]|$)'
    exact_answers = '|'.join([re.escape(option) for option in options])
    answers = '|'.join([re.escape(option.lower()) for option in options])

    matches_symbol = re.findall(symbols_pattern, answer_str, re.MULTILINE)
    matches_exact = re.findall(exact_answers, answer_str, re.MULTILINE)
    matches_answer = re.findall(answers, answer_str, re.MULTILINE|re.IGNORECASE)

    if len(set(matches_exact)) == 1:
        predicted_idx = options.index(matches_exact[0])
    elif len(set(matches_answer)) ==1:
        lowered_options = [option.lower() for option in options]
        predicted_idx = lowered_options.index(matches_answer[0].lower())
    elif len(set(matches_symbol)) ==1:
        predicted_idx = SYMBOLS.index(matches_symbol[0])
    else:
        predicted_idx = -1
    return predicted_idx

def parse_generated_output(output_str: str, eg: dict)->Prediction:
    """
    Parses the generated string into a Prediction instance to extract reasoning path and prediction.
    :param output_str: The generated string to be parsed.
    :param eg: A dictionary containing the sample data.
    :return: A Prediction instance containing the parsed data.
    """
    sample = format_sample(eg["uid"], eg["question"], [e.strip() for e in eg["choices"]], eg["target"])
    first_letter = r'(^[A-D])'
    last_letter = r'[A-D]$'
    simple_pattern = r'answer([^.]*)\.'
    option_pattern = '|'.join([re.escape(option) for option in sample.options])
    symbol_pattern = r'\(?[A-D](?![A-Za-z0-9])[\)\.\:]?'
    first_letter_match = re.search(first_letter, output_str, re.MULTILINE)
    last_letter_match = re.search(last_letter, output_str, re.MULTILINE)
    answer_matches = re.search(simple_pattern, output_str, re.MULTILINE|re.IGNORECASE)
    option_matches = re.findall(option_pattern, output_str, re.MULTILINE|re.IGNORECASE)
    symbol_matches = re.findall(symbol_pattern, output_str, re.MULTILINE|re.IGNORECASE)
    
    answer_str = output_str
    if answer_matches: 
        answer_str = answer_matches.group(1)
        predicted_idx = extract_answer_idx(answer_str, sample.options)
    elif len(set(option_matches)) == 1:
        lowered_options = [option.lower() for option in sample.options]
        predicted_idx = lowered_options.index(option_matches[0].lower())
    elif len(set(symbol_matches))==1: 
        answer_str = symbol_matches[0]
        predicted_idx = extract_answer_idx(answer_str, sample.options)
    elif first_letter_match:
        predicted_idx = SYMBOLS.index(first_letter_match.group())
    elif last_letter_match:
        predicted_idx = SYMBOLS.index(last_letter_match.group())
    else:
        sep = "\n\n" if "\n\n" in output_str else "."
        answer_str = output_str.split(sep)[0].strip()
        first_sentence_check = extract_answer_idx(answer_str, sample.options)
        whole_sentence_check = extract_answer_idx(output_str, sample.options)
        predicted_idx = max(first_sentence_check,whole_sentence_check)

    if predicted_idx == -1:
        explanation_str = ""
        predicted_idx = -1

    else: 
        start_index = output_str.index(answer_str)
        answer_str = output_str[start_index:]
        explanation_str = output_str[:start_index]

    prediction_str = sample.options[predicted_idx] if predicted_idx >= 0 else "N/A"

    outcome = predicted_idx == sample.answer_idx
    output = {
        'outcome' : outcome,
        'explanation' : explanation_str.strip(),
        'prediction' : answer_str.strip()
    }
    
    return Prediction(
        answer_str=prediction_str,
        answer_idx=predicted_idx,
        flows=[output],
    )

In [36]:
y = []
y_hat = []
for file in outputs:
    with open(path_to_outputs/file, 'r') as f:
        tmp_file = json.load(f)
    target = tmp_file['choices'].index(tmp_file['answer'])
    y.append(target)
    tmp_y_hat = []
    for response in tmp_file['reasoning']:
        output_str = response["prediction"]
        sample = {"uid": tmp_file["qst_locator"], "question": tmp_file["question"], "choices": [e.strip() for e in tmp_file["choices"]], "target": target}
        pred = parse_generated_output(output_str, sample)
        tmp_y_hat.append(pred.answer_idx)
    y_hat.append(tmp_y_hat)

In [37]:
results_file['y'] = y
results_file['y_hat'] = y_hat

# Write the updated JSON back to the file
with open(f"/Users/andreasmotz/Downloads/{model}-{dataset}.json", "w") as file:
    json.dump(results_file, file, indent=2)