In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import seaborn as sns

from adpbulk import ADPBulk
from scipy.stats import binom, percentileofscore
import matplotlib.pyplot as plt

import time

def SCherlock(adata, column_ctype, column_patient, k_values=[1,10,25], scoring='diff', aggregation_method='mean', parameter_estimation='patient_median', max_genes_kept=100, min_pat=3, min_read=10):
    start_time = time.time()
    print(f"Starting SCherlock with parameters: k_values={k_values}, scoring={scoring}, aggregation={aggregation_method}")
    
    print("Step 1/7: Calculating theoretical scores...")
    th_scores, expr_proportions = theoretical_scores_binomial(adata, column_ctype, column_patient, k_values, scoring, parameter_estimation)
    print(f"Theoretical scores shape: {th_scores.shape if hasattr(th_scores, 'shape') else 'N/A'}")
    step1_time = time.time()
    print(f"Time for step 1: {step1_time - start_time:.2f} seconds")
    
    print("Step 2/7: Performing multi-category correction...")
    proc_scores = multi_cat_correction(th_scores)
    print(f"Processed scores shape: {proc_scores.shape if hasattr(proc_scores, 'shape') else 'N/A'}")
    step2_time = time.time()
    print(f"Time for step 2: {step2_time - step1_time:.2f} seconds")
    
    print("Step 3/7: Aggregating scores...")
    agg_scores = aggregation(proc_scores, aggregation_method)
    print(f"Aggregated scores shape: {agg_scores.shape if hasattr(agg_scores, 'shape') else 'N/A'}")
    step3_time = time.time()
    print(f"Time for step 3: {step3_time - step2_time:.2f} seconds")
    
    print("Step 4/7: Sorting scores...")
    sorted_table = score_sorting(agg_scores, proc_scores, expr_proportions)
    print(f"Sorted table shape: {sorted_table.shape if hasattr(sorted_table, 'shape') else 'N/A'}")
    step4_time = time.time()
    print(f"Time for step 4: {step4_time - step3_time:.2f} seconds")
    
    print("Step 5/7: Filtering scores...")
    filtered_scores = filtering(sorted_table, column_patient, column_ctype, max_genes_kept)
    print(f"Filtered scores shape: {filtered_scores.shape if hasattr(filtered_scores, 'shape') else 'N/A'}")
    print(f"Number of genes after filtering: {len(filtered_scores.index.unique()) if hasattr(filtered_scores, 'index') else 'N/A'}")
    step5_time = time.time()
    print(f"Time for step 5: {step5_time - step4_time:.2f} seconds")

    print("Step 6/7: Calculating empirical scores...")
    emp_scores = empirical_scores_v0(filtered_scores, adata, column_ctype, column_patient, k_values, scoring)
    print(f"Empirical scores shape: {emp_scores.shape if hasattr(emp_scores, 'shape') else 'N/A'}")
    step6_time = time.time()
    print(f"Time for step 6: {step6_time - step5_time:.2f} seconds")
    
    print("Step 7/7: Final aggregation and sorting...")
    agg_emp_scores = aggregation(emp_scores, aggregation_method)
    sorted_emp_table = score_sorting(agg_emp_scores, emp_scores, expr_proportions)
    print(f"Final sorted table shape: {sorted_emp_table.shape if hasattr(sorted_emp_table, 'shape') else 'N/A'}")
    step7_time = time.time()
    print(f"Time for step 7: {step7_time - step6_time:.2f} seconds")
    
    print("Constructing top gene list...")
    top_gene_list = construct_top_list(sorted_emp_table, 0.5)
    print(f"Number of genes in top list: {len(top_gene_list) if isinstance(top_gene_list, list) else 'N/A'}")
    
    end_time = time.time()
    total_time = end_time - start_time
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"SCherlock completed successfully")
    
    return top_gene_list, sorted_emp_table

# def theoretical_scores_binomial_v0(adata, column_ctype, column_patient, k_values=[1,10,25], scoring='diff', parameter_estimation='patient_median'):
#     parameters = estimate_binomial(adata, column_ctype, column_patient, parameter_estimation)
#     scores = {}
#     for ctype in adata.obs[column_ctype].unique():
#         scores_ctype = np.zeros(shape=(adata.shape[1], len(k_values)))
#         scores_ctype = pd.DataFrame(scores_ctype, index=adata.var.index, columns=k_values)
#         for k in k_values:
#             cutoff_k=np.max([binom.ppf(.99, int(parameters[0][ctype]*k), parameters[2][ctype]).max(),
#                              binom.ppf(.99, int(parameters[1][ctype]*k), parameters[3][ctype]).max()])
#             alpha = np.zeros(shape=(adata.shape[1], int(cutoff_k)))
#             for l in np.arange(int(cutoff_k)):
#                 alpha[:,l]=binom.cdf(l, int(parameters[0][ctype]*k), parameters[2][ctype])
#             beta = np.zeros(shape=(adata.shape[1], int(cutoff_k)))
#             for l in np.arange(int(cutoff_k)):
#                 beta[:,l]=binom.cdf(l, int(parameters[1][ctype]*k), parameters[3][ctype])
#             if scoring=='diff':
#                 scores_ctype[k]=(beta-alpha)[np.arange(adata.shape[1]), np.argmax(beta-alpha,axis=1)]
#             elif scoring=='sensFPRzero':
#                 scores_ctype[k]=(1-alpha)[np.arange(adata.shape[1]), np.argmax(beta,axis=1)]
#             elif scoring=='sensPPV99':
#                 ppv = np.nan_to_num((1-alpha)/(2-alpha-beta))
#                 scores_ctype[k]=((1-alpha)[np.arange(adata.shape[1]), np.argmax(ppv>0.99,axis=1)])*(np.sum(ppv>0.99,axis=1)>0)
#         scores[ctype] = pd.DataFrame(scores_ctype)
#     return scores, parameters[2]

def theoretical_scores_binomial(adata, column_ctype, column_patient, k_values=[1,10,25], scoring='diff', parameter_estimation='patient_median'):
    """
    Calculate theoretical scores for gene as a marker of each cell type in the dataset.

    Args:
        adata (AnnData object): Annotated single-cell gene expression data.
        column_ctype (str): Column in `adata.obs` specifying cell types.
        column_patient (str): Column in `adata.obs` specifying patient identifiers.
        k_values (list of int): List of k values for which to compute scores (default: [1, 10, 25]).
        scoring (str): Scoring strategy. Options are:
                        - 'diff': Maximized difference between sensitivty and false positive rate (FPR).
                        - 'sensFPRzero': Sensitivity at zero false positive rate (FPR).
                        - 'sensPPV99': Sensitivity when positive predictive value (PPV) reaches 99%.
                       Default is 'diff'.
        parameter_estimation (str): Strategy for estimating binomial parameters. Options are:
                        - 'patient_median': Estimated on each patient independently, median of estimated values taken.
                        - 'mean': Estimated on the entire dataset.
                       Default is 'patient_median'.

    Returns:
        dict: A dictionary where keys are cell types and values Pandas DataFrames of theoretical scores computed.
        dict: A dictionary where keys are cell types and values Pandas DataFrames of estimated parameters p of the 
              binomial distribution for each gene (cprresponding to the proportion of expression within the cell type).
    """
    # Estimate binomial distribution parameters.
    parameters = estimate_binomial(adata, column_ctype, column_patient, parameter_estimation)
    scores = {}
    # Calculate scores for each cell type.
    for ctype in adata.obs[column_ctype].unique():
        # Initialize an empty DataFrame for scores of this cell type.
        scores_ctype = np.zeros(shape=(adata.shape[1], len(k_values)))  # Rows are genes, columns are k values.
        scores_ctype = pd.DataFrame(scores_ctype, index=adata.var.index, columns=k_values)
        # Calculate scores for each value of k.
        for k in k_values:
            # Compute computation cutoff for each gene based on 0.99 of PPF of the binomial distributions.
            cutoffs_k = pd.DataFrame(
                [
                    binom.ppf(.99, int(parameters[0][ctype] * k), parameters[2][ctype]),  # Group 1
                    binom.ppf(.99, int(parameters[1][ctype] * k), parameters[3][ctype])   # Group 2
                ],
                columns=scores_ctype.index
            ).max().clip(lower=100)  # Ensure a minimum value of 100 for cutoffs.
            # Compute for alpha (within ctype CDF) and beta (complement CDF) for each cell type
            # Compute alpha and beta at 101 points (ints) between 0 and cutoff; because of clipping
            # this becomes values from 0 to 100 when both PPFs are smaller than 100.
            alpha = (cutoffs_k.values[:, None] * np.arange(101) // 100)
            beta = alpha.copy()
            for l in np.arange(101):
                alpha[:, l] = binom.cdf(alpha[:, l], int(parameters[0][ctype] * k), parameters[2][ctype])
            for l in np.arange(101):
                beta[:, l] = binom.cdf(beta[:, l], int(parameters[1][ctype] * k), parameters[3][ctype])
            # Compute scores based on the selected scoring strategy.
            # (Sensitivity = 1 - alpha; FPR = 1 - beta)
            if scoring == 'diff':
                # Max (sensitivity - FPR).
                scores_ctype[k] = (beta - alpha)[np.arange(adata.shape[1]), np.argmax(beta - alpha, axis=1)]
            elif scoring == 'sensFPRzero':
                # Sensitivity at zero false positive rate.
                scores_ctype[k] = (1 - alpha)[np.arange(adata.shape[1]), np.argmax(beta, axis=1)]
            elif scoring == 'sensPPV99':
                # Sensitivity at PPV > 99%.
                ppv = np.nan_to_num((1 - alpha) / (2 - alpha - beta))  # Compute PPV.
                scores_ctype[k] = (
                    (1 - alpha)[np.arange(adata.shape[1]), np.argmax(ppv > 0.99, axis=1)] *
                    (np.sum(ppv > 0.99, axis=1) > 0)  # Only if PPV > 99% exists.
                )
        scores[ctype] = pd.DataFrame(scores_ctype)
    return scores, parameters[2]


def construct_top_list(sorted_emp_table, cutoff):
    """
    Identifies the top marker for each cell type where at least one satisfactory marker exists.

    Args:
        sorted_emp_table (dict): A dictionary containing sorted (descending) aggregared empirical
                                 scores in the form of a Pandas DataFrame for each cell type.
        cutoff (float): The minimum aggregated score for a gene to be considered as a marker.

    Returns:
        dict: A dictionary where keys are the cell types from `sorted_emp_table`, and values 
              are the corresponding top scoring markers.
              Cell types with no marker scores greater than or equal to the cutoff are excluded.
    """
    # Construct a dictionary of top markers for each category.
    top_markers = {
        ctype: table.index[0] # The index (name) of the first (top scoring) marker in the DataFrame.
        for ctype, table in sorted_emp_table.items() # Iterate over all categories and tables.
        if not table.empty and table['aggregated'].max() >= cutoff # Check if marker exists and meets cutoff.
    }
    return top_markers

def empirical_cdf_v0(df, value):
    #empirical cdf computation (all genes must be evaluated on the same number of pointd)
    return df.apply(lambda col: percentileofscore(col, value, kind='rank') / 100)

def empirical_scores_v0(filtered_scores, adata, column_ctype, column_patient, k_values=[1,10,25], scoring='diff', seed=0, n_sim=1000):
    #original stable empirical score computation: slow (calculated for all values up to 99th percentile of largest gene)
    gene_list = []
    np.random.seed(seed)
    scores_all = {}
    for ctype in adata.obs[column_ctype].unique():
        if not filtered_scores[ctype].empty:
            gene_list += filtered_scores[ctype].index.tolist()
    adata_subset=adata[:,gene_list].copy()
    for ctype in adata.obs[column_ctype].unique():
        scores_ctype = np.zeros(shape=(adata_subset.shape[1], len(k_values)))
        scores_ctype = pd.DataFrame(scores_ctype, index=adata_subset.var.index, columns=k_values)
        expr_A = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]==ctype], adata_subset.var.index.tolist())
        expr_nA = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]!=ctype], adata_subset.var.index.tolist())
        for k in k_values:
            if k > 1:
                sampled_A= np.random.choice(expr_A.shape[0], (n_sim, k), replace=True)
                k_sums_A = pd.DataFrame(expr_A.values[sampled_A].sum(axis=1), columns=expr_A.columns)
                sampled_nA= np.random.choice(expr_nA.shape[0], (n_sim, k), replace=True)
                k_sums_nA = pd.DataFrame(expr_nA.values[sampled_nA].sum(axis=1), columns=expr_nA.columns)
            else:
                k_sums_A = expr_A
                k_sums_nA = expr_nA
            cutoff_k=np.max([k_sums_A.quantile(0.99).max(),
                             k_sums_nA.quantile(0.99).max(), 1])
            alpha = np.zeros(shape=(adata_subset.shape[1], int(cutoff_k)))
            for l in np.arange(int(cutoff_k)):
                alpha[:,l]=empirical_cdf_v0(k_sums_A, l).values
            beta = np.zeros(shape=(adata_subset.shape[1], int(cutoff_k)))
            for l in np.arange(int(cutoff_k)):
                beta[:,l]=empirical_cdf_v0(k_sums_nA, l).values
            if scoring=='diff':
                scores_ctype[k]=(beta-alpha)[np.arange(adata_subset.shape[1]), np.argmax(beta-alpha,axis=1)]
            elif scoring=='sensFPRzero':
                scores_ctype[k]=(1-alpha)[np.arange(adata_subset.shape[1]), np.argmax(beta,axis=1)]
            elif scoring=='sensPPV99':
                ppv = np.nan_to_num((1-alpha)/(2-alpha-beta))
                scores_ctype[k]=((1-alpha)[np.arange(adata_subset.shape[1]), np.argmax(ppv>0.99,axis=1)])*(np.sum(ppv>0.99,axis=1)>0)
        scores_all[ctype] = pd.DataFrame(scores_ctype)
    corr_scores = multi_cat_correction(scores_all)
    scores = {}
    for ctype in adata.obs[column_ctype].unique():
        if filtered_scores[ctype].empty:
            scores[ctype]=pd.DataFrame([])
        else:
            scores[ctype]=corr_scores[ctype].loc[filtered_scores[ctype].index].clip(upper=1)
    return scores




# def empirical_scores_v0_v1(filtered_scores, adata, column_ctype, column_patient, k_values=[1,10,25], scoring='diff', seed=0, n_sim=1000):
#     # code for comparison with the same seed
#     gene_list = []
#     np.random.seed(seed)
#     scores_all_v0 = {}
#     scores_all_v1 = {}
#     for ctype in adata.obs[column_ctype].unique():
#         if not filtered_scores[ctype].empty:
#             gene_list += filtered_scores[ctype].index.tolist()
#     adata_subset=adata[:,gene_list].copy()
#     for ctype in adata.obs[column_ctype].unique():
#         scores_ctype_v0 = np.zeros(shape=(adata_subset.shape[1], len(k_values)))
#         scores_ctype_v0 = pd.DataFrame(scores_ctype_v0, index=adata_subset.var.index, columns=k_values)
#         scores_ctype_v1 = scores_ctype_v0.copy()
#         expr_A = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]==ctype], adata_subset.var.index.tolist())
#         expr_nA = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]!=ctype], adata_subset.var.index.tolist())
#         for k in k_values:
#             if k > 1:
#                 sampled_A= np.random.choice(expr_A.shape[0], (n_sim, k), replace=True)
#                 k_sums_A = pd.DataFrame(expr_A.values[sampled_A].sum(axis=1), columns=expr_A.columns)
#                 sampled_nA= np.random.choice(expr_nA.shape[0], (n_sim, k), replace=True)
#                 k_sums_nA = pd.DataFrame(expr_nA.values[sampled_nA].sum(axis=1), columns=expr_nA.columns)
#             else:
#                 k_sums_A = expr_A
#                 k_sums_nA = expr_nA
#             cutoff_k=np.max([k_sums_A.quantile(0.99).max(),
#                              k_sums_nA.quantile(0.99).max(), 1])
#             alpha = np.zeros(shape=(adata_subset.shape[1], int(cutoff_k)))
#             for l in np.arange(int(cutoff_k)):
#                 alpha[:,l]=empirical_cdf_v0(k_sums_A, l).values
#             beta = np.zeros(shape=(adata_subset.shape[1], int(cutoff_k)))
#             for l in np.arange(int(cutoff_k)):
#                 beta[:,l]=empirical_cdf_v0(k_sums_nA, l).values
#             if scoring=='diff':
#                 scores_ctype_v0[k]=(beta-alpha)[np.arange(adata_subset.shape[1]), np.argmax(beta-alpha,axis=1)]
#             elif scoring=='sensFPRzero':
#                 scores_ctype_v0[k]=(1-alpha)[np.arange(adata_subset.shape[1]), np.argmax(beta,axis=1)]
#             elif scoring=='sensPPV99':
#                 ppv = np.nan_to_num((1-alpha)/(2-alpha-beta))
#                 scores_ctype_v0[k]=((1-alpha)[np.arange(adata_subset.shape[1]), np.argmax(ppv>0.99,axis=1)])*(np.sum(ppv>0.99,axis=1)>0)
#             cutoffs_k=pd.DataFrame([k_sums_A.quantile(0.99),k_sums_nA.quantile(0.99)]).max().clip(lower=100)
#             #idx_k = cutoffs_k.apply(lambda x: np.arange(x + 1))
#             idx_k = cutoffs_k.apply(lambda x: (x * np.arange(101)) // max(100, x//50))
#             alpha = empirical_cdf(k_sums_A,idx_k.to_dict())
#             beta = empirical_cdf(k_sums_nA,idx_k.to_dict())
#             if scoring=='diff':
#                 scores_ctype_v1[k]=(beta-alpha).apply(lambda x: x.max())
#             elif scoring=='sensFPRzero':
#                 scores_ctype_v1[k]=[na[bmax] for na, bmax in zip(1-alpha, beta.apply(lambda x: x.argmax()))]
#             elif scoring=='sensPPV99':
#                 ppv = ((1-alpha)/(2-alpha-beta)).apply(lambda x: np.nan_to_num(x))
#                 scores_ctype_v1[k]= pd.Series([na[ppv] for na, ppv in zip(1-alpha, ppv.apply(lambda x: (x>0.99).argmax()))], index=alpha.index) * ppv.apply(lambda x: np.sum(x>0.99)>0)
#         scores_all_v0[ctype] = pd.DataFrame(scores_ctype_v0)
#         scores_all_v1[ctype] = pd.DataFrame(scores_ctype_v1)
#     corr_scores_v0 = multi_cat_correction(scores_all_v0)
#     corr_scores_v1 = multi_cat_correction(scores_all_v1)
#     scores_v0 = {}
#     scores_v1 = {}
#     for ctype in adata.obs[column_ctype].unique():
#         if filtered_scores[ctype].empty:
#             scores_v0[ctype]=pd.DataFrame([])
#             scores_v1[ctype]=pd.DataFrame([])
#         else:
#             scores_v0[ctype]=corr_scores_v0[ctype].loc[filtered_scores[ctype].index].clip(upper=1)
#             scores_v1[ctype]=corr_scores_v1[ctype].loc[filtered_scores[ctype].index].clip(upper=1)
#     return scores_v0, scores_v1


def empirical_cdf_v1(df, values):
    # computes empirical cdf on a dict, each gene does not need to have the same number of ponts of evaluation
    return pd.Series(
        {
            col: np.array([
                percentileofscore(df[col], v, kind='rank') / 100 for v in values[col]
            ])
            for col in df.columns
        },
    )

def empirical_scores_v1(filtered_scores, adata, column_ctype, column_patient, k_values=[1,10,25], scoring='diff', seed=0, n_sim=1000):
    # UNSTABLE: computes empirical score on 101 points between 0 and max(99th quantile in A, 99th quantile in complement of A, 100).
    gene_list = []
    np.random.seed(seed)
    scores_all = {}
    for ctype in adata.obs[column_ctype].unique():
        if not filtered_scores[ctype].empty:
            gene_list += filtered_scores[ctype].index.tolist()
    adata_subset=adata[:,gene_list].copy()
    for ctype in adata.obs[column_ctype].unique():
        scores_ctype = np.zeros(shape=(adata_subset.shape[1], len(k_values)))
        scores_ctype = pd.DataFrame(scores_ctype, index=adata_subset.var.index, columns=k_values)
        expr_A = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]==ctype], adata_subset.var.index.tolist())
        expr_nA = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]!=ctype], adata_subset.var.index.tolist())
        for k in k_values:
            if k > 1:
                sampled_A= np.random.choice(expr_A.shape[0], (n_sim, k), replace=True)
                k_sums_A = pd.DataFrame(expr_A.values[sampled_A].sum(axis=1), columns=expr_A.columns)
                sampled_nA= np.random.choice(expr_nA.shape[0], (n_sim, k), replace=True)
                k_sums_nA = pd.DataFrame(expr_nA.values[sampled_nA].sum(axis=1), columns=expr_nA.columns)
            else:
                k_sums_A = expr_A
                k_sums_nA = expr_nA
            #cutoffs_k=pd.DataFrame([k_sums_A.quantile(0.99),k_sums_nA.quantile(0.99)]).max().clip(lower=1)
            cutoffs_k=pd.DataFrame([k_sums_A.quantile(0.99),k_sums_nA.quantile(0.99)]).max().clip(lower=100)
            #idx_k = cutoffs_k.apply(lambda x: np.arange(x + 1))
            idx_k = cutoffs_k.apply(lambda x: (x * np.arange(101)) // max(100, x//50))
            alpha = empirical_cdf(k_sums_A,idx_k.to_dict())
            beta = empirical_cdf(k_sums_nA,idx_k.to_dict())
            if scoring=='diff':
                scores_ctype[k]=(beta-alpha).apply(lambda x: x.max())
            elif scoring=='sensFPRzero':
                scores_ctype[k]=[na[bmax] for na, bmax in zip(1-alpha, beta.apply(lambda x: x.argmax()))]
            elif scoring=='sensPPV99':
                ppv = ((1-alpha)/(2-alpha-beta)).apply(lambda x: np.nan_to_num(x))
                scores_ctype[k]= pd.Series([na[ppv] for na, ppv in zip(1-alpha, ppv.apply(lambda x: (x>0.99).argmax()))], index=alpha.index) * ppv.apply(lambda x: np.sum(x>0.99)>0)
        scores_all[ctype] = pd.DataFrame(scores_ctype)
    corr_scores = multi_cat_correction(scores_all)
    scores = {}
    for ctype in adata.obs[column_ctype].unique():
        if filtered_scores[ctype].empty:
            scores[ctype]=pd.DataFrame([])
        else:
            scores[ctype]=corr_scores[ctype].loc[filtered_scores[ctype].index].clip(upper=1)
    return scores

def empirical_scores_v2(filtered_scores, adata, column_ctype, column_patient, k_values=[1,10,25], scoring='diff', seed=0, n_sim=1000):
    # VERY UNSTABLE computes empirical scores at quantile points of distribution of A and complement of A
    # needs to be updated to exclude zeros from the computation of quantiles
    gene_list = []
    np.random.seed(seed)
    scores_all = {}
    for ctype in adata.obs[column_ctype].unique():
        if not filtered_scores[ctype].empty:
            gene_list += filtered_scores[ctype].index.tolist()
    adata_subset=adata[:,gene_list].copy()
    for ctype in adata.obs[column_ctype].unique():
        scores_ctype = np.zeros(shape=(adata_subset.shape[1], len(k_values)))
        scores_ctype = pd.DataFrame(scores_ctype, index=adata_subset.var.index, columns=k_values)
        expr_A = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]==ctype], adata_subset.var.index.tolist())
        expr_nA = sc.get.obs_df(adata_subset[adata_subset.obs[column_ctype]!=ctype], adata_subset.var.index.tolist())
        for k in k_values:
            if k > 1:
                sampled_A= np.random.choice(expr_A.shape[0], (n_sim, k), replace=True)
                k_sums_A = pd.DataFrame(expr_A.values[sampled_A].sum(axis=1), columns=expr_A.columns)
                sampled_nA= np.random.choice(expr_nA.shape[0], (n_sim, k), replace=True)
                k_sums_nA = pd.DataFrame(expr_nA.values[sampled_nA].sum(axis=1), columns=expr_nA.columns)
            else:
                k_sums_A = expr_A
                k_sums_nA = expr_nA
            quantile_pts = np.concatenate([np.array([0,0.01]), np.arange(1,50)/50 ,np.array([0.99,1])])
            quantiles_A = k_sums_A.quantile(quantile_pts)
            quantiles_nA = k_sums_nA.quantile(quantile_pts)
            idx_k = pd.concat([quantiles_A, quantiles_nA]).apply(lambda x: (x.unique()))
            alpha = empirical_cdf(k_sums_A,idx_k.to_dict())
            beta = empirical_cdf(k_sums_nA,idx_k.to_dict())
            if scoring=='diff':
                scores_ctype[k]=(beta-alpha).apply(lambda x: x.max())
            elif scoring=='sensFPRzero':
                scores_ctype[k]=[na[bmax] for na, bmax in zip(1-alpha, beta.apply(lambda x: x.argmax()))]
            elif scoring=='sensPPV99':
                ppv = ((1-alpha)/(2-alpha-beta)).apply(lambda x: np.nan_to_num(x))
                scores_ctype[k]= pd.Series([na[ppv] for na, ppv in zip(1-alpha, ppv.apply(lambda x: (x>0.99).argmax()))], index=alpha.index) * ppv.apply(lambda x: np.sum(x>0.99)>0)
        scores_all[ctype] = pd.DataFrame(scores_ctype)
    corr_scores = multi_cat_correction(scores_all)
    scores = {}
    for ctype in adata.obs[column_ctype].unique():
        if filtered_scores[ctype].empty:
            scores[ctype]=pd.DataFrame([])
        else:
            scores[ctype]=corr_scores[ctype].loc[filtered_scores[ctype].index].clip(upper=1)
    return scores


def estimate_binomial(adata, column_ctype, column_patient, parameter_estimation='patient_median'):
    """
    Estimates parameters of the binomial distribution on the single-cell gene expression dataset.

    Args:
        adata (AnnData object): Annotated single-cell gene expression data.
        column_ctype (str): Column in `adata.obs` specifying cell types.
        column_patient (str): Column in `adata.obs` specifying patient identifiers.
        parameter_estimation (str): Strategy for estimating binomial parameters. Options are:
                                     - 'patient_median': Estimated on each patient independently,
                                        median of estimated values taken.
                                     - 'mean': Estimated on the entire dataset.
                                    Default is 'patient_median'.

    Returns:
        pandas.Series: Estimated counts per cell type.
        pandas.Series: Estimated counts per complement (all other cell types except the given one).
        pandas.DataFrame: Normalized proportions of counts for each gene in each cell type.
        pandas.DataFrame: Normalized proportions of counts for each gene in the complement of each cell type.

    Raises:
        ValueError: If an unsupported estimation method is specified.
    """
    adata.obs[column_ctype] = adata.obs[column_ctype].astype('category')
    cat_list = adata.obs[column_ctype].cat.categories
    if parameter_estimation=='patient_median':
        # Generate pseudobulk data by aggregating counts based on patient and cell type.
        adata_agg = ADPBulk(adata, [column_patient, column_ctype])
        pseudobulk_matrix = adata_agg.fit_transform()  # Aggregated expression matrix.
        sample_meta = adata_agg.get_meta()  # Metadata corresponding to the pseudobulk samples.
        # Create a multi-index combining cell type and patient for easier data manipulation.
        tuples = list(zip(sample_meta[column_ctype],sample_meta[column_patient]))
        index = pd.MultiIndex.from_tuples(tuples, names=[column_ctype,column_patient])
        pseudobulk_matrix.set_index(index, inplace=True)
        # Calculate total counts per cell type and patient.
        total_counts = pseudobulk_matrix.sum(axis=1)
        # Calculate the number of cells per cell type and patient.
        n_cells = adata.obs.groupby(by=[column_ctype, column_patient]).size()
        # Calculate median counts per cell type across patients.
        counts_per_ctype = (total_counts/n_cells).groupby(level=column_ctype).median()
        # Calculate median counts for the complement of each cell type.
        counts_per_ctype_complement = pd.Series({
            ctype: ((total_counts.drop(ctype).groupby(level=column_patient).sum()) / 
                    (n_cells.drop(ctype).groupby(level=column_patient).sum())).median()
            for ctype in cat_list
        })
        # Compute median normalized proportions of counts for each cell type.
        count_proportions_per_ctype = (
            pseudobulk_matrix.div(total_counts.values, axis=0)
            .groupby(level=column_ctype)
            .median()
            .T
        )
        # Median normalized proportions for the complement of each cell type
        count_proportions_per_ctype_complement = pd.DataFrame({
            ctype: pseudobulk_matrix.drop(ctype)
                   .groupby(level=column_patient)
                   .sum()
                   .div(total_counts.drop(ctype).groupby(level=column_patient).sum(), axis=0)
                   .median()
            for ctype in cat_list
        })
    elif parameter_estimation=='mean':
        counts_per_ctype, counts_per_ctype_complement = {}, {}
        count_proportions_per_ctype, count_proportions_per_ctype_complement = {}, {}
        for ctype in cat_list:
            adata_A = adata[adata.obs[column_ctype]==ctype].copy()
            adata_nA = adata[adata.obs[column_ctype]!=ctype].copy()
            
            sums_A=adata_A.X.sum(axis=0)
            total_count_A=sums_A.sum()
            norm_sums_A=sums_A/total_count_A
            n_A=total_count_A/adata_A.shape[0]
            
            sums_nA=adata_nA.X.sum(axis=0)
            total_count_nA=sums_nA.sum()
            norm_sums_nA=sums_nA/total_count_nA
            n_nA=total_count_nA/adata_nA.shape[0]
            
            counts_per_ctype[ctype] = n_A
            counts_per_ctype_complement[ctype] = n_nA
            count_proportions_per_ctype[ctype] = norm_sums_A.tolist()[0]
            count_proportions_per_ctype_complement[ctype] = norm_sums_nA.tolist()[0]
        counts_per_ctype = pd.Series(counts_per_ctype)
        counts_per_ctype_complement = pd.Series(counts_per_ctype_complement)
        count_proportions_per_ctype = pd.DataFrame(count_proportions_per_ctype, index=adata.var.index)
        count_proportions_per_ctype_complement = pd.DataFrame(count_proportions_per_ctype_complement, index=adata.var.index)
    else:
        raise ValueError("Unsupported estimation method: " + aggregation_method)
    return counts_per_ctype, counts_per_ctype_complement, count_proportions_per_ctype, count_proportions_per_ctype_complement


def multi_cat_correction(computed_scores):
    """
    Adjusts marker scores across different cell types by normalizing them to sum to 1.

    Args:
        computed_scores (dict): A dictionary containing marker scores in the form of a 
                                Pandas DataFrame for each cell type.
    Returns:
        dict: A dictionary where keys are  the cell types from `computed_scores`, 
              and values are Pandas Series with normalized scores. Any NaN values in the result 
              are replaced with 0.
    """
    # Calculate the total sum of all scores across all categories.
    markers_sum = sum(computed_scores.values())
    # Normalize each category's scores by dividing by the total sum.
    # Replace any resulting NaN values with 0 (e.g., to handle division by zero scenarios).
    processed_scores = {
        ctype: (score / markers_sum).fillna(0)
        for ctype, score in computed_scores.items()
    }
    return processed_scores


def aggregation(processed_scores, aggregation_method='mean'):
    """
    Aggregates processed scores across the different values of k.

    Args:
        processed_scores (dict): A dictionary containing marker scores in the form of a 
                                 Pandas DataFrame for each cell type, where each column
                                 contains the scores for one one value of k.
        aggregation_method (str): The method to use for aggregation. Supported methods are:
            - 'mean': Compute the mean across columns for each row.
            - 'max': Compute the maximum across columns for each row.

    Returns:
        dict: A dictionary where keys are the cell types from `processed_scores`, and
              values are Pandas Series resulting from the aggregation applied to each DataFrame.
              If a DataFrame is empty, an empty DataFrame is returned for that cell type.

    Raises:
        ValueError: If an unsupported `aggregation_method` is provided.
    """
    # Validate the aggregation method.
    if aggregation_method not in {'mean', 'max'}:
        raise ValueError(f"Unsupported aggregation method: {aggregation_method}")
    # Define the aggregation function dynamically based on the chosen method.
    aggregation_func = {
        'mean': lambda df: df.mean(axis=1),  # Compute the row-wise mean.
        'max': lambda df: df.max(axis=1)    # Compute the row-wise maximum.
    }[aggregation_method]
    # Apply the aggregation function to each category's DataFrame.
    aggregated_scores = {
        ctype: aggregation_func(scores) if not scores.empty else pd.DataFrame([])  # Handle empty DataFrames.
        for ctype, scores in processed_scores.items()
    }
    return aggregated_scores  # Return the aggregated scores as a dictionary.



def filtering(aggregated_scores, column_patient, column_ctype, max_genes_kept=100, min_pat=3, min_read=10):
    """
    Filters aggregated gene scores based on multiple conditions.

    Args:
        aggregated_scores (dict): Dictionary where keys are cell types (ctypes) and values are dataframes of gene scores.
                                  Contains a column 'aggregated' with aggregated scores.
        column_patient (str): Column in `adata.obs` specifying patient identifiers.
        column_ctype (str): Column in `adata.obs` specifying cell types.
        max_genes_kept (int): Maximum number of candidate marker genes to retain per cell type after filtering.
        min_pat (int): Minimum number of patients a gene must be observed in within the given cell type
        min_read (int): Minimum number of reads a potential gene must have within the given cell type.

    Returns:
        dict: A dictionary where keys are the cell types from `aggregated_scores`, and
              values are Pandas DataFrames containing sores of only the retained genes,
              If no genes are retained, an empty DataFrame is returned for that cell type.
    """
    # no. patients
    patient_agg = ADPBulk(adata, [column_patient, column_ctype], name_delim="--", group_delim="::")
    patient_matrix = patient_agg.fit_transform()
    patient_matrix.index=patient_matrix.index.get_level_values(0).str.split('--', expand=True)
    ctype_n_patients = (patient_matrix>0).groupby(level=0).sum()
    ctype_n_patients.index=ctype_n_patients.index.str.split('::').str[-1]
    # no. reads
    reads_agg = ADPBulk(adata, column_ctype, group_delim="::")
    ctype_n_reads = reads_agg.fit_transform()
    ctype_n_reads.index=ctype_n_reads.index.str.split('::').str[-1]
    # Filter scores
    filtered_scores = {}
    for ctype, scores in aggregated_scores.items():
        # Combine all filtering conditions
        valid_genes = (
            (scores['aggregated'] > 0.5) &
            (ctype_n_patients.loc[ctype, scores.index] >= min_pat) &
            (ctype_n_reads.loc[ctype, scores.index] >= min_read)
        )
        # Apply filtering and limit max genes kept
        filtered = scores[valid_genes]
        filtered_scores[ctype] = filtered.iloc[:max_genes_kept] if len(filtered) > max_genes_kept else filtered
    return filtered_scores


def score_sorting(aggregated_scores, computed_scores, expr_proportions):
    """
    Sorts computed scores for different categories/types based on aggregated scores and expression level.

    Args:
        aggregated_scores (dict): A dictionary where keys are category types (e.g., strings) and values are Pandas Series
                                  containing aggregated scores for each category/type.
        computed_scores (dict): A dictionary where keys are category types and values are Pandas DataFrames
                                 containing computed scores for each category/type.
        expr_proportions (dict): A dictionary where keys are category types and values are Pandas Series
                                 containing expression proportions for each category/type.

    Returns:
        dict: A dictionary where keys are category types and values are sorted Pandas DataFrames.
              Each DataFrame is sorted by 'aggregated' and 'exp_prop' columns in descending order.
              If the corresponding aggregated_scores Series is empty, an empty DataFrame is returned.
    """
    sorted_scores = {
        ctype: computed_scores[ctype].assign( # Start with processed_scores data for each k value.
            aggregated=aggregated_scores[ctype],  # Add 'aggregated' column from aggregated_scores.
            exp_prop=expr_proportions[ctype]  # Add 'exp_prop' column from expr_proportions.
        ).sort_values(by=['aggregated', 'exp_prop'], ascending=False)  # Sort by both columns in descending order.
        if not aggregated_scores[ctype].empty else pd.DataFrame([])   # Handle empty DataFrames (no markers).
        for ctype in aggregated_scores.keys()
    }
    return sorted_scores










# Additional code for testing more stringent filtering on top of already calculated results
def filtering_emp(aggregated_scores,min_pat=3, min_read=10):
    # no. patients
    patient_agg = ADPBulk(adata, [column_patient, column_ctype], group_delim="::")
    patient_matrix = patient_agg.fit_transform()
    patient_matrix.index=patient_matrix.index.get_level_values(0).str.split('-', expand=True)
    ctype_n_patients = (patient_matrix>0).groupby(level=0).sum()
    ctype_n_patients.index=ctype_n_patients.index.str.split('::').str[-1]
    # no. reads
    reads_agg = ADPBulk(adata, column_ctype, group_delim="::")
    ctype_n_reads = reads_agg.fit_transform()
    ctype_n_reads.index=ctype_n_reads.index.str.split('::').str[-1]
    #
    filtered_scores=aggregated_scores.copy()
    for ctype in filtered_scores.keys():
        if len(filtered_scores[ctype]) > 0:
            filtered_scores[ctype] = filtered_scores[ctype][ctype_n_patients.loc[ctype,filtered_scores[ctype].index]>=min_pat]
        if len(filtered_scores[ctype]) > 0:
            filtered_scores[ctype] = filtered_scores[ctype][ctype_n_reads.loc[ctype,filtered_scores[ctype].index]>=min_read]
    return filtered_scores

# Boxen plot code
def boxenplot_gene(adata, gene, column, ctype):
    gene_expression = adata[:, gene].X.toarray()  # Assuming 'GRP' is the gene you want
    categories = adata.obs[column]  # Assuming 'ann_finest_level' is the category
    #
    # Convert the data to a format suitable for Seaborn boxplot
    data = {'Gene': gene_expression.flatten(), 'Category': categories}
    data_df = pd.DataFrame(data)
    #
    # Create the boxplot
    plt.figure(figsize=(10, 6))  # Adjust the figure size if needed
    sns.boxenplot(x='Category', y='Gene', data=data_df)
    plt.xticks(rotation=90)  # Rotate x-axis labels if needed
    plt.xlabel('Cell type')
    plt.ylabel('Counts')
    plt.title(f'Top scoring marker for {ctype} cells: {gene}')
    plt.tight_layout()  # Adjust layout


adata = sc.read_h5ad('/home/croizer/Documents/02_Analysis/02_scSherlock/hao_2021_SCT.h5ad')
cell_type_column = "celltype.l3"
patient_column = "donor"

top_markers = SCherlock(adata, cell_type_column, patient_column, k_values=[1,10,25], scoring='diff', aggregation_method='mean', parameter_estimation='patient_median', max_genes_kept=100, min_pat=3, min_read=10)

print(top_markers)


In [15]:
top_markers[1]['CD14 Mono']

Unnamed: 0,1,10,25,aggregated,exp_prop
IGSF23,-0.020758,1.000000,1.000000,0.659747,5.076480e-07
LINC02218,-0.023623,1.000000,1.000000,0.658792,6.776120e-07
ST6GALNAC2,-0.040529,1.000000,0.969697,0.643056,1.041262e-06
AC009974.1,-0.059418,1.000000,0.958506,0.633029,2.287647e-06
PPARGC1A,-0.024449,0.833333,1.000000,0.602961,5.161545e-07
...,...,...,...,...,...
AC006547.2,-0.023560,0.030516,0.032070,0.013009,6.295021e-07
AC007216.4,-0.013819,0.014925,0.037487,0.012865,5.316669e-07
CRX,-0.034259,0.028436,0.029025,0.007734,8.524916e-07
SNORC,-0.029168,0.024184,0.022093,0.005703,5.959651e-07


In [14]:
top_markers

({'CD14 Mono': 'IGSF23',
  'CD16 Mono': 'DOC2B',
  'MAIT': 'PRSS35',
  'dnT_2': 'OGN',
  'pDC': 'MYMX',
  'gdT_1': 'TRDV2',
  'Plasma': 'LINC02362',
  'Treg Memory': 'FANK1',
  'cDC2_1': 'CD1B',
  'NK Proliferating': 'SMC1B',
  'HSPC': 'PROM1',
  'cDC2_2': 'CD1E',
  'Platelet': 'LY6G6F',
  'NK_CD56bright': 'ADGRG3',
  'CD4 Proliferating': 'KIF20A',
  'Eryth': 'AHSP',
  'Plasmablast': 'GLDC',
  'cDC1': 'ERICH5',
  'ASDC_mDC': 'HAMP',
  'ILC': 'TTLL10'},
 {'CD14 Mono':                    1        10        25  aggregated      exp_prop
  IGSF23     -0.020758  1.000000  1.000000    0.659747  5.076480e-07
  LINC02218  -0.023623  1.000000  1.000000    0.658792  6.776120e-07
  ST6GALNAC2 -0.040529  1.000000  0.969697    0.643056  1.041262e-06
  AC009974.1 -0.059418  1.000000  0.958506    0.633029  2.287647e-06
  PPARGC1A   -0.024449  0.833333  1.000000    0.602961  5.161545e-07
  ...              ...       ...       ...         ...           ...
  AC006547.2 -0.023560  0.030516  0.032070    0