In [6]:
# Load plot data
import os
import glob

from utils.general_plotter import PlotData


import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output


# run_name = "iwslt_bayesformer"
# files = os.listdir("local/translation-results-iwslt/*/")

# files = [file for file in files if file.endswith(".json")]

# read files in folders in local/translation-results-iwslt/*/


file_paths = glob.glob("local/translation-results-iwslt/**/*.json", recursive=True)
# file_paths = glob.glob("local/gpt-plot-data/full_run_24032025/*.json", recursive=True)
file_contents = [open(file_path, "r").read() for file_path in file_paths]
plot_data: list[PlotData] = [
    PlotData.model_validate_json(file_content) for file_content in file_contents
]


# Extract all unique filter values from the loaded data
eval_methods = sorted(set(data.eval_method for data in plot_data))
search_method_types = sorted(set(data.search_method_type for data in plot_data))
enable_mcdo_options = sorted(set(data.enable_mcdo for data in plot_data))
model_names = sorted(set(data.model_name for data in plot_data))
benchmarks = sorted(set(data.benchmark for data in plot_data))

# Get all unique UQ methods across all data
all_uq_methods = []
for data in plot_data:
    for uq_method in data.aq_func_names:
        if uq_method not in all_uq_methods:
            all_uq_methods.append(uq_method)
all_uq_methods = sorted(all_uq_methods)

# Create widgets for filtering
eval_method_selector = widgets.Dropdown(
    options=eval_methods,
    value=eval_methods[0] if eval_methods else None,
    description="Eval Method",
    disabled=False,
)

search_method_selector = widgets.SelectMultiple(
    options=search_method_types,
    value=search_method_types[:1] if search_method_types else [],
    description="Search Method",
    disabled=False,
)

mcdo_selector = widgets.SelectMultiple(
    options=enable_mcdo_options,
    value=enable_mcdo_options[:1] if enable_mcdo_options else [],
    description="MCDO",
    disabled=False,
)

model_selector = widgets.SelectMultiple(
    options=model_names,
    value=model_names[:1] if model_names else [],
    description="Model",
    disabled=False,
)

benchmark_selector = widgets.SelectMultiple(
    options=benchmarks,
    value=benchmarks[:1] if benchmarks else [],
    description="Benchmark",
    disabled=False,
)

uq_method_selector = widgets.SelectMultiple(
    options=all_uq_methods,
    value=all_uq_methods[:1] if all_uq_methods else [],
    description="UQ Methods",
    disabled=False,
)

# Add title input
title_input = widgets.Text(
    value="Benchmark Evaluation Results",
    placeholder="Enter plot title",
    description="Title:",
    disabled=False,
)

# Add export button
export_button = widgets.Button(
    description="Export to SVG",
    disabled=False,
    button_style="success",
    tooltip="Export current plot to SVG file",
)

output_filename = widgets.Text(
    value="plot_export.svg",
    placeholder="filename.svg",
    description="Filename:",
    disabled=False,
)

# Create output area for the plot
plot_output = widgets.Output()

# Dictionary to store muted state of plots
muted_plots = {}

def uppercase_first(string:str) -> str:
    return string[0].upper() + string[1:]

def create_plot(
    selected_eval_methods,
    selected_search_methods,
    selected_mcdo,
    selected_models,
    selected_benchmarks,
    selected_uq_methods,
    plot_title,
):
    fig = plt.figure(figsize=(10, 6))

    # Filter data based on all selections
    for data in plot_data:
        if (
            data.eval_method in selected_eval_methods
            and data.search_method_type in selected_search_methods
            and data.enable_mcdo in selected_mcdo
            and data.model_name in selected_models
            and data.benchmark in selected_benchmarks
        ):

            # For each selected UQ method, plot if available in this data
            for uq_index, uq_method in enumerate(data.aq_func_names):
                if uq_method in selected_uq_methods:
                    # Check if we have data for this UQ method
                    if uq_index < len(data.eval_scores):
                        scores = data.eval_scores[uq_index]
                        x_points = data.x_points

                        # Create a descriptive label
                        label = f"{data.model_name} - {uppercase_first(data.search_method_type)} search - {"Dropout" if data.enable_mcdo else "No dropout"} - {uq_method}"
                        
                        # Only plot if not muted
                        if not muted_plots.get(label, False):
                            plt.plot(x_points, scores, label=label)

    plt.xlabel("Retention rate")
    plt.ylabel(selected_eval_methods)
    plt.title(plot_title)
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.grid(True, linestyle="--", alpha=0.7)
    return fig

def update_plot(
    selected_eval_methods,
    selected_search_methods,
    selected_mcdo,
    selected_models,
    selected_benchmarks,
    selected_uq_methods,
    plot_title,
):
    with plot_output:
        clear_output(wait=True)
        fig = create_plot(
            selected_eval_methods,
            selected_search_methods,
            selected_mcdo,
            selected_models,
            selected_benchmarks,
            selected_uq_methods,
            plot_title,
        )
        plt.show()
        
        # Create mute toggles for visible plots
        mute_toggles = []
        for data in plot_data:
            if (
                data.eval_method in selected_eval_methods
                and data.search_method_type in selected_search_methods
                and data.enable_mcdo in selected_mcdo
                and data.model_name in selected_models
                and data.benchmark in selected_benchmarks
            ):
                for uq_index, uq_method in enumerate(data.aq_func_names):
                    if uq_method in selected_uq_methods and uq_index < len(data.eval_scores):
                        label = f"{data.model_name} - {uppercase_first(data.search_method_type)} search - {"Dropout" if data.enable_mcdo else "No dropout"} - {uq_method}"
                        
                        # Create toggle button for this plot
                        toggle = widgets.ToggleButton(
                            value=not muted_plots.get(label, False),
                            description=f"Show {label}",
                            tooltip=f"Toggle visibility of {label}",
                            layout=widgets.Layout(width='auto')
                        )
                        
                        def create_toggle_handler(label):
                            def toggle_handler(change):
                                muted_plots[label] = not change.new
                                update_plot(
                                    selected_eval_methods,
                                    selected_search_methods,
                                    selected_mcdo,
                                    selected_models,
                                    selected_benchmarks,
                                    selected_uq_methods,
                                    plot_title,
                                )
                            return toggle_handler
                        
                        toggle.observe(create_toggle_handler(label), names='value')
                        mute_toggles.append(toggle)
        
        # Display toggle buttons in a flowbox layout
        if mute_toggles:
            display(widgets.Box(mute_toggles, layout=widgets.Layout(
                display='flex',
                flex_flow='row wrap',
                align_items='stretch'
            )))

# Export function
def export_svg(b):
    fig = create_plot(
        eval_method_selector.value,
        search_method_selector.value,
        mcdo_selector.value,
        model_selector.value,
        benchmark_selector.value,
        uq_method_selector.value,
        title_input.value,
    )
    filename = output_filename.value
    if not filename.endswith(".svg"):
        filename += ".svg"
    fig.savefig(filename, format="svg", bbox_inches="tight")
    plt.close(fig)
    print(f"Plot exported to {filename}")


# Connect the button to the export function
export_button.on_click(export_svg)

# Display the export controls
export_controls = widgets.HBox([output_filename, export_button])
display(export_controls)

# Use interactive to update the plot when any selection changes
interactive_plot = widgets.interactive(
    update_plot,
    selected_eval_methods=eval_method_selector,
    selected_search_methods=search_method_selector,
    selected_mcdo=mcdo_selector,
    selected_models=model_selector,
    selected_benchmarks=benchmark_selector,
    selected_uq_methods=uq_method_selector,
    plot_title=title_input,
)

# Use VBox to organize the controls in a more compact way
controls = widgets.VBox(
    [
        widgets.HBox([eval_method_selector, search_method_selector, mcdo_selector]),
        widgets.HBox([model_selector, benchmark_selector, uq_method_selector]),
        title_input,
    ]
)

display(controls)
display(plot_output)

# Trigger initial plot update
update_plot(
    eval_method_selector.value,
    search_method_selector.value,
    mcdo_selector.value,
    model_selector.value,
    benchmark_selector.value,
    uq_method_selector.value,
    title_input.value,
)

HBox(children=(Text(value='plot_export.svg', description='Filename:', placeholder='filename.svg'), Button(butt…

VBox(children=(HBox(children=(Dropdown(description='Eval Method', options=('BLEU',), value='BLEU'), SelectMult…

Output()

Plot exported to iwslt14_retention_curve.svg


Plot exported to iwslt14_retention_curve.svg
