In [None]:
import pandas as pd
import numpy as np
from scipy.stats import ranksums

In [None]:
def zy_pvalue(dt, sample_map, group_col, id_col):
    """
    Calculate p-values, fold changes, and enrichment between group pairs.
    
    Args:
        dt (pd.DataFrame): Data matrix where rows are species and columns are samples.
        sample_map (pd.DataFrame): Mapping file with group and ID information.
        group_col (str): Column name in `sample_map` indicating the group variable.
        id_col (str): Column name in `sample_map` indicating sample IDs.

    Returns:
        pd.DataFrame: Results with calculated statistics for each species and group pair.
    """
    # Subset dt to include only the relevant sample columns
    dt = dt[sample_map[id_col]]
    groups = sample_map[group_col].unique()
    comparisons = [(g1, g2) for i, g1 in enumerate(groups) for g2 in groups[i + 1:]]
    nspecies = dt.shape[0]
    species_names = dt.index

    results = []

    for n in range(nspecies):
        temp_dt = dt.iloc[n, :]
        for g1, g2 in comparisons:
            g1_samples = sample_map[sample_map[group_col] == g1][id_col]
            g2_samples = sample_map[sample_map[group_col] == g2][id_col]
            dt1 = temp_dt[g1_samples].values
            dt2 = temp_dt[g2_samples].values

            c1 = np.sum(dt1 != 0)
            c2 = np.sum(dt2 != 0)
            m1 = np.mean(dt1)
            m2 = np.mean(dt2)
            ag1 = np.sum(dt1) / c1 if c1 > 0 else 0
            ag2 = np.sum(dt2) / c2 if c2 > 0 else 0
            am = np.mean(np.concatenate([dt1, dt2]))
            a_var = np.var(np.concatenate([dt1, dt2]))
            p = ranksums(dt1, dt2).pvalue
            fold_change = m1 / m2 if m2 != 0 else np.inf
            enriched = g1 if m1 > m2 else g2

            # Calculate rank averages
            all_samples = sample_map[sample_map[group_col].isin([g1, g2])][id_col]
            all_ranks = pd.Series(temp_dt[all_samples].rank())
            rank1_avg = all_ranks[g1_samples].mean()
            rank2_avg = all_ranks[g2_samples].mean()

            results.append({
                "name": species_names[n],
                "g1": g1,
                "g2": g2,
                "Avg.g1": m1,
                "Avg.g2": m2,
                "fold_change": fold_change,
                "enriched": enriched,
                "Avg.weighted.g1": ag1,
                "Avg.weighted.g2": ag2,
                "all.avg": am,
                "all.var": a_var,
                "pvalue": p,
                "count1": c1,
                "count2": c2,
                "rank1.avg": rank1_avg,
                "rank2.avg": rank2_avg
            })

    return pd.DataFrame(results)

def zy_qvalue(dt, sample_map, group_col, id_col, method="fdr_bh", min_count=0, min_avg=0, min_fd=0):
    """
    Filter results and calculate q-values for significance testing.
    
    Args:
        dt (pd.DataFrame): Data matrix where rows are species and columns are samples.
        sample_map (pd.DataFrame): Mapping file with group and ID information.
        group_col (str): Column name in `sample_map` indicating the group variable.
        id_col (str): Column name in `sample_map` indicating sample IDs.
        method (str): Method for multiple testing correction (default: "fdr_bh").
        min_count (int): Minimum sample count threshold.
        min_avg (float): Minimum average value threshold.
        min_fd (float): Minimum fold-change threshold.

    Returns:
        pd.DataFrame: Results with calculated q-values and filtered data.
    """
    from statsmodels.stats.multitest import multipletests

    result = zy_pvalue(dt, sample_map, group_col, id_col)
    result = result.astype({
        "Avg.g1": float, "Avg.g2": float, "fold_change": float,
        "Avg.weighted.g1": float, "Avg.weighted.g2": float, "all.avg": float,
        "all.var": float, "pvalue": float, "count1": int, "count2": int,
        "rank1.avg": float, "rank2.avg": float
    })

    result["qvalue"] = np.nan

    # Filter rows meeting criteria for q-value calculation
    result_filtered = result[
        ((result["count1"] >= min_count) | (result["count2"] >= min_count)) &
        (result["fold_change"] >= min_fd) &
        (result["all.avg"] >= min_avg)
    ]
    if not result_filtered.empty:
        _, qvalues, _, _ = multipletests(result_filtered["pvalue"], method=method)
        result_filtered["qvalue"] = qvalues

    # Combine filtered and unfiltered data
    result_unfiltered = result[
        ~(((result["count1"] >= min_count) | (result["count2"] >= min_count)) &
          (result["fold_change"] >= min_fd) &
          (result["all.avg"] >= min_avg))
    ]
    final_result = pd.concat([result_unfiltered, result_filtered], ignore_index=True)

    return final_result