In [1]:
from ipywidgets import interact
import ipywidgets as widgets
import asyncio
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import subprocess
from IPython.display import FileLink
import os
def display_images(original_path, generated_path1, generated_path2):
    # Set up the figure with a specified size
    plt.figure(figsize=(15, 10))
    
    # Display original image
    plt.subplot(2, 2, 1)  # 2 rows, 2 columns, 1st subplot
    img = mpimg.imread(original_path)
    plt.imshow(img)
    plt.axis('off')  # Turn off the axis numbers and ticks
    plt.title('Original')  # Add a title to the first image
    
    # Display first generated image
    plt.subplot(2, 2, 3)  # 2 rows, 2 columns, 3rd subplot (second row, first column)
    img = mpimg.imread(generated_path1)
    plt.imshow(img)
    plt.axis('off')
    plt.title('Generated 1')  # Add a title to the third image

    # Display second generated image
    plt.subplot(2, 2, 4)  # 2 rows, 2 columns, 4th subplot (second row, second column)
    img = mpimg.imread(generated_path2)
    plt.imshow(img)
    plt.axis('off')
    plt.title('Generated 2')  # Add a title to the fourth image
    
    # Show the plot with all images
    plt.show()


def load_data(filepath):
    with open(filepath, 'r') as f:
        data = json.load(f)
    return data

def wait_for_change(func_button):
    future = asyncio.Future()

    def evaluate(button):
        future.set_result(button.description)
        func_button.on_click(
            evaluate, remove=True
        )  # we need to free up the binding to getvalue to avoid an InvalidState error buttons don't support unobserve, so use `remove=True`

    func_button.on_click(evaluate)
    return future

In [2]:
merged_filepath = 'merged_output.json'
output_filepath = "compare_output.json"
if os.path.exists(output_filepath):
    list_file = load_data(output_filepath)
else:
    list_file = load_data(merged_filepath)


In [3]:
out = widgets.Output()  # Create an Output widget to capture and display items


def display_data_point(i, dic_data):
    if dic_data.get('Comparison', False) :
        if i + 1 < len(list_file):
            display_data_point(i + 1, list_file[i + 1])
        return
    if os.path.exists(dic_data["direct_output_figure_path"]) == False or os.path.exists(dic_data['figure_path']) == False or os.path.exists(dic_data['cot_output_figure_path']) == False:
        if i + 1 < len(list_file):
            display_data_point(i + 1, list_file[i + 1])
        return
    with out:
        figure_path = dic_data['figure_path']
        output_path_direct = dic_data["direct_output_figure_path"]
        output_path_cot = dic_data["cot_output_figure_path"]
        caption = dic_data['caption']
        print(figure_path)
        print(output_path_direct)
        print(output_path_cot)
        caption_display = widgets.Label(value="Figure Caption:" + caption)
        display(caption_display)
        display_images(figure_path, output_path_direct, output_path_cot)
        custom_layout = widgets.Layout(width='80%')
        compare = widgets.RadioButtons(
                options=["Left", "Right"],
                value=None,
                description="Which figure, left or right, more closely resembles the original, and do you think its code can more easily reproduce the original figure?",
                disabled=False,
                layout=custom_layout
            )
        button = widgets.Button(description="Submit",
                                disabled=False,
                                button_style="")
        
        Box = widgets.VBox([compare , button])
        display(Box)

        def on_submit(btn):
            compare_val = compare.value
            dict_map = {"Left": "Direct", "Right": "CoT"}
            dic_data["Comparison"] = dict_map[compare_val]
            dic_data["CoT_output_figure_path"] = output_path_cot
            list_file[i] = dic_data
            with open(output_filepath, "w") as file:
                json.dump(list_file, file, indent=4)
            out.clear_output(wait=True)
            if i + 1 < len(list_file):
                display_data_point(i + 1, list_file[i + 1])

        button.on_click(on_submit)


display(out)  # Display the Output widget initially

# Start the process by displaying the first data point
display_data_point(0, list_file[0])

Output()