In [1]:
import functools
import pickle

import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
from IPython.display import display, HTML

RUNDIRS = '../logs/rundirs'

In [2]:
@functools.cache
def get_key2df(runname):
    rundir = f'{RUNDIRS}/{runname}'
    with open(f'{rundir}/key2df.pickle', 'rb') as file:
        return pickle.load(file)

In [6]:
def calculate_heatmap_data(*, df, col_strategy, strategy, col_i, col_j, col_data):
    df = df[df[col_strategy] == strategy]
    df = df[~df['passhum']]
    return df[[col_i, col_j, col_data]].pivot(
        index=col_i, columns=col_j, values=col_data
    )


def rank_dataframes(dataframes, is_the_more_the_better=True):
    ranked_dfs = [pd.DataFrame(index=df.index, columns=df.columns)
                  for df in dataframes]
    
    # Get the shape of the DataFrames
    num_rows, num_cols = dataframes[0].shape
    
    # Iterate over each cell in the DataFrame
    for row in range(num_rows):
        for col in range(num_cols):
            # Extract the values across all DataFrames at the same position
            cell_values = (df.iloc[row, col] for df in dataframes)
            series = pd.Series(cell_values)
            ranks = series.rank(method='min', ascending=not is_the_more_the_better)

            # Iterate over each DataFrame
            for df_index, ranked_df in enumerate(ranked_dfs):
                ranked_df.iloc[row, col] = ranks[df_index]
        
    return ranked_dfs


def compare_ranked_dfs(df_a, df_b, label_a, label_b, color_a, color_b, color_ab):
    df_cmp = pd.DataFrame(index=df_a.index, columns=df_a.columns)
    df_colors = pd.DataFrame(index=df_a.index, columns=df_a.columns)
    
    # Get the shape of the DataFrames
    num_rows, num_cols = df_a.shape
    
    # Iterate over each cell in the DataFrame
    for row in range(num_rows):
        for col in range(num_cols):
            # Extract the values across all DataFrames at the same position
            value_a = df_a.iloc[row, col]
            value_b = df_b.iloc[row, col]
            
            if value_a < value_b:
                value_c = label_a
                color = color_a
            elif value_a > value_b:
                value_c = label_b
                color = color_b
            else:
                value_c = label_a + label_b
                color = color_ab
                
            df_cmp.iloc[row, col] = value_c
            df_colors.iloc[row, col] = color
            
    return df_cmp, df_colors


def plot_df(*, runname, title, col_i, col_j, col_data, heatmap_data, df_ranks):
    # Create an interactive heatmap
    fig = px.imshow(
        heatmap_data,
        labels={"x": col_j, "y": col_i, "color": col_data},
        title=f"{runname}: {title}: {col_data}",
        # text_auto=True,
    )
    
    fig.data[0].update(text=df_ranks.values, texttemplate="%{text}")
    
    fig.update_traces(
        hovertemplate=(
            col_i + ": %{y}<br>" + col_j + ": %{x}<br>" + col_data + ": %{z}<extra></extra>"
        )
    )
    
    return fig


def plot_df_all(runname, *, are_bridges, slowness):
    key2df = get_key2df(runname)
    df = key2df[are_bridges, True]
    
    dfx = df[~df['passhum'] & (df['slowness'] == slowness)]
    assert dfx.empty == (not are_bridges and slowness == 'with rerouting')
    if dfx.empty:
        return
    df = dfx
    
    col_strategy = 'forcing'
    col_i = 'i_map'
    col_j = 'position'
    col_data = 'No. of completed missions'
    is_the_more_the_better = True
    
    strategies = list(df[col_strategy].unique())
    
    title_cmp = 'Strategies Comparison'
    titles = strategies + [title_cmp]
    
    # Create subplots
    fig = make_subplots(
        rows=1,
        cols=len(titles),
        subplot_titles=titles,
    )
    
    dfs_heatmap_data = []
    for idx, strategy in enumerate(strategies):
        heatmap_data = calculate_heatmap_data(
            df=df,
            col_strategy=col_strategy,
            strategy=strategy, 
            col_i=col_i,
            col_j=col_j, 
            col_data=col_data,            
        )
        dfs_heatmap_data.append(heatmap_data)
        
    dfs_ranks = rank_dataframes(dfs_heatmap_data, is_the_more_the_better=is_the_more_the_better)
    df_cmp, df_color = compare_ranked_dfs(
        dfs_ranks[1], dfs_ranks[2],
        strategies[1].split()[-1][0].upper(),
        strategies[2].split()[-1][0].upper(),
        0.0, # 'blue',
        1.0, # 'yellow',
        0.5, # 'gray',
    )
        
    for idx, title in enumerate(titles):
        heatmap_fig = plot_df(
            runname=runname, 
            title=title, 
            col_i=col_i,
            col_j=col_j, 
            col_data=col_data,
            heatmap_data=df_color if title == title_cmp else dfs_heatmap_data[idx],
            df_ranks=df_cmp if title == title_cmp else dfs_ranks[idx],
        )
          
        # Add heatmap to the subplot
        for trace in heatmap_fig.data:
            if title != title_cmp:
                trace.update(coloraxis="coloraxis1")  # Link each subplot to the shared color axis
            else:
                trace.update(coloraxis="coloraxis2", showscale=False)
            trace.update(showscale=title != title_cmp)  # Show colorbar only if it's not the Comparison plot
            fig.add_trace(trace, row=1, col=idx + 1)
            
       
        # Apply axis settings to each subplot
        xaxis_key = f"xaxis{idx + 1}" if idx > 0 else "xaxis"
        yaxis_key = f"yaxis{idx + 1}" if idx > 0 else "yaxis"
        
        fig.update_layout(**{
            xaxis_key: dict(
                title=col_j,
                tickmode="array",
                tickvals=list(dfs_heatmap_data[0].columns),
                ticktext=dfs_heatmap_data[0].columns.tolist(),
                title_standoff=7,  # Move x-axis title closer
                automargin=True,
            ),
            yaxis_key: dict(
                title=col_i,
                tickmode="array",
                tickvals=list(dfs_heatmap_data[0].index),
                ticktext=dfs_heatmap_data[0].index.tolist(),
                autorange="reversed",  # Reverse the y-axis for top-to-bottom ticks
                title_standoff=0,  # Move y-axis title closer
                automargin=True,
            )
        })
    
    # Update layout with shared color scale
    label_are_bridges = 'maps with high connectivity' if are_bridges else 'maps with low connectivity'
    fig.update_layout(
        title=f"{col_data}<br>(slowness: {slowness}; coordination strategies)",
        coloraxis1=dict(
            colorscale="Greens",  # Apply "Greens" color scale to the shared color axis
            colorbar=dict(
                title=col_data,
                titlefont=dict(size=12),
                x=-0.08,
                titleside="right",
                thickness=10
            )
        ),
        coloraxis2=dict(
            colorscale=[
             #   [0.0, '#0000ff'],  # 0.0 is P
             #   [0.5, '#808080'],  # 0.5 is PS
             #   [1.0, '#ffff00'],  # 1.0 is S
             #   [0.0, '#ADD8E6'],  # Light blue (replacing '#0000ff')
             #   [0.5, '#FFFFE0'],  # Light yellow (replacing '#ffff00')
             #   [1.0, '#D3D3D3'],  # Light Gray (replacing '#808080')
                
                [0.0, '#FFFFE0'],  
                [0.5, '#D3D3D3'], 
                [1.0, '#ADD8E6'], 

            ],
         #   colorbar={"x": 1000},
        )
    )
    
    # Display the figure
    fig.show()


"""
- No. of completed missions
- Average mission length
- Average CS density score

Hypotheses:
- less Average mission length -> more No. of completed
- less Average CS density score -> more No. of completed missions (& less collisions, etc.) 
"""

def plot_runname(runname):
    for are_bridges in False, True:
        label_are_bridges = 'Maps with high connectivity' if are_bridges else 'Maps with low connectivity'
        display(HTML(f'<h2>{label_are_bridges}</h2>'))
        for slowness in (
                'baseline', 
                'without rerouting',
                'with rerouting',
        ):
            plot_df_all(runname, 
                        are_bridges=are_bridges,
                        slowness=slowness)
            
plot_runname('20241230_173555')