In [None]:
# --- Project Setup ---
from setup_notebook import setup_project_root
setup_project_root()

# --- Imports ---
from src.project_config import COLORS_MODELS, PROCESSED_DIR, PROTEIN_IDS_CSV, get_paths_protein, get_paths



### Plotting


In [None]:
# Define Plottig Function
# Ensure palette is a list and matches the number of unique categories
unique_classes = df["pred_class"].nunique()
COLORS = list(COLORS)[:unique_classes]


def make_plot_with_means(
    plot_function,
    title,
    file_name,
    invert,
    llr,
    means,
    x,
    y,
    data,
    invert_y=False,
    subtitle=None,
    background_color="white",
    palette=None,
    n=False,
    model=None,
    **kwargs
):
    # Create the figure and set the background color
    fig = plt.gcf()
    fig.set_size_inches(14, 8)  # Adjust aspect ratio
    fig.patch.set_facecolor(background_color)  # Set background color

    # Create the plot
    ax = plot_function(x=x, y=y, data=data, palette=palette, **kwargs)

    # Add title and axis labels
    plt.title(title, y=1.03, size=24, fontweight='bold')
    plt.ylabel("Average Pathogenicity " + llr, fontsize=16, fontweight='bold')
    plt.xlabel("Prediction Class", fontsize=16, fontweight='bold')

    # Add an optional subtitle
    if subtitle:
        plt.suptitle(subtitle, y=0.98, fontsize=18, fontweight='medium', color='gray')

    # Customize tick labels
    ax.tick_params(axis='x', labelsize=14, rotation=0)
    ax.tick_params(axis='y', labelsize=14)

    # Add major and minor gridlines
    ax.grid(True, which='major', linestyle='--', linewidth=0.5, color='gray', alpha=0.7)
    ax.grid(True, which='minor', linestyle=':', linewidth=0.3, color='lightgray', alpha=0.5)

    # Optional: Invert the y-axis if requested
    if invert_y:
        ax.invert_yaxis()

    # Disable the gridlines
    ax.grid(False)  # This removes all gridlines

    # Ensure the order of means matches the order of boxplots
    ordered_classes = [tick.get_text() for tick in ax.get_xticklabels()]

    if means:
        # Add mean values as text annotations
        group_means = data.groupby(x)[y].mean()

        # Ensure the order of means matches the order of boxplots
        #ordered_classes = [tick.get_text() for tick in ax.get_xticklabels()]
        ordered_means = group_means.loc[ordered_classes]

        for i, mean in enumerate(ordered_means):
            ax.text(
                i,
                mean + invert,  # Adjust the offset for better positioning
                f'{mean:.3f}',
                color='white',
                ha='center',
                va='top',
                fontsize=16,  # Increased font size for better visibility
                fontweight='bold',
                alpha=1,
                path_effects=[path_effects.withStroke(linewidth=1, foreground="black")]  # Add stroke for better sharpness
            )
    if n:
        # Add number of data points per group
        group_counts = data.groupby(x)[y].count()

        # Ensure the order of counts matches the order of boxplots
        ordered_counts = group_counts.loc[ordered_classes]
        
        # Add number of data points per group
        for i, count in enumerate(ordered_counts):
            # Dynamically adjust offset based on the y-axis limits
            ymin, ymax = ax.get_ylim()  # Get the current y-axis limits
            if model == "AlphaMissense":
                offset = -1.65
            else:
                offset = ymin - (ymax - ymin) * 0.07 # 2% above the bottom of the y-axis
            
            ax.text(
                i,
                offset,  # Adjusted position relative to the y-axis limits
                f'n={count}',
                color='black',
                ha='center',
                va='bottom',
                fontsize=12,
                fontweight='bold',
                alpha=0.8
            )


    # Define the save path based on the plot function
    if plot_function == sns.swarmplot:
        save_path = '/Users/doma/Documents/Bachelor_Arbeit/Code/results/images/Swarmplots/'
    elif plot_function == sns.boxplot:
        save_path = '/Users/doma/Documents/Bachelor_Arbeit/Code/results/images/Boxplots/'


    # Save the plot with dynamic file naming
    plt.tight_layout()  # Ensure everything fits well
    plt.savefig(save_path + file_name + ".png", dpi=300, bbox_inches='tight')
    plt.savefig(save_path + file_name + ".svg", dpi=300, bbox_inches='tight')

    # Return the Axes object
    return ax