In [10]:
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from cinemaot import sinkhorn_knopp as skp
#from . import utils
from scipy.sparse import issparse
from sklearn.neighbors import NearestNeighbors
import scipy.stats as ss

# In this newer version we use the Python implementation of xicor
# import rpy2.robjects as ro
# import rpy2.robjects.numpy2ri
# import rpy2.robjects.pandas2ri
# from rpy2.robjects.packages import importr
# rpy2.robjects.numpy2ri.activate()
# rpy2.robjects.pandas2ri.activate()


# Instead of projecting the whole count matrix, we use the pca result of projected ICA components to stablize the noise
# returning an anndata object
# Detecting differently expressed genes: G = A + Z + AZ + e by NB regression. Significant coefficient before AZ means conditional-specific effect
# Further exclusion of false positives may be removed by permutation (as in PseudotimeDE)

#import ot

import statsmodels.api as sm
from sklearn.linear_model import LinearRegression

from sklearn.decomposition import FastICA
import sklearn.metrics

In [11]:
def weighted_quantile(values, num, sample_weight=None, 
                      values_sorted=False):
    """
    Estimate weighted quantile for robust estimation of gene expression change given the OT map. The function is completely vectorized to accelerate computation
    """
    values = np.array(values)
    if sample_weight is None:
        sample_weight = np.ones(len(values))
    sorter = np.argsort(values,axis=0)
    values = np.take_along_axis(values, sorter, axis=0)
    sample_weight = np.tile(sample_weight/np.sum(sample_weight),(1,values.shape[1]))
    sample_weight = np.take_along_axis(sample_weight,sorter,axis=1)
    weighted_quantiles = np.cumsum(sample_weight,axis=0)
    weighted_quantiles = np.vstack((np.zeros(values.shape[1]),weighted_quantiles))
    numindex = np.sum(values <= num.reshape(1,-1),axis=0)
    return np.diag(weighted_quantiles[np.ix_(numindex,np.arange(values.shape[1]))])

In [12]:
class Xi:
    """
    x and y are the data vectors
    """

    def __init__(self, x, y):

        self.x = x
        self.y = y

    @property
    def sample_size(self):
        return len(self.x)

    @property
    def x_ordered_rank(self):
        # PI is the rank vector for x, with ties broken at random
        # Not mine: source (https://stackoverflow.com/a/47430384/1628971)
        # random shuffling of the data - reason to use random.choice is that
        # pd.sample(frac=1) uses the same randomizing algorithm
        len_x = len(self.x)
        randomized_indices = np.random.choice(np.arange(len_x), len_x, replace=False)
        randomized = [self.x[idx] for idx in randomized_indices]
        # same as pandas rank method 'first'
        rankdata = ss.rankdata(randomized, method="ordinal")
        # Reindexing based on pairs of indices before and after
        unrandomized = [
            rankdata[j] for i, j in sorted(zip(randomized_indices, range(len_x)))
        ]
        return unrandomized

    @property
    def y_rank_max(self):
        # f[i] is number of j s.t. y[j] <= y[i], divided by n.
        return ss.rankdata(self.y, method="max") / self.sample_size

    @property
    def g(self):
        # g[i] is number of j s.t. y[j] >= y[i], divided by n.
        return ss.rankdata([-i for i in self.y], method="max") / self.sample_size

    @property
    def x_ordered(self):
        # order of the x's, ties broken at random.
        return np.argsort(self.x_ordered_rank)

    @property
    def x_rank_max_ordered(self):
        x_ordered_result = self.x_ordered
        y_rank_max_result = self.y_rank_max
        # Rearrange f according to ord.
        return [y_rank_max_result[i] for i in x_ordered_result]

    @property
    def mean_absolute(self):
        x1 = self.x_rank_max_ordered[0 : (self.sample_size - 1)]
        x2 = self.x_rank_max_ordered[1 : self.sample_size]
        
        return (
            np.mean(
                np.abs(
                    [
                        x - y
                        for x, y in zip(
                            x1,
                            x2,
                        )
                    ]
                )
            )
            * (self.sample_size - 1)
            / (2 * self.sample_size)
        )

    @property
    def inverse_g_mean(self):
        gvalue = self.g
        return np.mean(gvalue * (1 - gvalue))

    @property
    def correlation(self):
        """xi correlation"""
        return 1 - self.mean_absolute / self.inverse_g_mean

    @classmethod
    def xi(cls, x, y):
        return cls(x, y)

    def pval_asymptotic(self, ties=False, nperm=1000):
        """
        Returns p values of the correlation
        Args:
            ties: boolean
                If ties is true, the algorithm assumes that the data has ties
                and employs the more elaborated theory for calculating
                the P-value. Otherwise, it uses the simpler theory. There is
                no harm in setting tiles True, even if there are no ties.
            nperm: int
                The number of permutations for the permutation test, if needed.
                default 1000
        Returns:
            p value
        """
        # If there are no ties, return xi and theoretical P-value:

        if ties:
            return 1 - ss.norm.cdf(
                np.sqrt(self.sample_size) * self.correlation / np.sqrt(2 / 5)
            )

        # If there are ties, and the theoretical method
        # is to be used for calculation P-values:
        # The following steps calculate the theoretical variance
        # in the presence of ties:
        sorted_ordered_x_rank = sorted(self.x_rank_max_ordered)

        ind = [i + 1 for i in range(self.sample_size)]
        ind2 = [2 * self.sample_size - 2 * ind[i - 1] + 1 for i in ind]

        a = (
            np.mean([i * j * j for i, j in zip(ind2, sorted_ordered_x_rank)])
            / self.sample_size
        )

        c = (
            np.mean([i * j for i, j in zip(ind2, sorted_ordered_x_rank)])
            / self.sample_size
        )

        cq = np.cumsum(sorted_ordered_x_rank)

        m = [
            (i + (self.sample_size - j) * k) / self.sample_size
            for i, j, k in zip(cq, ind, sorted_ordered_x_rank)
        ]

        b = np.mean([np.square(i) for i in m])
        v = (a - 2 * b + np.square(c)) / np.square(self.inverse_g_mean)

        return 1 - ss.norm.cdf(
            np.sqrt(self.sample_size) * self.correlation / np.sqrt(v)
        )

In [13]:
def cinemaot_unweighted(adata,obs_label,ref_label,expr_label,dim=20,thres=0.15,smoothness=1e-4,eps=1e-3,mode='parametric',marker=None,preweight_label=None):
    """
    Parameters
    ----------
    adata: 'AnnData'
        An anndata object containing the whole gene count matrix and an observation index for treatments. It should be preprocessed before input.
    obs_label: 'str'
        A string for indicating the treatment column name in adata.obs.
    ref_label: 'str'
        A string for indicating the control group in adata.obs.values.
    expr_label: 'str'
        A string for indicating the experiment group in adata.obs.values.
    dim: 'int'
        The number of independent components.
    thres: 'float'
        The threshold for setting the Chatterjee coefficent for confounder separation.
    smoothness: 'float'
        The parameter for setting the smoothness of entropy-regularized optimal transport. Should be set as a small value above zero!
    eps: 'float'
        The parameter for stop condition of OT convergence. 
    mode: 'str'
        If mode is 'parametric', return standard differential matrices. If it's non-parametric, we return expr cells' weighted quantile.
    Return
    ----------
    cf: 'numpy.ndarray'
        Confounder components, of shape (n_cells,n_components).
    ot: 'numpy.ndarray'
        Transport map across control and experimental conditions.
    te2: 'numpy.ndarray'
        Single-cell differential expression for each cell in control condition, of shape (n_refcells, n_genes).
    """
    if dim is None:
        sk = skp.SinkhornKnopp()
        c = 0.5
        data=adata.X
        vm = (1e-3 + data + c * data * data)/(1+c)
        P = sk.fit(vm)
        wm = np.dot(np.dot(np.sqrt(sk._D1),vm),np.sqrt(sk._D2))
        u,s,vt = np.linalg.svd(wm)
        dim = np.min(sum(s > (np.sqrt(data.shape[0])+np.sqrt(data.shape[1]))),adata.obsm['X_pca'].shape[1])


    transformer = FastICA(n_components=dim, random_state=0,whiten="arbitrary-variance")
    X_transformed = transformer.fit_transform(adata.obsm['X_pca'][:,:dim])
    #importr("XICOR")
    #xicor = ro.r["xicor"]
    groupvec = (adata.obs[obs_label]==ref_label *1).values #control
    xi = np.zeros(dim)
    #pval = np.zeros(dim)
    j = 0
    for source_row in X_transformed.T:
        xi_obj = Xi(source_row,groupvec*1)
        #rresults = xicor(ro.FloatVector(source_row), ro.FloatVector(groupvec), pvalue = True)
        #xi[j] = np.array(rresults.rx2("xi"))[0]
        xi[j] = xi_obj.correlation
        #pval[j] = np.array(rresults.rx2("pval"))[0]
        j = j+1
    cf = X_transformed[:,xi<thres]
    cf1 = cf[adata.obs[obs_label]==expr_label,:] #expr
    cf2 = cf[adata.obs[obs_label]==ref_label,:] #control
    if sum(xi<thres)==1:
        dis = sklearn.metrics.pairwise_distances(cf1.reshape(-1,1),cf2.reshape(-1,1))
    elif sum(xi<thres)==0:
        raise ValueError("No confounder components identified. Please try a higher threshold.")
    else:
        dis = sklearn.metrics.pairwise_distances(cf1,cf2)
    e = smoothness * sum(xi<thres)
    af = np.exp(-dis * dis / e)
    r = np.zeros([cf1.shape[0],1])
    c = np.zeros([cf2.shape[0],1])
    if preweight_label is None:
        r[:,0] = 1/cf1.shape[0]
        c[:,0] = 1/cf2.shape[0]
    else:
        #implement a simple function here, taking adata.obs, output inverse prob weight. For consistency, c is still the empirical distribution, while r is weighted.
        adata1 = adata[adata.obs[obs_label]==expr_label,:]
        adata2 = adata[adata.obs[obs_label]==ref_label,:]
        c[:,0] = 1/cf2.shape[0]
        for ct in list(set(adata1.obs[preweight_label].values.tolist())):
            r[(adata1.obs[preweight_label]==ct).values,0] = np.sum((adata2.obs[preweight_label]==ct).values) / np.sum((adata1.obs[preweight_label]==ct).values)
        r[:,0] = r[:,0]/np.sum(r[:,0])

    sk = skp.SinkhornKnopp(setr=r,setc=c,epsilon=eps)
    ot_matrix = sk.fit(af).T

    embedding = X_transformed[adata.obs[obs_label]==ref_label,:] - np.matmul(ot_matrix/np.sum(ot_matrix,axis=1)[:,None],X_transformed[adata.obs[obs_label]==expr_label,:])

    if mode == 'parametric':
        if issparse(adata.X):
            te2 = adata.X.toarray()[adata.obs[obs_label]==ref_label,:] - np.matmul(ot_matrix/np.sum(ot_matrix,axis=1)[:,None],adata.X.toarray()[adata.obs[obs_label]==expr_label,:])
        else:
            te2 = adata.X[adata.obs[obs_label]==ref_label,:] - np.matmul(ot_matrix/np.sum(ot_matrix,axis=1)[:,None],adata.X[adata.obs[obs_label]==expr_label,:])
    elif mode == 'non_parametric':
        if issparse(adata.X):
            ref = adata.X.toarray()[adata.obs[obs_label]==ref_label,:]
            ref = ref[:,adata.var_names.isin(marker)]
            expr = adata.X.toarray()[adata.obs[obs_label]==expr_label,:]
            expr = expr[:,adata.var_names.isin(marker)]
            te2 = ref * 0
            for i in range(te2.shape[0]):
                te2[i,:] = weighted_quantile(expr,ref[i,:],sample_weight=ot_matrix[i,:])
        else:
            ref = adata.X[adata.obs[obs_label]==ref_label,:]
            ref = ref[:,adata.var_names.isin(marker)]
            expr = adata.X[adata.obs[obs_label]==expr_label,:]
            expr = expr[:,adata.var_names.isin(marker)]
            te2 = ref * 0
            for i in range(te2.shape[0]):
                te2[i,:] = weighted_quantile(expr,ref[i,:],sample_weight=ot_matrix[i,:])            
    else:
        raise ValueError("We do not support other methods for DE now.")

    TE = sc.AnnData(te2,obs=adata[adata.obs[obs_label]==ref_label,:].obs.copy(),var=adata.var.copy())
    TE.obsm['X_embedding'] = embedding
    return cf, ot_matrix, TE

In [26]:
adata = sc.read_h5ad('data/Integrated_subset.h5ad')

In [27]:
adata_ = adata[adata.obs["perturbation"].isin(["No stimulation", "IFNb"])]
sc.pp.pca(adata_)
cf, ot, de = cinemaot_unweighted(
    adata_,
    obs_label="perturbation",
    ref_label="No stimulation",
    expr_label="IFNb",
    mode="parametric",
    thres=0.5,
    smoothness=1e-5,
    eps=1e-3,
    preweight_label="cell_type0528",
)

  adata.obsm['X_pca'] = X_pca


In [28]:
ot

array([[1.04288645e-17, 4.39323647e-14, 8.65123580e-15, ...,
        2.71220876e-15, 8.25183199e-19, 1.14311443e-12],
       [1.19102212e-32, 4.26832062e-22, 1.02291188e-23, ...,
        3.00446440e-23, 2.77416264e-22, 5.05143041e-28],
       [3.02779461e-11, 2.92344022e-08, 2.91644610e-08, ...,
        6.01648001e-15, 4.85040164e-08, 4.25144831e-11],
       ...,
       [2.11124543e-23, 2.35593573e-22, 6.76845508e-23, ...,
        5.14545000e-26, 8.24019216e-21, 5.89943866e-16],
       [1.40929510e-12, 1.24397978e-14, 1.86974088e-09, ...,
        9.03585371e-18, 3.75657794e-14, 1.29347429e-17],
       [1.47324371e-15, 4.47141112e-11, 4.68498025e-17, ...,
        3.43331379e-15, 8.08898688e-17, 1.00301695e-16]])

In [24]:
adata_.obs[adata_.obs['perturbation']=='No stimulation']

Unnamed: 0,perturbation,n_genes,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt,batch,leiden,cell_type0528
AAACCTGAGGCTAGCA-1,No stimulation,938,938,1921.0,78.0,4.060385,H3D2,7,CD8 T
AAACCTGAGTGGAGTC-1,No stimulation,4288,4287,16910.0,276.0,1.632170,H3D2,18,Monocyte
AAACCTGCACATCTTT-1,No stimulation,691,691,1199.0,40.0,3.336113,H3D2,14,CD4 T
AAACCTGCATTATCTC-1,No stimulation,1072,1072,2590.0,77.0,2.972973,H3D2,0,CD4 T
AAACGGGAGACTAGGC-1,No stimulation,781,781,1345.0,27.0,2.007435,H3D2,2,CD8 T
...,...,...,...,...,...,...,...,...,...
TTTGTCACACAACTGT-1,No stimulation,863,863,1654.0,46.0,2.781137,H3D2,7,CD8 T
TTTGTCACACGAAATA-1,No stimulation,3239,3238,10030.0,259.0,2.582253,H3D2,12,Monocyte
TTTGTCAGTCTGATCA-1,No stimulation,1426,1425,2727.0,70.0,2.566923,H3D2,22,NK
TTTGTCATCCTATTCA-1,No stimulation,486,486,646.0,26.0,4.024768,H3D2,23,CD4 T


In [32]:
adata_.obs[adata_.obs['perturbation']=='IFNb']

Unnamed: 0,perturbation,n_genes,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt,batch,leiden,cell_type0528
AAACCTGAGAATTGTG-1,IFNb,2124,2124,5351.0,66.0,1.233414,H3D2,10,Monocyte
AAACCTGAGATGCCAG-1,IFNb,1650,1650,4109.0,71.0,1.727914,H3D2,0,CD4 T
AAACCTGCACTGCCAG-1,IFNb,720,720,1104.0,15.0,1.358696,H3D2,5,NK
AAACCTGGTACCGGCT-1,IFNb,897,897,1746.0,39.0,2.233677,H3D2,3,CD4 T
AAACCTGGTCAGAAGC-1,IFNb,1203,1203,2351.0,32.0,1.361123,H3D2,3,CD4 T
...,...,...,...,...,...,...,...,...,...
TTTGGTTTCGTGGACC-1,IFNb,797,797,1322.0,13.0,0.983359,H3D2,20,CD4 T
TTTGGTTTCTTGTTTG-1,IFNb,864,864,1473.0,28.0,1.900882,H3D2,1,CD4 T
TTTGTCAAGCGCTTAT-1,IFNb,2964,2964,7918.0,118.0,1.490275,H3D2,10,Monocyte
TTTGTCAGTAAGAGGA-1,IFNb,1237,1237,2447.0,77.0,3.146710,H3D2,3,CD4 T


In [33]:
de.obs

Unnamed: 0,perturbation,n_genes,n_genes_by_counts,total_counts,total_counts_mt,pct_counts_mt,batch,leiden,cell_type0528
AAACCTGAGGCTAGCA-1,No stimulation,938,938,1921.0,78.0,4.060385,H3D2,7,CD8 T
AAACCTGAGTGGAGTC-1,No stimulation,4288,4287,16910.0,276.0,1.632170,H3D2,18,Monocyte
AAACCTGCACATCTTT-1,No stimulation,691,691,1199.0,40.0,3.336113,H3D2,14,CD4 T
AAACCTGCATTATCTC-1,No stimulation,1072,1072,2590.0,77.0,2.972973,H3D2,0,CD4 T
AAACGGGAGACTAGGC-1,No stimulation,781,781,1345.0,27.0,2.007435,H3D2,2,CD8 T
...,...,...,...,...,...,...,...,...,...
TTTGTCACACAACTGT-1,No stimulation,863,863,1654.0,46.0,2.781137,H3D2,7,CD8 T
TTTGTCACACGAAATA-1,No stimulation,3239,3238,10030.0,259.0,2.582253,H3D2,12,Monocyte
TTTGTCAGTCTGATCA-1,No stimulation,1426,1425,2727.0,70.0,2.566923,H3D2,22,NK
TTTGTCATCCTATTCA-1,No stimulation,486,486,646.0,26.0,4.024768,H3D2,23,CD4 T


In [52]:
control = adata[adata.obs['perturbation']=='No stimulation']

In [54]:
control.X, control.X.shape

(ArrayView([[-1.5649903e-01,  4.6627760e+00, -1.3260117e-01, ...,
             -1.9857515e-01, -1.6133374e-01, -2.3858619e-01],
            [ 2.5749395e-02,  6.0544107e-02,  7.2225355e-02, ...,
             -2.3227680e-02,  9.0107147e-04,  2.1258122e-01],
            [-1.6670781e-01, -2.8013307e-01, -1.5954842e-01, ...,
             -2.1336092e-01, -1.6441849e-01, -2.6941788e-01],
            ...,
            [-1.4901598e-01, -2.4170154e-01, -1.4926013e-01, ...,
             -1.9941691e-01, -1.4494702e-01,  4.4739728e+00],
            [-1.7241320e-01, -2.9764402e-01, -1.5494134e-01, ...,
             -2.1531560e-01, -1.7377223e-01, -2.7958304e-01],
            [-1.5589504e-01, -2.5722858e-01, -1.5235630e-01, ...,
             -2.0454866e-01, -1.5286881e-01, -2.4443221e-01]],
           dtype=float32),
 (2268, 773))

In [46]:
adata.X, adata.X.shape

(array([[-0.11865562, -0.17559336, -0.13184807, ..., -0.17556609,
         -0.11143828, -0.1599121 ],
        [-0.13325824, -0.20568384, -0.14286487, ..., -0.18788525,
         -0.12653024, -0.19412361],
        [-0.15649903,  4.662776  , -0.13260117, ..., -0.19857515,
         -0.16133374, -0.23858619],
        ...,
        [-0.12380419, -0.19445087, -0.12295869, ..., -0.17581214,
         -0.12171488, -0.1673852 ],
        [-0.1724132 , -0.29764402, -0.15494134, ..., -0.2153156 ,
         -0.17377223, -0.27958304],
        [-0.15589504, -0.25722858, -0.1523563 , ..., -0.20454866,
         -0.1528688 , -0.24443221]], dtype=float32),
 (9209, 773))

In [47]:
adata

AnnData object with n_obs × n_vars = 9209 × 773
    obs: 'perturbation', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'batch', 'leiden', 'cell_type0528'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
    uns: 'batch_colors', 'cell_type0528_colors', 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pca', 'perturbation_colors', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [35]:
de.X, de.X.shape

(array([[-5.36350218e-02,  4.79445373e+00, -2.50692672e+00, ...,
         -1.75557896e-04, -8.15729347e-03,  3.74461862e-02],
        [ 2.13709679e-02, -1.05344196e-02, -4.70492095e-02, ...,
         -2.19506462e-01, -3.78152651e-02,  4.16946664e-02],
        [-6.19409904e-02, -1.91294310e-01, -8.45040081e-03, ...,
         -7.44984073e-02, -8.68032782e-02, -5.29177999e-02],
        ...,
        [-3.89115187e-01, -1.30004209e+00, -1.24121167e+00, ...,
         -2.06752127e-01, -4.69588964e-01,  2.47418985e+00],
        [-3.11096077e-03,  6.06325101e-04, -1.24358385e-02, ...,
         -3.39462535e-01,  3.99766515e-04, -9.66897429e-03],
        [-3.42868360e-02, -2.52850823e-01, -7.84349447e-01, ...,
         -2.72645833e-02, -1.42368016e-02,  1.18387235e-02]]),
 (2268, 773))