In [None]:
import numpy as np
import pandas as pd
from skbio.diversity.alpha import shannon
from skbio.stats.composition import multiplicative_replacement
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import wilcoxon, mannwhitneyu
import itertools

In [None]:
def get_pan_data(data):
    """
    Generate species accumulation curve data.
    Args:
        data: A Pandas DataFrame where rows represent species and columns represent samples.
    Returns:
        A DataFrame with richness, standard deviation, and site information.
    """
    n_samples = data.shape[1]
    richness = []
    sd = []
    
    for i in range(1, n_samples + 1):
        subset = data.iloc[:, :i]
        richness.append(subset.sum(axis=1).gt(0).sum())
        sd.append(np.std(subset.sum(axis=1).gt(0)))
    
    return pd.DataFrame({
        "Sites": range(1, n_samples + 1),
        "Richness": richness,
        "SD": sd
    })

In [None]:
def zy_nspecies(dt, sample_map, group="Group", ID="Sample", sample_color=None, title="Rarefaction Curve Analysis"):
    """
    Generate group-specific species accumulation curves.
    Args:
        dt: A DataFrame where rows represent species and columns represent samples.
        sample_map: A DataFrame mapping samples to groups.
        group: Column name in sample_map specifying groups.
        ID: Column name in sample_map specifying sample IDs.
        sample_color: List of colors for the groups.
        title: Title of the plot.
    Returns:
        Matplotlib plot object.
    """
    if sample_color is None:
        sample_color = sns.color_palette("husl", len(sample_map[group].unique()))
    
    dt = dt[sample_map[ID]]
    results = []

    for g in sample_map[group].unique():
        group_samples = sample_map[sample_map[group] == g]
        group_data = dt[group_samples[ID]]
        pan_data = get_pan_data(group_data)
        pan_data[group] = g
        results.append(pan_data)
    
    results = pd.concat(results)
    
    # Plotting
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=results, x="Sites", y="Richness", hue=group, palette=sample_color, linewidth=2)
    for _, sub_df in results.groupby(group):
        plt.fill_between(sub_df["Sites"], sub_df["Richness"] - sub_df["SD"], sub_df["Richness"] + sub_df["SD"], alpha=0.2)
    
    plt.xlabel("Number of Samples")
    plt.ylabel("Number of Species")
    plt.title(title)
    plt.legend(title=group)
    plt.grid(False)
    plt.show()

In [None]:
def sig_func(p_value):
    """
    Convert p-value to significance level.
    """
    if p_value < 0.001:
        return "***"
    elif p_value < 0.01:
        return "**"
    elif p_value < 0.05:
        return "*"
    else:
        return "ns"

In [None]:
def zy_alpha(dt, sample_map, group="Group", ID="Sample", index="shannon", sample_color=None, 
             box_width=0.5, title="Alpha Diversity", violin=False):
    """
    Plot alpha diversity boxplots and perform statistical tests.
    Args:
        dt: A DataFrame where rows represent species and columns represent samples.
        sample_map: A DataFrame mapping samples to groups.
        group: Column name in sample_map specifying groups.
        ID: Column name in sample_map specifying sample IDs.
        index: Alpha diversity index to use (e.g., "shannon").
        sample_color: List of colors for the groups.
        box_width: Width of the box plot.
        title: Title of the plot.
        violin: Whether to include violin plots.
    Returns:
        Matplotlib plot object.
    """
    if sample_color is None:
        sample_color = sns.color_palette("husl", len(sample_map[group].unique()))

    # Align data
    dt = dt[sample_map[ID]]
    dt = dt.loc[~(dt.sum(axis=1) == 0)]
    
    # Calculate alpha diversity
    alpha = dt.apply(lambda row: shannon(row.values), axis=0)
    alpha_df = pd.DataFrame({"alpha": alpha, ID: alpha.index}).merge(sample_map, on=ID)
    
    # Pairwise comparisons
    comparisons = list(itertools.combinations(alpha_df[group].unique(), 2))
    p_values = []
    for g1, g2 in comparisons:
        group1 = alpha_df[alpha_df[group] == g1]["alpha"]
        group2 = alpha_df[alpha_df[group] == g2]["alpha"]
        _, p_value = mannwhitneyu(group1, group2, alternative="two-sided")
        p_values.append((g1, g2, sig_func(p_value)))
    
    # Plotting
    plt.figure(figsize=(10, 6))
    sns.set_style("whitegrid")
    if violin:
        sns.violinplot(data=alpha_df, x=group, y="alpha", palette=sample_color, inner=None)
    sns.boxplot(data=alpha_df, x=group, y="alpha", width=box_width, palette=sample_color, fliersize=2, linewidth=1.5)
    for i, (g1, g2, sig) in enumerate(p_values):
        plt.text(i + 0.5, max(alpha_df["alpha"]) * 1.1, sig, ha="center", fontsize=12)
    
    plt.title(title)
    plt.xlabel("Group")
    plt.ylabel("Alpha Diversity")
    plt.show()