In [13]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors as plt_colors
import math


colors=["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC", "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"]

In [14]:
# Read dfframe
negative_prompt_negative_examples_stats_df = pd.read_csv("negative_prompt_negative_examples_stats.csv")
negative_prompt_mixed_examples_stats_df = pd.read_csv("negative_prompt_mixed_examples_stats.csv")
critical_prompt_negative_examples_stats_df = pd.read_csv("critical_prompt_negative_examples_stats.csv")
critical_prompt_mixed_examples_stats_df = pd.read_csv("critical_prompt_mixed_examples_stats.csv")

total_df = pd.concat([negative_prompt_negative_examples_stats_df, negative_prompt_mixed_examples_stats_df, critical_prompt_negative_examples_stats_df, critical_prompt_mixed_examples_stats_df])

def filter_and_rename(df, method, correct_state, incorrect_state):
    temp = df[(df['method'] == method) & (df['result_type'].isin([correct_state, incorrect_state]))].copy()
    temp.loc[:, 'result_type'] = temp['result_type'].map({correct_state: 'correct', incorrect_state: 'incorrect'})
    return temp

filtered_total_df_list = [
    total_df[total_df['method'] == 'cot_1'],
    filter_and_rename(total_df, 'selfreflection_cot_2', 'correct_state_2', 'incorrect_state_2'),
    filter_and_rename(total_df, 'reflexion_cot_2', 'correct_state_2', 'incorrect_state_2'),
    filter_and_rename(total_df, 'selfreflection_cot_3', 'correct_state_3', 'incorrect_state_3'),
    filter_and_rename(total_df, 'reflexion_cot_3', 'correct_state_3', 'incorrect_state_3')
]

filtered_total_df = pd.concat(filtered_total_df_list, ignore_index=True)


In [15]:
new_total_df = total_df.groupby(['prompt_examples', 'method', 'result_type']).agg({'value': 'sum'}).reset_index()

# Filter "result_type" to only include "correct_state_1", "correct_state_2", "correct_state_3", "incorrect_state_1" and "incorrect_state_2" and "incorrect_state_3"
new_total_df = new_total_df[new_total_df['result_type'].isin(['correct_state_1', 'correct_state_2', 'correct_state_3', 'incorrect_state_1', 'incorrect_state_2', 'incorrect_state_3'])]

In [16]:
# # Define a color palette for methods
# color_palette = [
#     "rgba(31, 119, 180, 0.8)",   # Blue for selfreflection_cot_2
#     "rgba(255, 127, 14, 0.8)",   # Orange for reflexion_cot_2
#     "rgba(44, 160, 44, 0.8)",    # Green for selfreflection_cot_3
#     "rgba(214, 39, 40, 0.8)",    # Red for reflexion_cot_3
#     "rgba(148, 103, 189, 0.8)",  # Purple for additional methods if any
#     "rgba(140, 86, 75, 0.8)",    # Brown
#     "rgba(227, 119, 194, 0.8)",  # Pink
#     "rgba(127, 127, 127, 0.8)",  # Gray
#     "rgba(188, 189, 34, 0.8)",   # Olive
#     "rgba(23, 190, 207, 0.8)"     # Cyan
# ]


# import pandas as pd
# import plotly.graph_objects as go
# from plotly.subplots import make_subplots
# import math

# def create_sankey_for_prompt_method(df, prompt, method):
#     """
#     Creates a Sankey diagram for a specific prompt_examples category and method.

#     Parameters:
#     - df: pandas DataFrame containing the data.
#     - prompt: The specific prompt_examples category to visualize.
#     - method: The specific method to visualize.

#     Returns:
#     - plotly.graph_objects.Figure object representing the Sankey diagram.
#     """
#     # Filter the dataframe for the specific prompt and method
#     if pd.isna(prompt):
#         prompt_df = df[(df['prompt_examples'].isna()) & (df['method'] == method)]
#         prompt_label = 'NaN'
#     else:
#         prompt_df = df[(df['prompt_examples'] == prompt) & (df['method'] == method)]
#         prompt_label = prompt

#     # Initialize nodes
#     nodes = []

#     # Determine the number of states based on the method
#     if method.endswith('_3'):
#         max_state = 3
#     elif method.endswith('_2'):
#         max_state = 2
#     else:
#         max_state = 1

#     # Add state nodes
#     for state in range(1, max_state + 1):
#         nodes.append(f'State {state} Correct')
#         nodes.append(f'State {state} Incorrect')

#     # Create a mapping from node to index
#     node_indices = {node: idx for idx, node in enumerate(nodes)}

#     # Initialize link lists
#     source = []
#     target = []
#     value = []
#     link_colors = []

#     # Define colors
#     color_correct = "rgba(34, 139, 34, 0.8)"      # Forest Green
#     color_incorrect = "rgba(220, 20, 60, 0.8)"   # Crimson

#     if max_state >= 2:
#         correct_s1 = prompt_df[prompt_df['result_type'] == 'correct_state_1']['value'].sum()
#         incorrect_s1 = prompt_df[prompt_df['result_type'] == 'incorrect_state_1']['value'].sum()
#         correct_s2 = prompt_df[prompt_df['result_type'] == 'correct_state_2']['value'].sum()
#         incorrect_s2 = prompt_df[prompt_df['result_type'] == 'incorrect_state_2']['value'].sum()

#         # State 1 Correct -> State 2 Correct
#         source.append(node_indices['State 1 Correct'])
#         target.append(node_indices['State 2 Correct'])
#         value.append(correct_s2)
#         link_colors.append(color_correct)

#         # State 1 Correct -> State 2 Incorrect
#         source.append(node_indices['State 1 Correct'])
#         target.append(node_indices['State 2 Incorrect'])
#         value.append(correct_s1 - correct_s2)
#         link_colors.append(color_incorrect)

#         # State 1 Incorrect -> State 2 Correct
#         source.append(node_indices['State 1 Incorrect'])
#         target.append(node_indices['State 2 Correct'])
#         value.append(correct_s2 - correct_s1)
#         link_colors.append(color_correct)

#         # State 1 Incorrect -> State 2 Incorrect
#         source.append(node_indices['State 1 Incorrect'])
#         target.append(node_indices['State 2 Incorrect'])
#         value.append(incorrect_s2)
#         link_colors.append(color_incorrect)

#     if max_state == 3:
#         correct_s3 = prompt_df[prompt_df['result_type'] == 'correct_state_3']['value'].sum()
#         incorrect_s3 = prompt_df[prompt_df['result_type'] == 'incorrect_state_3']['value'].sum()

#         # State 2 Correct -> State 3 Correct
#         source.append(node_indices['State 2 Correct'])
#         target.append(node_indices['State 3 Correct'])
#         value.append(correct_s3)
#         link_colors.append(color_correct)

#         # State 2 Correct -> State 3 Incorrect
#         source.append(node_indices['State 2 Correct'])
#         target.append(node_indices['State 3 Incorrect'])
#         value.append(correct_s2 - correct_s3)
#         link_colors.append(color_incorrect)

#         # State 2 Incorrect -> State 3 Correct
#         source.append(node_indices['State 2 Incorrect'])
#         target.append(node_indices['State 3 Correct'])
#         value.append(correct_s3 - correct_s2)
#         link_colors.append(color_correct)

#         # State 2 Incorrect -> State 3 Incorrect
#         source.append(node_indices['State 2 Incorrect'])
#         target.append(node_indices['State 3 Incorrect'])
#         value.append(incorrect_s3)
#         link_colors.append(color_incorrect)

#     # Create the Sankey diagram
#     sankey = go.Sankey(
#         node=dict(
#             pad=15,
#             thickness=20,
#             line=dict(color="black", width=0.5),
#             label=nodes,
#             color="lightblue"
#         ),
#         link=dict(
#             source=source,
#             target=target,
#             value=value,
#             color=link_colors,
#             hovertemplate='Value: %{value}<br>From: %{source.label}<br>To: %{target.label}<extra></extra>'
#         )
#     )

#     # Create the figure
#     fig = go.Figure(data=[sankey])
#     fig.update_layout(
#         title_text=f"Sankey Diagram for Prompt: {prompt_label}, Method: {method}",
#         font_size=10
#     )

#     return fig

# # Identify unique prompt_examples and method combinations
# prompt_method_combinations = new_total_df.groupby(['prompt_examples', 'method']).size().reset_index()[['prompt_examples', 'method']]
# num_combinations = len(prompt_method_combinations)

# # Define the number of columns for the subplot grid
# cols = 2  # Adjust based on preference
# rows = math.ceil(num_combinations / cols)

# # Create subplots
# fig = make_subplots(
#     rows=rows, 
#     cols=cols, 
#     specs=[[{'type': 'sankey'} for _ in range(cols)] for _ in range(rows)],
#     subplot_titles=[f"{str(row.prompt_examples) if pd.notna(row.prompt_examples) else 'NaN'} - {row.method}" 
#                     for _, row in prompt_method_combinations.iterrows()]
# )

# # Iterate and add Sankey diagrams to subplots
# for i, (_, row) in enumerate(prompt_method_combinations.iterrows()):
#     row_idx = (i // cols) + 1
#     col_idx = (i % cols) + 1
#     sankey_fig = create_sankey_for_prompt_method(new_total_df, row.prompt_examples, row.method)
#     sankey = sankey_fig['data'][0]
#     fig.add_trace(sankey, row=row_idx, col=col_idx)

# # Update layout
# fig.update_layout(
#     title_text="Sankey Diagrams for Each Prompt-Method Combination",
#     font_size=10,
#     height=600 * rows  # Adjust height based on number of rows
# )

# fig.show()


In [17]:
# Define which result_types are considered correct and incorrect
correct_types = ['correct', 'correct_state_1', 'correct_state_2', 'correct_state_3']
incorrect_types = ['incorrect', 'incorrect_state_1', 'incorrect_state_2', 'incorrect_state_3']

# Categorize each result_type
def categorize_result(result_type):
    if result_type in correct_types:
        return 'correct'
    elif result_type in incorrect_types:
        return 'incorrect'
    else:
        return 'other'

filtered_total_df['result_category'] = filtered_total_df['result_type'].apply(categorize_result)

# Filter out any 'other' categories if present
filtered_total_df = filtered_total_df[filtered_total_df['result_category'].isin(['correct', 'incorrect'])]

# Separate 'cot_1' data
cot1_df = filtered_total_df[filtered_total_df['method'] == 'cot_1']
other_methods_df = filtered_total_df[filtered_total_df['method'] != 'cot_1']

# Pivot the tables to get correct and incorrect counts per method and prompt_example
cot1_pivot = cot1_df.pivot_table(
    index='method',
    columns='result_category',
    values='value',
    aggfunc='sum',
    fill_value=0
).reset_index()

other_accuracy_pivot = other_methods_df.pivot_table(
    index=['prompt_examples', 'method'],
    columns='result_category',
    values='value',
    aggfunc='sum',
    fill_value=0
).reset_index()

# Calculate accuracy for 'cot_1'
cot1_pivot['accuracy'] = cot1_pivot['correct'] / (cot1_pivot['correct'] + cot1_pivot['incorrect'])
cot1_pivot['accuracy'] = cot1_pivot['accuracy'].fillna(0)

# Calculate accuracy for other methods
other_accuracy_pivot['accuracy'] = other_accuracy_pivot['correct'] / (other_accuracy_pivot['correct'] + other_accuracy_pivot['incorrect'])
other_accuracy_pivot['accuracy'] = other_accuracy_pivot['accuracy'].fillna(0)

In [18]:
# desired_order = ['selfreflection_cot_2', 'selfreflection_cot_3', 'reflexion_cot_2', 'reflexion_cot_3']
desired_order = ['Self-Correction + CoT<br>2 trials', 'Self-Correction + CoT<br>3 trials', 'Reflexion + CoT<br>Max. 2 trials', 'Reflexion + CoT<br>Max. 3 trials']

method_name_mapping = {
    'selfreflection_cot_2': 'Self-Correction + CoT<br>2 trials',
    'selfreflection_cot_3': 'Self-Correction + CoT<br>3 trials',
    'reflexion_cot_2': 'Reflexion + CoT<br>Max. 2 trials',
    'reflexion_cot_3': 'Reflexion + CoT<br>Max. 3 trials',
}

other_accuracy_pivot['method'] = other_accuracy_pivot['method'].map(method_name_mapping)

# Ensure 'method' is a categorical variable with the desired order
other_accuracy_pivot['method'] = pd.Categorical(other_accuracy_pivot['method'], categories=desired_order, ordered=True)

# Sort the DataFrame based on the categorical order
other_accuracy_pivot = other_accuracy_pivot.sort_values('method')

# Map each method to a color
method_colors = {method: colors[i+1 % len(colors)] for i, method in enumerate(desired_order)}

In [21]:
# Ensure mapping dictionaries are defined correctly
prompt_example_mapping = {
    'negative_prompt_negative_examples': 'Negative Prompt<br>Negative Examples',
    'negative_prompt_mixed_examples': 'Negative Prompt<br>Mixed Examples',
    'critical_prompt_negative_examples': 'Critical Prompt<br>Negative Examples',
    'critical_prompt_mixed_examples': 'Critical Prompt<br>Mixed Examples',
    # Add other prompt examples if necessary
}

# Optional: Customize x-axis category order with new labels
desired_prompt_order = [
    'Negative Prompt<br>Negative Examples',
    'Negative Prompt<br>Mixed Examples',
    'Critical Prompt<br>Negative Examples',
    'Critical Prompt<br>Mixed Examples'
]

# Initialize the figure
fig = go.Figure()

# Multiply accuracy by 100 to convert to percentage scale
# Assuming 'other_accuracy_pivot' is your DataFrame with accuracy data

# Add a bar for each method
for method in desired_order:
    method_data = other_accuracy_pivot[other_accuracy_pivot['method'] == method].copy()
    # Map the prompt examples to new labels
    method_data['prompt_examples'] = method_data['prompt_examples'].map(prompt_example_mapping)
    # Multiply accuracy by 100
    method_data['accuracy_percent'] = method_data['accuracy'] * 100
    fig.add_trace(go.Bar(
        x=method_data['prompt_examples'],
        y=method_data['accuracy_percent'],
        name=method_name_mapping.get(method, method),
        text=[f"{acc:.1f}" for acc in method_data['accuracy_percent']],  # Remove % symbol
        textposition='outside',
        marker_color=method_colors.get(method, '#000000'),
    ))

# Calculate the average accuracy for 'cot_1' and convert to percentage
cot1_accuracy = cot1_pivot['accuracy'].iloc[0]  # Assuming only one row for 'cot_1'
cot1_accuracy_percent = cot1_accuracy * 100

# Add a horizontal line for 'cot_1' accuracy
fig.add_shape(
    type="line",
    xref='x',
    yref='y',
    x0=-0.5,  # Start just before the first bar
    y0=cot1_accuracy_percent,
    x1=len(desired_prompt_order) - 0.5,  # End just after the last bar
    y1=cot1_accuracy_percent,
    line=dict(color=colors[0], width=4, dash="dash"),
)

# Add annotation for 'cot_1' accuracy
fig.add_annotation(
    x=0.54,  # Relative position along the x-axis in paper coordinates
    y=cot1_accuracy_percent,
    xref="paper",
    yref="y",
    text=f"<b>CoT, 1 Trial: {cot1_accuracy_percent:.1f}%</b>",  # Include % symbol in annotation if desired
    showarrow=False,
    font=dict(color="black", size=12),
    align="left",
    yshift=-10
)

# Update layout for grouped bars with modified axes and grid lines
fig.update_layout(
    barmode='group',
    title={
        'text': '<b>Accuracy - Trip Planning (Gemini 1.5 Flash, Few-Shot Reflective Prompting)</b>',
        'y': 0.95,
        'x': 0.5,
        'xanchor': 'center',
        'yanchor': 'top'
    },
    xaxis_title='Reflective Prompt and Few-Shot Reflective Examples',
    yaxis_title='Accuracy (%)',  # Add (%) to the y-axis title
    template='plotly_white',
    font=dict(family="Arial, sans-serif", size=12),
    width=1000,
    height=600,
    legend_title_text='Method',
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="center",
        x=0.5
    ),
    # Modify xaxis and yaxis to set axis lines to black and enhance grid lines
    xaxis=dict(
        categoryorder='array',
        categoryarray=desired_prompt_order,
        linecolor='black',             # Set x-axis line color to black
        showline=True,                 # Show x-axis line
        tickfont=dict(color='black'),  # Set x-axis tick labels to black
        showgrid=False,                # Remove vertical grid lines
        zeroline=False,                # Remove zero line if not needed
        mirror='all'                   # Mirror x-axis lines on all sides (adds top line)
    ),
    yaxis=dict(
        tickformat=None,               # Show numbers without the percent symbol
        linecolor='black',             # Set y-axis line color to black
        showline=True,                 # Show y-axis line
        tickfont=dict(color='black'),  # Set y-axis tick labels to black
        gridcolor='lightgrey',         # Set grid line color to light grey
        gridwidth=1,                  # Increase grid line width for intensity
        zeroline=False,                # Remove zero line if not needed
        mirror='all',                   # Mirror y-axis lines on all sides (adds right line)
        #show 0 to 100
        range=[0, 40],
        tickvals=np.arange(0, 41, 5)
    )
)

# Set the x-axis category order to the desired order
fig.update_xaxes(categoryorder='array', categoryarray=desired_prompt_order)

# Show the figure
fig.show()
fig.write_image("accuracy_by_method_prompt_examples.pdf")


In [8]:
from typing import List, Tuple, Dict

def process_data(total_df: pd.DataFrame, prompt_example: str) -> Tuple[List[str], List[float]]:
    """
    Process data for a given prompt example.
    
    Args:
    total_df (pd.DataFrame): The complete dataset.
    prompt_example (str): The specific prompt example to process.
    
    Returns:
    Tuple[List[str], List[float]]: A tuple containing methods and their corresponding values.
    """
    df_prompt = total_df[total_df['prompt_examples'] == prompt_example]
    
    methods = [
        'Reflexion + CoT<br>Max. 3 trials',
        'Reflexion + CoT<br>Max. 2 trials',
        'Self-Reflection + CoT<br>3 trials',
        'Self-Reflection + CoT<br>2 trials'
    ]
    
    method_mapping = {
        'reflexion_cot_3': 'Reflexion + CoT<br>Max. 3 trials',
        'reflexion_cot_2': 'Reflexion + CoT<br>Max. 2 trials',
        'selfreflection_cot_3': 'Self-Reflection + CoT<br>3 trials',
        'selfreflection_cot_2': 'Self-Reflection + CoT<br>2 trials',
    }
    
    result = []
    
    for method in method_mapping.keys():
        df_method = df_prompt[df_prompt['method'] == method]
        df_transitions = df_method[df_method['transition'] == 'start_to_final']
        transition_counts = df_transitions.groupby('result_type')['value'].sum()
        
        correct_to_correct = transition_counts.get('correct_to_correct', 0)
        incorrect_to_correct = transition_counts.get('incorrect_to_correct', 0)
        correct_to_incorrect = transition_counts.get('correct_to_incorrect', 0)
        incorrect_to_incorrect = transition_counts.get('incorrect_to_incorrect', 0)
        
        total = sum(transition_counts)
        if total == 0:
            result.extend([0, 0, 0])
        else:
            result.extend([
                (correct_to_correct + incorrect_to_incorrect) / total * 100,
                incorrect_to_correct / total * 100,
                correct_to_incorrect / total * 100
            ])
    
    return methods, result

def create_heatmap(methods: List[str], values: List[float], show_colorbar: bool = False) -> go.Heatmap:
    """
    Create a heatmap trace for the subplot.
    
    Args:
    methods (List[str]): List of methods.
    values (List[float]): List of corresponding values.
    show_colorbar (bool): Whether to show the colorbar.
    
    Returns:
    go.Heatmap: A Plotly heatmap object.
    """
    z_values = [values[i:i+3] for i in range(0, len(values), 3)]
    return go.Heatmap(
        z=z_values,
        x=['No Change', 'Incorrect ⇒ Correct', 'Correct ⇒ Incorrect'],
        y=methods,
        colorscale='RdBu',
        zmin=0,
        zmax=100,
        showscale=show_colorbar,
        colorbar=dict(
            title="Percentage (%)", 
            orientation="h",
            y=1.1,
            x=0.5,
            yanchor="bottom",
            thickness=20,
            title_side="top",
        ),
        text=[[f"{val:.1f}%" for val in row] for row in z_values],
        texttemplate="%{text}",
        textfont={"size": 12},
        hovertemplate='Method: %{y}<br>Transition: %{x}<br>Percentage: %{z:.1f}%<extra></extra>'
    )

def create_visualization(total_df: pd.DataFrame) -> None:
    """
    Create and save the visualization.
    
    Args:
    total_df (pd.DataFrame): The complete dataset.
    """
    titles = [
        '<b>Negative Prompt<br>Negative Examples</b>',
        '<b>Negative Prompt<br>Mixed Examples</b>',
        '<b>Critical Prompt<br>Negative Examples</b>',
        '<b>Critical Prompt<br>Mixed Examples</b>'
    ]

    prompt_examples = [
        'negative_prompt_negative_examples',
        'negative_prompt_mixed_examples',
        'critical_prompt_negative_examples',
        'critical_prompt_mixed_examples'
    ]

    fig = make_subplots(rows=2,
                        cols=2,
                        subplot_titles=titles,
                        horizontal_spacing=0.03,  # Increased from 0.1
                        vertical_spacing=0.12,   # Increased from 0.2,
                        shared_xaxes=True,
     )

    for i, (title, prompt_example) in enumerate(zip(titles, prompt_examples)):
        row = i // 2 + 1
        col = i % 2 + 1
        
        methods, values = process_data(total_df, prompt_example)
        heatmap = create_heatmap(methods, values, show_colorbar=(i == 1))
        fig.add_trace(heatmap, row=row, col=col)

        if col == 1:
            fig.update_yaxes(
                title_text='Method',
                row=row,
                col=col,
                title_standoff=25
            )
        
        if row == 2:
            fig.update_xaxes(
                title_text='Transition',
                row=row,
                col=col,
                # Move it a bit down
                title_standoff=60
        )


    fig.update_layout(
        title={
            'text': '<b>Transition Distribution - Trip Planning (Gemini 1.5 Flash, Few-Shot Reflective Prompting)</b>',
            'y': 0.94,
            'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top',
        },
        template='plotly_white',
        font=dict(family="Arial, sans-serif", size=12),
        width=1000,
        height=700,
        margin=dict(t=150, b=0) 
    )

    # Update x-axes
    fig.update_xaxes(showgrid=False, zeroline=False)

    # Update y-axes
    fig.update_yaxes(showgrid=False, zeroline=False)
    fig.update_yaxes(showticklabels=False, row=1, col=2)
    fig.update_yaxes(showticklabels=False, row=2, col=2)

    for i in fig['layout']['annotations']:
        i['font'] = dict(family="Arial, sans-serif", size=12)

    fig.update_xaxes(showgrid=True)
    fig.update_yaxes(showgrid=True)

    fig.show()
    fig.write_image("transition_distribution_prompt_examples.pdf")


create_visualization(total_df)

In [9]:
df = total_df

filtered_df_list = [
    df[df['method'] == 'cot_1'],
    filter_and_rename(df, 'selfreflection_cot_2', 'correct_state_2', 'incorrect_state_2'),
    filter_and_rename(df, 'reflexion_cot_2', 'correct_state_2', 'incorrect_state_2'),
    filter_and_rename(df, 'selfreflection_cot_3', 'correct_state_3', 'incorrect_state_3'),
    filter_and_rename(df, 'reflexion_cot_3', 'correct_state_3', 'incorrect_state_3')
]

filtered_df = pd.concat(filtered_df_list, ignore_index=True)

desired_order = ['cot_1', 'selfreflection_cot_2', 'selfreflection_cot_3', 'reflexion_cot_2', 'reflexion_cot_3']
filtered_df['method'] = pd.Categorical(filtered_df['method'], categories=desired_order, ordered=True)

groupped_df = filtered_df.groupby(["method", "result_type"]).agg({"value": "sum"}).reset_index()

In [11]:
# Assuming 'total_df' is your DataFrame and other necessary mappings are defined
# Define your colors appropriately
method_colors = {
    'selfreflection_cot_2': colors[0],
    'selfreflection_cot_3': colors[1],
    'reflexion_cot_2': colors[2],
    'reflexion_cot_3': colors[3],
}
method_name_mapping = {
    'selfreflection_cot_2': 'Self-Reflection + CoT<br>2 Trials',
    'selfreflection_cot_3': 'Self-Reflection + CoT<br>3 Trials',
    'reflexion_cot_2': 'Reflexion + CoT<br>Max. 2 Trials',
    'reflexion_cot_3': 'Reflexion + CoT<br>Max. 3 Trials',
}

# Function to process data from the new dataframe
def process_data(total_df, prompt_example):
    # Filter the dataframe for the given prompt_example
    df_prompt = total_df[total_df['prompt_examples'] == prompt_example]
    
    categories = ['Self-Reflection<br>+ CoT<br>2 trials', 
                  'Self-Reflection<br>+ CoT<br>3 trials',
                  'Reflexion<br>+ CoT<br>Max. 2 trials', 
                  'Reflexion<br>+ CoT<br>Max. 3 trials']
    
    method_mapping = {
        'cot_1': 'CoT<br><br>1 trial',
        'selfreflection_cot_2': 'Self-Reflection + CoT<br>2 trials',
        'selfreflection_cot_3': 'Self-Reflection + CoT<br>3 trials',
        'reflexion_cot_2': 'Reflexion + CoT<br>Max. 2 trials',
        'reflexion_cot_3': 'Reflexion + CoT<br>Max. 3 trials',
    }
    
    no_change_values = []
    incorrect_to_correct_values = []
    correct_to_incorrect_values = []
    
    # For each method in the methods list
    for method_key in ['selfreflection_cot_2', 'selfreflection_cot_3', 'reflexion_cot_2', 'reflexion_cot_3']:
        # Map method_key to category
        category = method_mapping[method_key]
        
        # Filter data for the method
        df_method = df_prompt[df_prompt['method'] == method_key]
        
        # Filter data where transition == 'start_to_final'
        df_transitions = df_method[df_method['transition'] == 'start_to_final']
        
        # Sum 'value' for each 'result_type'
        transition_counts = df_transitions.groupby('result_type')['value'].sum()
        
        # Get counts for each transition type
        correct_to_correct = transition_counts.get('correct_to_correct', 0)
        incorrect_to_correct = transition_counts.get('incorrect_to_correct', 0)
        correct_to_incorrect = transition_counts.get('correct_to_incorrect', 0)
        incorrect_to_incorrect = transition_counts.get('incorrect_to_incorrect', 0)
        
        # Compute no_change
        no_change = correct_to_correct + incorrect_to_incorrect
        
        # Now compute percentages
        total = no_change + incorrect_to_correct + correct_to_incorrect
        if total == 0:
            no_change_perc = incorrect_to_correct_perc = correct_to_incorrect_perc = 0
        else:
            no_change_perc = (no_change / total) * 100
            incorrect_to_correct_perc = (incorrect_to_correct / total) * 100
            correct_to_incorrect_perc = (correct_to_incorrect / total) * 100
        
        no_change_values.append(no_change_perc)
        incorrect_to_correct_values.append(incorrect_to_correct_perc)
        correct_to_incorrect_values.append(correct_to_incorrect_perc)
    
    return categories, no_change_values, incorrect_to_correct_values, correct_to_incorrect_values

# Now, collect the data for each of the four plots
titles = [
    '<b>Negative Prompt<br>Negative Examples</b>',
    '<b>Negative Prompt<br>Mixed Examples</b>',
    '<b>Critical Prompt<br>Negative Examples</b>',
    '<b>Critical Prompt<br>Mixed Examples</b>'
]

prompt_examples = [
    'negative_prompt_negative_examples',
    'negative_prompt_mixed_examples',
    'critical_prompt_negative_examples',
    'critical_prompt_mixed_examples'
]

# Initialize figure with subplots
fig = make_subplots(
    rows=2, 
    cols=2, 
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.175,
)

# Define colors for the transition types
transition_colors = {
    'No Change': colors[0],
    'Incorrect ⇒ Correct': colors[2],
    'Correct ⇒ Incorrect': colors[3]
}

# Loop over each scenario and add traces to the appropriate subplot
for i, (title, prompt_example) in enumerate(zip(titles, prompt_examples)):
    categories, no_change_values, incorrect_to_correct_values, correct_to_incorrect_values = process_data(total_df, prompt_example)
    
    # Determine row and column
    row = i // 2 + 1  # integer division
    col = i % 2 + 1    # modulo
    
    # Show legend only for the first subplot
    show_legend = True if i == 0 else False
    
    # Add traces
    fig.add_trace(go.Bar(
        x=categories,
        y=no_change_values,
        name='No Change',
        marker_color=transition_colors['No Change'],
        text=[f'{val:.1f}%' for val in no_change_values],
        textposition='inside',
        showlegend=show_legend
    ), row=row, col=col)
    
    fig.add_trace(go.Bar(
        x=categories,
        y=incorrect_to_correct_values,
        name='Incorrect ⇒ Correct',
        marker_color=transition_colors['Incorrect ⇒ Correct'],
        text=[f'{val:.1f}%' for val in incorrect_to_correct_values],
        textposition='inside',
        showlegend=show_legend
    ), row=row, col=col)
    
    fig.add_trace(go.Bar(
        x=categories,
        y=correct_to_incorrect_values,
        name='Correct ⇒ Incorrect',
        marker_color=transition_colors['Correct ⇒ Incorrect'],
        text=[f'{val:.1f}%' for val in correct_to_incorrect_values],
        textposition='inside',
        showlegend=show_legend
    ), row=row, col=col)
    
    # Update axes for the current subplot
    fig.update_xaxes(
        categoryorder='array',
        categoryarray=categories,
        linecolor='black',             # Set x-axis line color to black
        showline=True,                 # Show x-axis line
        tickfont=dict(color='black'),  # Set x-axis tick labels to black
        showgrid=False,                # Remove vertical grid lines
        zeroline=False,                # Remove zero line if not needed
        mirror='all',                  # Mirror x-axis lines on all sides (adds top line)
        row=row,
        col=col
    )
    
    fig.update_yaxes(
        tickformat=None,               # Show numbers without the percent symbol
        linecolor='black',             # Set y-axis line color to black
        showline=True,                 # Show y-axis line
        tickfont=dict(color='black'),  # Set y-axis tick labels to black
        gridcolor='lightgrey',         # Set grid line color to light grey
        gridwidth=1,                 # Increase grid line width for intensity
        zeroline=False,                # Remove zero line if not needed
        mirror='all',                  # Mirror y-axis lines on all sides (adds right line)
        range=[0, 100],                # Set y-axis range from 0 to 100
        tickvals=np.arange(0, 101, 20),# Set y-axis ticks at every 10%
        row=row,
        col=col
    )
    
    # Update y-axis title for the first column subplots
    if col == 1:
        fig.update_yaxes(
            title_text='Percentage (%)',
            row=row,
            col=col
        )  

    if row == 2:
        fig.update_xaxes(
            title_text='Transition',
            row=row,
            col=col,
        )

# Update layout
fig.update_layout(
    barmode='stack',
    title=dict(
        text='<b>Transition Outcomes - Trip Planning (Gemini 1.5 Flash, Few-Shot Reflective Prompting)</b>',
        x=0.5,
        y=0.97,
        xanchor='center',
        yanchor='top',
    ),
    template="plotly_white",
    font=dict(family="Arial, sans-serif", size=12),
    width=1000,
    height=800,
    legend_title="Transition",
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.08,
        xanchor="center",
        x=0.5
    ),
    showlegend=True,
    margin=dict(
        t=120  # Increased top margin to accommodate the main title and legends
    )
)

# Adjust subplot titles font size
for annotation in fig['layout']['annotations']:
    annotation['font'] = dict(family="Arial, sans-serif", size=14)

# Show the figure
fig.show()
fig.write_image("transition_by_method_prompt_examples.pdf")
