In [4]:
import plotly.graph_objects as go

# Data for "Distance to optimal"
dist_to_optimal = {
    '2 feature unbalanced': -2.45,
    '2 feature balanced': -1,
    '4 feature unbalanced': -4.88,
    '4 feature balanced': -1.4
}

# Sample n values corresponding to each bar
n_values = {
    '2 feature unbalanced': -2.45,
    '2 feature balanced': -0.54,
    '4 feature unbalanced': -4.88,
    '4 feature balanced': -0.63
}

# N values for unbalanced and balanced features
n_unbalanced = [n_values['2 feature unbalanced'], n_values['4 feature unbalanced']]
n_balanced = [n_values['2 feature balanced'], n_values['4 feature balanced']]

# Create x-axis labels
x_labels = ['2 features', '4 features']

# Y-axis data for unbalanced and balanced features
y_unbalanced = [dist_to_optimal['2 feature unbalanced'], dist_to_optimal['4 feature unbalanced']]
y_balanced = [dist_to_optimal['2 feature balanced'], dist_to_optimal['4 feature balanced']]

# Create the bar plot
fig = go.Figure()

# Add bars for unbalanced features
fig.add_trace(go.Bar(
    x=x_labels,
    y=y_unbalanced,
    name='Unbalanced',
    marker=dict(color='darkorange'),
    width = 0.2,
    offsetgroup=0,
    text=[f"{n}" for n in n_unbalanced],  # Add n values on the bars
    textposition='inside',  # Show text outside the bars
    # textangle = -90,
    insidetextanchor='middle',  # Position text in the middle of the bar
    showlegend=False
))

# Add bars for balanced features
fig.add_trace(go.Bar(
    x=x_labels,
    y=y_balanced,
    name='Balanced',
    marker=dict(color='darkcyan'),
    width = 0.2,
    offsetgroup=1,
    text=[f"{n}" for n in n_balanced],  # Add n values on the bars
    textposition='inside',  # Show text outside the bars
    # textangle = -90,
    insidetextanchor='middle',  # Position text in the middle of the bar
    showlegend=False
))

# # Define the x-coordinates for where the arrows should start and end (middle of the bars)
# x_2_unbalanced = -0.13  # Adjusting position for '2 feature unbalanced'
# x_2_balanced = 0.12     # Adjusting position for '2 feature balanced'
# x_4_unbalanced = 0.78   # Adjusting position for '4 feature unbalanced'
# x_4_balanced = 1.05     # Adjusting position for '4 feature balanced'

# # Add arrows from 4 feature unbalanced to 2 feature unbalanced and balanced
# fig.add_annotation(x=x_4_unbalanced, y=dist_to_optimal['4 feature unbalanced'],
#                    ax=x_2_unbalanced, ay=dist_to_optimal['2 feature unbalanced']+0.05,
#                    xref="x", yref="y", axref="x", ayref="y",
#                    showarrow=True, arrowhead=3, arrowsize=2, arrowwidth=4)

# fig.add_annotation(x=x_4_balanced, y=dist_to_optimal['4 feature balanced'],
#                    ax=x_2_balanced, ay=dist_to_optimal['2 feature balanced'],
#                    xref="x", yref="y", axref="x", ayref="y",
#                    showarrow=True, arrowhead=3, arrowsize=2, arrowwidth=4)

# Add dummy scatter trace for larger squares in the legend (Unbalanced)
fig.add_trace(go.Scatter(
    x=[None], y=[None],  # No actual data points
    mode='markers',
    marker=dict(
        size=70,  # Control the size of the squares in the legend
        symbol='square',  # Set the marker shape to square
        color='darkorange'  # Same color as the 'Unbalanced' bars
    ),
    name='Unbalanced',  # Same name for legend
    showlegend=True
))

# Add dummy scatter trace for larger squares in the legend (Balanced)
fig.add_trace(go.Scatter(
    x=[None], y=[None],  # No actual data points
    mode='markers',
    marker=dict(
        size=70,  # Control the size of the squares in the legend
        symbol='square',  # Set the marker shape to square
        color='darkcyan'  # Same color as the 'Balanced' bars
    ),
    name='Balanced',  # Same name for legend
    showlegend=True
))

# Update layout for better visualization
fig.update_layout(
    font=dict(size=35),  # Increase the font size for the entire plot
    yaxis_title='Distance to Optimal Score',
    yaxis=dict(tickfont=dict(size=1), showgrid=False),
    xaxis=dict(tickvals=['2 features', '4 features']),
    width = 800,
    height = 800,
    bargap = 0.5,
    barmode = 'group',
    legend=dict(
        orientation="h",  # Horizontal layout
        yanchor="bottom",  # Align to the bottom of the plot area
        y=1.02,  # Move above the plot
        xanchor="center",  # Center the legend horizontally
        x=0.5,  # Position it in the middle
        itemsizing='trace',  # Ensure legend markers are not scaled
        font=dict(size=40),  # Increase the legend font size
        itemwidth=40,  # Increase the item width to make the squares larger

    ),
    template='presentation'
)

# Display the plot
fig.show()