In [31]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors

In [121]:
def plot_exp_result(result_file, baseline_name, propose_model=None):
    df = pd.read_excel(result_file)

    # Highlighted names and their custom text
    highlight_names = {baseline_name: "baseline"} if propose_model is None else {baseline_name: "baseline", propose_model: "proposed"}

    # Add a new column for custom text
    df['custom_text'] = df['name'].map(highlight_names).fillna('')

    # Create a scatter plot for each test_data value
    for test_data_value in df['test_data'].unique():
        subset = df[df['test_data'] == test_data_value]
        non_highlight = subset[~subset['name'].isin(highlight_names.keys())]
        # Create a figure
        fig = px.scatter(non_highlight, x='AC3_3', y='total_accuracy', color='name',
                        title=f'Scatter Plot for {test_data_value}',
                        labels={'total_accuracy': 'Total Accuracy', 'AC3_3': 'AC3_3'})
        
        # Highlight specific points and show custom text with markers
        for name, custom_text in highlight_names.items():
            if name in subset['name'].values:
                highlighted_subset = subset[subset['name'] == name]
                fig.add_trace(
                    go.Scatter(x=highlighted_subset['AC3_3'], y=highlighted_subset['total_accuracy'], 
                               text=highlighted_subset['custom_text'], mode='markers+text',name=name,
                                textposition='top center')
                )
        
        # Set width and height, and adjust legend position
        fig.update_layout(width=400, height=400)
        fig.update_xaxes(scaleanchor="y", scaleratio=1)
        fig.update_layout(margin=dict(t=50, b=50, l=50, r=50))
        fig.update_layout(showlegend=False)
        # Show the plot
        fig.show()

In [122]:
plot_exp_result('./summary/gemma-2b-test-20240227.xlsx', 'gemma-2b.platypus-alpaca-comb.0.16.1e-4.1.gemma', 'gemma-2b.platypus-alpaca-comb.0-1-3.16.1e-4.1.gemma')