In [None]:
from utils import calculate_gap_closed
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots



def gap_closed_heatmap(metrics):
    approaches = ['AS-R-MO', 'AS-R-SO', 'AS-PR-MO', 'AS-PR-SO', 'AS-C-SO', 'AS-PC-MO', 'AS-PC-SO', 'AS-CS-PC-SO']
    data = []
    for metric in metrics:
        data.append(calculate_gap_closed(metric, "results", 42))
    df = pd.DataFrame(data, index=metrics, columns=approaches)
    print(df)
    
    # Calculate the mean of each column and append it as a new row
    df.loc['Mean'] = df.mean()
    
    #rename metrics 
    index_mapping = {
        'HAMMING LOSS example based': 'Hamming loss example-based',
        'MACRO F1': 'F1 macro',
        'MICRO F1': 'F1 micro', 
        'AUCROC MICRO': 'AUROC micro', 
        'F1 example based': 'F1 example-based'
    }
    df.rename(index=index_mapping, inplace=True)

    # Plotting
    fig = px.imshow(df.round(2), color_continuous_scale='pubu', text_auto=True)
    fig.update_coloraxes(showscale=False)
    fig.update_xaxes(side="top")
    
    # Adjust layout
    fig.update_layout(
        autosize=True,
        width=720,
        margin=dict(
            l=0,
            r=0,
            b=0,
            t=0,
        ),
        font=dict(
            size=14
        )
    )
    
    # Drawing a dashed red line to separate the last row
    num_metrics = len(metrics)
    fig.add_shape(type="line",
                  x0=-0.5, y0=num_metrics - 0.5, x1=len(approaches) - 0.5, y1=num_metrics - 0.5,
                  line=dict(color="Red", width=4, dash="dash"),
                  layer="above"
                 )
    
    fig.show()
    fig.write_image(f'./figures/heatmap.pdf')
    # return df

gap_closed_heatmap(metrics=['HAMMING LOSS example based', 'MACRO F1', 'MICRO F1', 'AUCROC MICRO', 'F1 example based'])

                              AS-R-MO    AS-R-SO   AS-PR-MO   AS-PR-SO  \
HAMMING LOSS example based  58.651600  47.234061  47.929908  54.133231   
MACRO F1                    43.846904  21.520622  31.499747  25.651397   
MICRO F1                    56.716828  40.464245  52.870783  50.502393   
AUCROC MICRO                55.848361  44.104398  68.171020  64.172446   
F1 example based            49.031361  39.208647  65.443961  53.860864   

                              AS-C-SO   AS-PC-MO   AS-PC-SO  AS-CS-PC-SO  
HAMMING LOSS example based   4.464963  12.791298  22.122259    11.867585  
MACRO F1                    40.292221  50.666330  47.645083    44.924300  
MICRO F1                    37.132444  44.159213  42.242695    46.780219  
AUCROC MICRO                52.137152  66.734666  52.388384    59.266039  
F1 example based            59.045385  62.004466  52.844830    52.356961  
