In [13]:
import json
import ipywidgets as widgets
from IPython.display import display, Image, HTML
from pathlib import Path

In [14]:
# TODO: adapt to your experiment path
exp_base_path = Path("XXX")

# Adapt the following variables to your needs
dataset = "SLAKE"
# TODO: put your name here
rater_name = "XXX"
open_closed = "open"
train_split = "all"
val_split = "all"
balanced = False

finetune_method = "ia3"
hyperparam_str = "lr3e-2_seed123"

# TODO: adapt to your selected questions path
selected_questions_path = Path("XXX") / dataset

exp_path = exp_base_path / f"{dataset}/{finetune_method}/llava-{dataset}_train_{train_split}-finetune_{finetune_method}_{hyperparam_str}/eval/{dataset}_val_{val_split}"
results_file = exp_path / "test_results.json"
selected_questions_file = selected_questions_path / f"rater_study_{open_closed}.json" if not balanced else selected_questions_path / f"rater_study_{open_closed}_balanced.json"
instructions_file = "../prompts/mistral_prompt.txt" if open_closed == "open" else "../prompts/mistral_prompt_closed.txt"
ratings_save_file = exp_path / f"human_metrics_{open_closed}_{rater_name.lower()}.json" if not balanced else exp_path / f"human_metrics_{open_closed}_balanced_{rater_name.lower()}.json"

with open(selected_questions_file, 'r') as f:
    selected_questions = json.load(f)

with open(results_file, 'r') as f:
    results = json.load(f)

selected_question_ids = [element["qid"] for element in selected_questions]
current_index = 0
max_index = len(selected_question_ids)
ratings = {}


def display_current_question():
    global current_index
    selected_question_id = selected_question_ids[current_index]
    element = [element for element in results if element["qid"] == selected_question_id]
    if len(element) == 1:
        element = element[0]
        question_widget.value = f"<b>Question:</b> {element['question']}<br><b>GT:</b> {element['gt']}<br><b>Pred:</b> {element['pred']}<br><b>QID:</b> {element['qid']}"
    else:
        print(f"Question with qid {selected_question_id} not found")
    rate_text.value = str(ratings.get(selected_question_id, ""))
    
def on_next_button_clicked(b):
    global current_index, max_index
    current_index = (current_index + 1) % max_index
    display_current_question()

def on_previous_button_clicked(b):
    global current_index, max_index
    if current_index == 0:
        return
    current_index = (current_index - 1) % max_index
    display_current_question()

def on_rate_button_clicked(b):
    global current_index, ratings
    selected_question_id = selected_question_ids[current_index]
    try:
        rating_score = int(rate_text.value)
        if open_closed == "open" and rating_score not in range(1,6):
            raise
        if open_closed == "closed" and rating_score not in [0,1]:
            raise
        ratings[selected_question_id] = rating_score
        display_ratings()
    except:
        print("Enter an integer between 1 and 5 for open ended and between 0 and 1 for closed ended questions")


def on_save_clicked(b):
    rated_questions = [q for q in results if q['qid'] in ratings.keys()]
    ratings_formatted = [
        {
            "qid": element['qid'],
            "question": element['question'],
            "gt": element['gt'],
            "pred": element['pred'],
            "answer_type": element['answer_type'],
            "human_score": ratings[element['qid']]
        }
        for element in rated_questions
    ]
    with open(ratings_save_file, 'w') as f:
        data = json.dump(ratings_formatted, f, indent=4)
    print("saved")


with open(instructions_file, 'r') as file:
    instructions = file.read()
instructions = instructions[3:].replace("\n", "<br>").split("Here are some instructions on the input and output format:")[0]
instructions += "========================================<br>"

instruction_widget = widgets.HTML(value=f"{instructions}")
question_widget = widgets.HTML()
previous_button = widgets.Button(description="Previous")
next_button = widgets.Button(description="Next")
rate_text = widgets.Text(
    value='',
    placeholder='',
    disabled=False
)
rate_button = widgets.Button(description="Rate")
save_button = widgets.Button(description="Save ratings to file")

next_button.on_click(on_next_button_clicked)
previous_button.on_click(on_previous_button_clicked)
rate_button.on_click(on_rate_button_clicked)
save_button.on_click(on_save_clicked)

display_current_question()

buttons_previous_next = widgets.HBox([previous_button, next_button])
rate_widget = widgets.HBox([rate_text, rate_button])
ui = widgets.VBox([instruction_widget, question_widget, rate_widget, buttons_previous_next, save_button])
display(ui)

rated_questions_widget = widgets.HTML()

def display_ratings():
    rated_questions = [q for q in results if q['qid'] in ratings.keys()]
    rated_questions_widget.value = f"================<br><b>RATED QUESTIONS (total {len(ratings.keys())})</b><br>" +  "<br>================<br>".join(f"<b>Question:</b> {element['question']}<br><b>GT:</b> {element['gt']}<br><b>Pred:</b> {element['pred']}<br><b>Rating:</b> {ratings[element['qid']]}<br><b>QID:</b> {element['qid']}" for element in rated_questions)

display(rated_questions_widget)

VBox(children=(HTML(value='[INST] You are a helpful evaluator to evaluate answers to questions about biomedica…

HTML(value='')