In [21]:
import functools
import pickle

import plotly.express as px
from plotly.subplots import make_subplots

RUNDIRS = '../logs/rundirs'

In [4]:
@functools.cache
def get_key2df(runname):
    rundir = f'{RUNDIRS}/{runname}'
    with open(f'{rundir}/key2df.pickle', 'rb') as file:
        return pickle.load(file)
    
    
list(get_key2df('20241229_131008'))

[(True, None),
 (True, True),
 (True, False),
 (False, None),
 (False, True),
 (False, False)]

In [31]:
def plot_df(*, runname, df, col_strategy, strategy, col_i, col_j, col_data):
    df = df[df.index.get_level_values(col_strategy) == strategy]
    heatmap_data = df[[col_data]].reset_index().pivot(
        index=col_i, columns=col_j, values=col_data
    )

    # Create an interactive heatmap
    fig = px.imshow(
        heatmap_data,
        labels={"x": col_j, "y": col_i, "color": col_data},
        title=f"{runname}: {strategy}: {col_data}",
    )
    
    fig.update_traces(
        hovertemplate=(
            col_i + ": %{y}<br>" + col_j + ": %{x}<br>" + col_data + ": %{z}<extra></extra>"
        )
    )
    
    return heatmap_data, fig


def plot_df_all(runname, are_bridges, label_are_bridges):
    key2df = get_key2df(runname)
    df = key2df[are_bridges, True]
    
    col_strategy = 'Coordination strategy'
    col_i = 'i_map'
    col_j = 'Positions variant'
    col_data ='No. of completed missions'
    
    strategies = df.index.get_level_values(col_strategy).unique()
    
    # Create subplots
    fig = make_subplots(
        rows=1,
        cols=len(strategies),
        subplot_titles=[f"{strategy}" for strategy in strategies],
        #shared_yaxes=True
    )
    
    for idx, strategy in enumerate(strategies):
        heatmap_data, heatmap_fig = plot_df(
            runname=runname, df=df, 
            col_strategy=col_strategy, strategy=strategy, 
            col_i=col_i, col_j=col_j, col_data=col_data,
        )
          
        # Add heatmap to the subplot
        for trace in heatmap_fig.data:
            trace.update(coloraxis="coloraxis")  # Link each subplot to the shared color axis
            fig.add_trace(trace, row=1, col=idx + 1)
        
        # Apply axis settings to each subplot
        xaxis_key = f"xaxis{idx + 1}" if idx > 0 else "xaxis"
        yaxis_key = f"yaxis{idx + 1}" if idx > 0 else "yaxis"
        
        fig.update_layout(**{
            xaxis_key: dict(
                title=col_j,
                tickmode="array",
                tickvals=list(heatmap_data.columns),
                ticktext=heatmap_data.columns.tolist(),
            ),
            yaxis_key: dict(
                title=col_i,
                tickmode="array",
                tickvals=list(heatmap_data.index),
                ticktext=heatmap_data.index.tolist(),
                autorange="reversed",  # Reverse the y-axis for top-to-bottom ticks
            )
        })
    
    # Update layout with shared color scale
    fig.update_layout(
        title=f"{col_data} ({runname}, {col_strategy}, {label_are_bridges})",
        coloraxis=dict(
            colorscale="Greens",  # Apply "Greens" color scale to the shared color axis
            colorbar=dict(
                title=col_data,
                titleside="right",  # Position the colorbar title vertically to the left
                titlefont=dict(size=12)  # Optional: Adjust font size for better appearance
            )
        )
    )
    
    # Display the figure
    fig.show()



# Call the function
plot_df_all('20241229_131008', are_bridges=False, label_are_bridges='maps with low connectivity')
plot_df_all('20241229_131008', are_bridges=True, label_are_bridges='maps with high connectivity')
