In [32]:
import numpy as np
import scanpy as sc
import pandas as pd
from functools import partial

## Let's define the core statistic functions

In [None]:
from scipy.stats import chi2
from scipy.stats._mannwhitneyu import _broadcast_concatenate, _rankdata

def van_elteren_core(x, y, axis=0):
    x, y, xy = _broadcast_concatenate(x, y, axis)
    M, N = x.shape[-1], y.shape[-1]

    ranks, t = _rankdata(xy, 'average', return_ties=True)  
    R1 = ranks[..., :M].sum(axis=-1)
    U = R1 - M * (M + 1) / 2 # Mann-Whitney U statistic for X, symmetric to U stat for Y so works the same
    
    mu = M * N / 2  # Mean U
    
    n = M + N
    tie_term = (t**3 - t).sum(axis=-1)
    V = M * N / 12 * ((n + 1) - tie_term / (n * (n - 1))) # Tie corrected variance
    
    U = U - mu # Center U about zero
    U -= np.sign(U) * 0.5 # Continuity correction 

    return U, V, M, N

def van_elteren_test(X, Y, X_strata, Y_strata, return_effect_size=False):
    
    assert X.shape[-1] == Y.shape[-1], 'Same number of features expected.'
    
    unique_strata = np.unique(X_strata)
    n_strata = len(unique_strata)
    
    Us = np.full(shape=(n_strata, X.shape[1]), fill_value=np.nan)
    Vs = np.full(shape=(n_strata, X.shape[1]), fill_value=np.nan)
    Ms = np.full(shape=(n_strata, 1), fill_value=np.nan)
    Ns = np.full(shape=(n_strata, 1), fill_value=np.nan)
    
    for i, stratum in enumerate(unique_strata):
        X_stratum_mask = X_strata == stratum
        Y_stratum_mask = Y_strata == stratum
        
        if not (sum(X_stratum_mask) >= 10 and sum(Y_stratum_mask) >= 10):
            continue
        
        sX = X[X_stratum_mask]
        sY = Y[Y_stratum_mask]

        Us[i], Vs[i], Ms[i], Ns[i] = van_elteren_core(sX, sY)
    
    # Remove skipped strata from formulation
    valid_strata_mask = ~np.isnan(Ms[:,0])
    Us = Us[valid_strata_mask]
    Vs = Vs[valid_strata_mask]
    Ms = Ms[valid_strata_mask]
    Ns = Ns[valid_strata_mask]
    
    w = 1 / (Ms + Ns + 1) # locally best weighting
    # w = 1 / (Ms*Ns) # design free weighting
    with np.errstate(divide='ignore', invalid='ignore'):
        VEs = np.sum(w * Us, axis=0)**2 / np.sum(w**2 * Vs, axis=0) # Van Elteren Statistic
    
    pvalues = chi2.sf(VEs, df=1) # Van Elteren statistic can be approximated by a chi-squared distribution with 1 degree of freedom
    # valid_pvalues = ~np.isnan(pvalues) # If the variance is 0, VEs is nan therefore pvalue is nan
    # pvalues[valid_pvalues] = false_discovery_control(pvalues[valid_pvalues])
    
    if return_effect_size:
        with np.errstate(divide='ignore', invalid='ignore'):
            effect_size = np.sum(w * Us, axis=0) / np.sum(w * (Ms + Ns), axis=0) # Effect Size
        
        return effect_size, pvalues
    else:
        return pvalues

### Lets define the run function

In [34]:
def subsample_strata(adata, strata_key, max_size, min_size=0, random_state=0):
    selected = []
    for stratum in set(adata.obs[strata_key]):
        indices = np.where(adata.obs[strata_key] == stratum)[0]
        if len(indices) > max_size:
            np.random.seed(random_state)
            selected.extend(np.random.choice(indices, max_size).tolist())
        elif len(indices) < min_size:
            pass
        else:
            selected.extend(indices.tolist())
    
    return adata[selected]

In [35]:
pct_expression = lambda X: X.getnnz(axis=0)/X.shape[0]

In [36]:
def run_markers(cluster: sc.AnnData, 
               other: sc.AnnData,
               strata_key: str,
               max_stratum_size: int = 50,
               min_stratum_size: int = 10,
               min_pct_expression: float = 0.1,
               min_specificity: float = 0.3) -> bool:
    
    # print(f'Checking {clusterA.name} and {clusterB.name}')
    n_genes = cluster.shape[1]
    genes = cluster.var_names
    
    # Get cluster AnnData subsets
    subset_A = cluster
    subset_B = other
    
    # Subsample data to reduced strata proportions
    subset_A = subsample_strata(subset_A, strata_key, max_size=max_stratum_size, min_size=min_stratum_size)
    subset_B = subsample_strata(subset_B, strata_key, max_size=max_stratum_size, min_size=min_stratum_size)
    
    # Identify common strata between clusters
    common_strata = set(subset_A.obs[strata_key]).intersection(set(subset_B.obs[strata_key]))
    
    if not common_strata:
        print('No overlapping strata')
        return None
    
    # Update cluster subsets to contain only common strata
    subset_A = subset_A[subset_A.obs[strata_key].isin(common_strata)]
    subset_B = subset_B[subset_B.obs[strata_key].isin(common_strata)]
    
    # Get percentage of cells expressing each gene for both clusters    
    pct_A = pct_expression(subset_A.X)
    pct_B = pct_expression(subset_B.X)
    
    # Create mask for genes that have pct expression > min_pct_expression    
    pct_mask = (pct_A >= min_pct_expression)
    
    # Calculate absolute specificity for all genes between both clusters

    relative_pct_diff = np.abs(pct_A - pct_B)/(pct_A+pct_B+1e-10)
    avgFC = (subset_A.X.mean(axis=0) - subset_B.X.mean(axis=0)).A1
    specificity = relative_pct_diff * np.abs(avgFC)
    
    # Create mask for genes that have specificity > min_specificity   
    specificity_mask = specificity >= min_specificity
    
    # Any gene with higher enough expression and specificity is parsed to statistical test
    valid_gene_indices = np.where(pct_mask & specificity_mask)[0]
    
    print(f'{len(valid_gene_indices)} valid genes found!')

    # Update cluster subsets to contain only valid genes
    subset_A = subset_A[:, valid_gene_indices]
    subset_B = subset_B[:, valid_gene_indices]

    # Convert to dense for statistics
    exp_A = subset_A.X.toarray()
    exp_B = subset_B.X.toarray()

    # Run Van Elteren Test
    fs, ps = van_elteren_test(exp_A, exp_B, subset_A.obs[strata_key], subset_B.obs[strata_key], return_effect_size=True)
    
    effect_size = np.zeros(n_genes)
    effect_size[valid_gene_indices] = fs
    
    pvalues = np.ones(n_genes)
    pvalues[valid_gene_indices] = ps
    
    padj = np.clip(pvalues*len(valid_gene_indices), 0, 1) # Bonferroni false positive correction
    
    stat_res = pd.DataFrame({
        'specificity': specificity,
        'avgFC': avgFC,
        'effect_size': effect_size,
        'padj': padj,
        'pvalue': pvalues,
        'pctA': pct_A,
        'pctB': pct_B,
        'pct_diff': relative_pct_diff
    }, index=genes)
    
    # If all checks pass then return True, clusters are distinct
    # return sum(significant_mask) >= min_marker_genes
    return stat_res

partial_run_markers = partial(run_markers,
                            strata_key='Sample',
                            max_stratum_size=250,
                            min_stratum_size=10,
                            min_pct_expression=0.1,
                            min_specificity=0.1)

### Lets load the data

In [37]:
adata = sc.read_h5ad('../data/raw.h5ad')
adata.obsm['X_embed'] = sc.read_h5ad('../data/embed_only.h5ad').obsm['X_embed']

### Lets log2 normalize the data

In [38]:
sc.pp.normalize_total(adata, target_sum=1e5)
sc.pp.log1p(adata, base=2)

### Lets create some example clusters

In [39]:
sc.pp.neighbors(adata, use_rep='X_embed')

In [40]:
sc.tl.leiden(adata, resolution=1, flavor='igraph', n_iterations=2)

### Lets setup the data for marker gene detection

In [41]:
mask = adata.obs['leiden'] == '1' # boolean mask - true where cell is in cluster '1'
cluster = adata[mask] # AnnData view of cluster only
other = adata[~mask] # AnnData view of all cells no in cluster

### Lets run some markers!

In [43]:
ranking = partial_run_markers(cluster, other)
ranking = ranking[ranking['padj'] < 0.01] # Subset to significant genes
ranking = ranking[ranking['avgFC'] > 0] # Upregulated gene only
ranking.sort_values('specificity', ascending=False) # Sort by specificity

2825 valid genes found!


Unnamed: 0,specificity,avgFC,effect_size,padj,pvalue,pctA,pctB,pct_diff
Gabra6,2.745632,4.027257,24.725745,0.000000e+00,0.000000e+00,0.839900,0.158933,0.681762
Reln,2.577249,4.508045,24.851045,0.000000e+00,0.000000e+00,0.927681,0.252800,0.571700
Svep1,2.511865,2.894591,21.781222,0.000000e+00,0.000000e+00,0.606484,0.042933,0.867779
Grm4,2.440551,3.661247,26.057079,0.000000e+00,0.000000e+00,0.749127,0.149867,0.666590
Cadps2,2.383733,5.059105,27.797805,0.000000e+00,0.000000e+00,0.974065,0.350133,0.471177
...,...,...,...,...,...,...,...,...
St3gal5,0.103440,0.988544,8.422328,2.956989e-40,1.046722e-43,0.561596,0.455200,0.104639
Adcy8,0.101861,0.906037,4.871752,5.644051e-12,1.997894e-15,0.613965,0.489867,0.112425
Cpe,0.101639,1.251732,10.397718,2.209904e-55,7.822669e-59,0.856359,0.727733,0.081198
Tbc1d8,0.101077,0.713104,5.162637,3.649711e-19,1.291933e-22,0.352618,0.265067,0.141742
