In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist
from scipy.cluster.hierarchy import linkage, leaves_list

In [None]:
def zy_compositions(dt, top_N=10, title="Composition", taxo_colors=None, width=0.9,
                    order_func="order", order_n=1, label_order=None):
    """
    Create a composition bar plot for taxonomic data.

    Args:
        dt (pd.DataFrame): Dataframe with rows as taxa and columns as samples.
        top_N (int): Number of top taxa to display; remaining taxa grouped as 'other'.
        title (str): Plot title.
        taxo_colors (list): List of colors for taxa.
        width (float): Bar width.
        order_func (str): Sample ordering method ('order', 'cluster', 'specific').
        order_n (int): Taxon to use for ordering when `order_func='order'`.
        label_order (list): Specific sample order when `order_func='specific'`.

    Returns:
        matplotlib.figure.Figure: Composition bar plot.
    """
    if taxo_colors is None:
        taxo_colors = sns.color_palette("tab20", n_colors=top_N)

    # Filter taxa with zero abundance
    dt = dt.loc[dt.sum(axis=1) != 0]

    # Sort taxa by mean abundance
    dt = dt.loc[dt.mean(axis=1).sort_values(ascending=False).index]

    # Aggregate less abundant taxa into 'other'
    if len(dt) > top_N:
        others = dt.iloc[top_N:].sum(axis=0)
        dt = dt.iloc[:top_N]
        dt.loc["other"] = others

    # Order samples based on the chosen method
    if order_func == "order":
        label_order = dt.iloc[order_n - 1].sort_values().index
    elif order_func == "cluster":
        dist_matrix = pdist(dt.T, metric="braycurtis")
        cluster_order = leaves_list(linkage(dist_matrix, method="average"))
        label_order = dt.columns[cluster_order]
    elif order_func == "specific" and label_order is not None:
        pass  # Use the provided label_order
    else:
        label_order = dt.columns

    # Reshape for plotting
    dt = dt[label_order]
    melted_data = dt.reset_index().melt(id_vars="index", var_name="Sample", value_name="Abundance")
    melted_data.rename(columns={"index": "Taxon"}, inplace=True)

    # Plot composition
    plt.figure(figsize=(12, 6))
    sns.barplot(data=melted_data, x="Sample", y="Abundance", hue="Taxon", palette=taxo_colors)
    plt.xticks(rotation=45, ha="right")
    plt.title(title, fontsize=16)
    plt.xlabel("Samples", fontsize=12)
    plt.ylabel("Abundance", fontsize=12)
    plt.legend(title="Taxa", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()


In [None]:
def zy_group_compositions(dt, sample_map, ID, group, top_N=10, title="Composition",
                          taxo_colors=None, width=0.9, label_order=None,
                          order_func="order", order_n=1):
    """
    Create a grouped composition bar plot for taxonomic data.

    Args:
        dt (pd.DataFrame): Dataframe with rows as taxa and columns as samples.
        sample_map (pd.DataFrame): Metadata for grouping samples.
        ID (str): Column in `sample_map` corresponding to sample IDs.
        group (str): Column in `sample_map` corresponding to group labels.
        top_N (int): Number of top taxa to display; remaining taxa grouped as 'other'.
        title (str): Plot title.
        taxo_colors (list): List of colors for taxa.
        width (float): Bar width.
        label_order (list): Specific sample order when `order_func='specific'`.
        order_func (str): Sample ordering method ('order', 'cluster', 'specific').
        order_n (int): Taxon to use for ordering when `order_func='order'`.

    Returns:
        matplotlib.figure.Figure: Grouped composition bar plot.
    """
    if taxo_colors is None:
        taxo_colors = sns.color_palette("tab20", n_colors=top_N)

    # Filter and reorder samples
    dt = dt[sample_map[ID]]
    dt = dt.loc[dt.sum(axis=1) != 0]

    # Sort taxa by mean abundance
    dt = dt.loc[dt.mean(axis=1).sort_values(ascending=False).index]

    # Aggregate less abundant taxa into 'other'
    if len(dt) > top_N:
        others = dt.iloc[top_N:].sum(axis=0)
        dt = dt.iloc[:top_N]
        dt.loc["other"] = others

    # Order samples based on the chosen method
    if order_func == "order":
        label_order = dt.iloc[order_n - 1].sort_values().index
    elif order_func == "cluster":
        dist_matrix = pdist(dt.T, metric="braycurtis")
        cluster_order = leaves_list(linkage(dist_matrix, method="average"))
        label_order = dt.columns[cluster_order]
    elif order_func == "specific" and label_order is not None:
        pass  # Use the provided label_order
    else:
        label_order = dt.columns

    # Merge with sample metadata
    melted_data = dt[label_order].reset_index().melt(id_vars="index", var_name="Sample", value_name="Abundance")
    melted_data.rename(columns={"index": "Taxon"}, inplace=True)
    melted_data = melted_data.merge(sample_map, left_on="Sample", right_on=ID)

    # Plot grouped composition
    plt.figure(figsize=(12, 6))
    sns.barplot(
        data=melted_data, x="Sample", y="Abundance", hue="Taxon",
        palette=taxo_colors, dodge=False
    )
    plt.xticks(rotation=90)
    plt.title(title, fontsize=16)
    plt.xlabel("Samples", fontsize=12)
    plt.ylabel("Abundance", fontsize=12)
    plt.legend(title="Taxa", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.show()
