In [1]:
import scvelo as scv
import dynamo as dyn
import numpy as np
from anndata import AnnData
# import loompy
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from scipy.cluster.hierarchy import fcluster,leaders
from sklearn.decomposition import PCA
from scipy.linalg import inv
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.mixture import GaussianMixture
from RKHS import SparseVFC
from RKHS import Jacobian_rkhs_gaussian
from scipy.stats import multivariate_normal
from scipy.sparse import csr_matrix
import pandas as pd

In [2]:
import argparse
import random

import scipy.sparse as sp
import scipy.sparse.csgraph
import sklearn.linear_model as sklm
import sklearn.metrics as skm
import sklearn.model_selection as skms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint
from torch.utils.data import DataLoader, IterableDataset, get_worker_info

CHECKPOINT_PREFIX = "g2g"

from g2g_model_Fisher import *
from utils import *
# from minepy import MINE
from sklearn.preprocessing import MinMaxScaler

In [3]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())

1.10.0
11.3


True

In [4]:
data_path='data/'
result_path='results/'
adata0=scv.read('DG_bin.h5ad', cache=True)

In [None]:
gene_arr=adata0.var.index.values
X_pca=adata0.obsm['X_pca']
X_umap=adata0.obsm['X_umap']
cell_vpt=adata0.obs['velocity_pseudotime'].values#adata0.obs['latent_time'].values#
Xs=adata0.layers['Ms']#adata.X.A#
X_pca.shape

In [None]:
#-------data preprocessing
k_nei=10
adata=adata0.copy()
scv.pp.neighbors(adata, n_neighbors=k_nei)
scv.pp.pca(adata,n_comps=50)
scv.pp.moments(adata, n_pcs=50, n_neighbors=k_nei)

In [None]:
cell_nei = np.load('results/cell_nei.npy')
nei_w = np.load('results/nei_w.npy')

def smooth_func(X_val,cell_nei=cell_nei,nei_w=nei_w):
    X_s=X_val.copy()
    for ci in range(len(X_val)):
        X_s[ci]=np.dot(X_val[cell_nei[ci,:]],nei_w[ci,:])
    return(X_s)

In [None]:
encoder = torch.load('results/encoder.pt')
# encoder = torch.load('results/encoder L=4,K=1.pt')

In [114]:
# X=Xs
X=Xs/np.mean(np.abs(Xs),axis=0)
# X=adata.X.A
mu, sigma = encoder(torch.tensor(X))
mu_learned = mu.detach().numpy()
sigma_learned = sigma.detach().numpy()
latent_z = np.hstack((mu_learned,sigma_learned))

L = np.load('results/latent_dim.npy')
Fisher_g=np.zeros((X.shape[0],L*2,L*2))
for i in range(X.shape[0]):
    for j in range(L):
        Fisher_g[i,2*j,2*j]=1/sigma_learned[i,j]**2
        Fisher_g[i,2*j+1,2*j+1]=2/sigma_learned[i,j]**2

In [None]:
crc = np.load('results/crc')
crc_eu = np.load('results/crc_eu')
crc_smooth = np.save('results/crc_smooth')

In [None]:
#--------------eigengene analysis------------ 
# -----here we use the first two PCs of each module as coordinates, comparing with only PC1 eigen gene, this 
#--------coordinates can give right distribution of information velocity

In [None]:
Xs=adata.layers['Ms']#adata.X.A#
Xu=adata.layers['Mu']
scaler=StandardScaler()#MinMaxScaler()#
X=scaler.fit_transform(Xs)
velo=np.array(adata.layers['velocity'])

X_corr=np.corrcoef(X, rowvar=False)
plt.imshow(X_corr)
plt.colorbar()
plt.show()

Z=linkage((1-X_corr)[np.triu_indices(X_corr.shape[0],k=1)],method='weighted')#method='weighted')
dg=dendrogram(Z)
X_re= X[:,dg['leaves']]#X_re reorder X by clustering

X_corr_re=np.corrcoef(X_re, rowvar=False)

plt.figure(figsize=(9,6))
plt.imshow(X_corr_re, aspect='auto', cmap=plt.cm.coolwarm, interpolation='nearest',origin='lower')

plt.xlabel('Gene',fontsize=16,fontweight='bold')
plt.ylabel('Gene',fontsize=16,fontweight='bold')

plt.xticks(fontsize=12,fontweight='bold')
plt.yticks(fontsize=12,fontweight='bold')
plt.colorbar()
plt.savefig(result_path+'EG_heatmap.png',dpi=300)
plt.show()

In [None]:
td=0.99
T=fcluster(Z, t=td, criterion='distance')
# T=fcluster(Z, t=td, criterion='maxclust')
T_re=T[dg['leaves']]

X_corr_label=X_corr_re.copy()
for i in range(X_corr_re.shape[0]):
    X_corr_label[:,i]=T_re[i]
    X_corr_label[i,:]=T_re[i]

plt.imshow(X_corr_label, aspect='auto', cmap=plt.cm.coolwarm, interpolation='nearest',origin='lower')
plt.colorbar()
plt.show()

In [None]:
# eigen gene is the value of first principal component of each community
def eigen_gene(X_re,T_re):#return the weights of each gene in correponding eigen gene
    eigen_X_w1=[]
    eigen_X_w2=[]
    for i in np.unique(T_re):
        pca=PCA(n_components=10).fit(X_re[:,T_re==i])
        print(pca.explained_variance_ratio_)
        pc1=pca.transform(X_re[:,T_re==i])[:,0]
        pc2=pca.transform(X_re[:,T_re==i])[:,1]
        eigen_X_w1.append(pca.components_[0,:])
        eigen_X_w2.append(pca.components_[1,:])
#         eigen_gene_size.append(np.where(T_re==i)[0].shape[0])
    return eigen_X_w1,eigen_X_w2

In [None]:
eigen_X_w1,eigen_X_w2=eigen_gene(X_re,T_re)
eigen_dim=len(np.unique(T_re))
print(eigen_dim)

In [None]:
#---------eigen gene of each single cell----------------
cell_eigen_X1=np.zeros((X_re.shape[0],eigen_dim))
cell_eigen_X2=np.zeros((X_re.shape[0],eigen_dim))
for j in range(X_re.shape[0]):
    for k in range(len(eigen_X_w1)):
        cell_eigen_X1[j,k]=np.dot(eigen_X_w1[k],X_re[j,T_re==k+1])
        cell_eigen_X2[j,k]=np.dot(eigen_X_w2[k],X_re[j,T_re==k+1])

In [None]:
# for ei in range(eigen_dim):
#     plt.scatter(X_pca[:,0],X_pca[:,1],s=10,c=cell_eigen_X2[:,ei],cmap=plt.cm.jet)
#     plt.xlabel('PC1',fontsize=14)
#     plt.ylabel('PC2',fontsize=14)
#     clb=plt.colorbar()
#     clb.ax.set_ylabel('Eigen-gene'+str(ei+1),fontsize=14)
#     plt.savefig(result_path+'DG_eigengene'+str(ei+1)+'.png')
#     plt.show()

In [None]:
# for ci in range(eigen_dim):
#     plt.scatter(cell_vpt,cell_eigen_X1[:,ci],s=10,c=cell_vpt,cmap=plt.cm.jet)
#     plt.xlabel('pseudotime',fontsize=14)
#     plt.ylabel('Eigengene'+str(ci+1),fontsize=14)

#     # clb=plt.colorbar()
#     # clb.ax.set_ylabel('Eigen-gene',fontsize=14)
#     plt.savefig(result_path+'DG_eigengene'+str(ci+1)+'_vpt.png')
#     plt.show()

# scaler=MinMaxScaler()
# norm_cell_eigen_X=scaler.fit_transform(cell_eigen_X)

# plt.scatter(cell_vpt,norm_cell_eigen_X[:,0]-norm_cell_eigen_X[:,1],s=10,cmap=plt.cm.jet)
# plt.xlabel('pseudotime',fontsize=14)
# plt.show()

plt.scatter(cell_vpt,cell_eigen_X1[:,0])
plt.scatter(cell_vpt,cell_eigen_X1[:,1])
plt.scatter(cell_vpt,cell_eigen_X1[:,2])
plt.xlabel('velovity psuedotime',fontsize=14,weight='bold')
plt.ylabel('eigengene1',fontsize=14,weight='bold')
# plt.twinx()
# plt.scatter(cell_vpt,orc_smooth,c='blue',label='curvature')
# plt.twinx()
# plt.scatter(cell_vpt,zv1_smooth,c='orange',label='info velo')
plt.show()

plt.scatter(cell_vpt,cell_eigen_X2[:,0])
plt.scatter(cell_vpt,cell_eigen_X2[:,1])
plt.scatter(cell_vpt,cell_eigen_X2[:,2])
plt.xlabel('velovity psuedotime',fontsize=14,weight='bold')
plt.ylabel('eigengene2',fontsize=14,weight='bold')
# plt.twinx()
# plt.scatter(cell_vpt,orc_smooth,c='blue',label='curvature')
# plt.twinx()
# plt.scatter(cell_vpt,zv1_smooth,c='orange',label='info velo')
plt.show()

In [None]:
cell_eigen_X=np.hstack((cell_eigen_X1,cell_eigen_X2))

In [None]:
model = nn.Sequential(
    nn.Linear(cell_eigen_X.shape[1], 128),
    nn.ReLU(),
    nn.Linear(128,64),
    nn.ReLU(),
    nn.Linear(64,2*L),
)

# Define your loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = SGD(model.parameters(), lr=0.01, weight_decay=0.01, momentum=0.8)

x_in=torch.tensor(cell_eigen_X.astype(np.float32))
x_out=torch.tensor(latent_z.astype(np.float32))
# Train the model
for epoch in range(500):  # number of epochs
    # Forward pass
    output = model(x_in)
    loss = loss_fn(output,x_out) 
    # Backward pass and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print loss for each epoch
#         print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Initialize a tensor to store the gradients
pZ_pEg = np.zeros([cell_eigen_X.shape[0], L*2, cell_eigen_X.shape[1]])

# Compute the gradients
for i in range(cell_eigen_X.shape[0]):
    x0=torch.tensor(cell_eigen_X[i,:].astype(np.float32),requires_grad=True)
    z=model(x0)
    for j in range(2*L):
        x0.grad = None       
        z[j].backward(retain_graph=True)
        pZ_pEg[i,j,:] = x0.grad.detach()
print(pZ_pEg.shape)

In [None]:
eigen_gij=np.zeros((X.shape[0],eigen_dim*2,eigen_dim*2))
for l in range(X.shape[0]):

    for p in range(eigen_dim*2):
        for q in range(eigen_dim*2):
            for gi in range(L):
                eigen_gij[l,p,q]+=(1/latent_z[l,L+gi]**2)*pZ_pEg[l,gi,p]*pZ_pEg[l,gi,q]+\
                                  (2/latent_z[l,L+gi]**2)*pZ_pEg[l,gi+L,p]*pZ_pEg[l,gi+L,q]


plt.imshow(np.mean(eigen_gij,axis=0))
plt.colorbar()
plt.show()

print(np.mean(eigen_gij,axis=0).flatten())

In [None]:
# #-----relation between the eigh
# cs=['r','b','k','g','y']
# for j in range(eigen_dim*2):

#     plt.scatter(cell_vpt,eigen_gij[:,j,j],c=cell_vpt,cmap=plt.cm.jet)
# #         plt.yscale('log')
#     plt.show()

In [None]:
velo_re= velo[:,dg['leaves']]

velo_eigen1=np.zeros([X.shape[0], eigen_dim])
velo_eigen2=np.zeros([X.shape[0], eigen_dim])
for i in range(X.shape[0]):
    for k in range(len(eigen_X_w1)):
#         velo_eigen[i,k]=np.dot(eigen_X_w[k],velo_re[i,T_re==k+1])
        velo_eigen1[i,k]=np.dot(eigen_X_w1[k],velo_re[i,T_re==k+1])#*len(eigen_X_w[k])
        velo_eigen2[i,k]=np.dot(eigen_X_w2[k],velo_re[i,T_re==k+1])
        
velo_eigen=np.hstack((velo_eigen1,velo_eigen2))

In [None]:
z_velo=[]
zv2 = np.array([velo_eigen[l]@eigen_gij[l]@velo_eigen[l] for l in range(eigen_gij.shape[0])])
zv1 = np.sqrt(zv2)
zv1_smooth=smooth_func(zv1)

In [None]:
plt.scatter(adata.obsm['X_pca'][:,0],adata.obsm['X_pca'][:,1],s=5,c=zv1,cmap=plt.cm.jet)
plt.xlabel('PC1',fontsize=14)
plt.ylabel('PC2',fontsize=14)
clb=plt.colorbar()
clb.ax.set_ylabel('Information velocity',fontsize=14)
plt.savefig(result_path+'iv_eigen.png')
plt.show()

plt.scatter(adata.obsm['X_pca'][:,0],adata.obsm['X_pca'][:,1],s=5,c=zv1_smooth,cmap=plt.cm.jet)
plt.xlabel('PC1',fontsize=14)
plt.ylabel('PC2',fontsize=14)
clb=plt.colorbar()
clb.ax.set_ylabel('Information velocity',fontsize=14)
plt.savefig(result_path+'iv_smooth_eigen.png')
plt.show()

# plt.scatter(X_sigma[:,0],X_sigma[:,1], c=zv1, s=2, alpha=0.8,cmap=plt.cm.jet)
# plt.colorbar()
# plt.show()

In [None]:
plt.figure(dpi=600)
ln1 = plt.scatter(cell_vpt,crc_smooth,c='blue',label='curvature')
plt.xlabel('velovity psuedotime',fontsize=14,weight='bold')
plt.ylabel('curvature',fontsize=14,weight='bold')
plt.twinx()
ln2 = plt.scatter(cell_vpt,zv1_smooth,c='orange',label='info velo')
plt.ylabel('info velo',fontsize=14,weight='bold')
plt.legend(handles=[ln1,ln2])
plt.savefig(result_path+'cviv_eigen.png')
plt.show()

In [None]:
mask,x_ls,y_ls= kde_lowess(cell_vpt,crc_smooth)
plt.scatter(cell_vpt[mask], crc_smooth[mask],s=10,c=cell_vpt[mask],cmap=plt.cm.jet)
plt.plot(x_ls, y_ls, linewidth=4,color='red')
plt.xlabel('pseudotime',fontsize=14)
plt.ylabel('curvature',fontsize=14)
# sns.kdeplot(cell_vpt,cRc_smooth,bw_adjust=0.7)
clb=plt.colorbar()
clb.ax.set_ylabel('pseudotime',fontsize=14)
# plt.savefig(result_path+'vpt_cv_kde.png')
plt.show()

mask,x_ls,y_ls= kde_lowess(cell_vpt,zv1_smooth)
plt.scatter(cell_vpt[mask], zv1_smooth[mask],s=10,c=cell_vpt[mask],cmap=plt.cm.jet)
plt.plot(x_ls, y_ls, linewidth=4,color='red')
plt.xlabel('pseudotime',fontsize=14)
plt.ylabel('information velocity',fontsize=14)
# sns.kdeplot(cell_vpt,cRc_smooth,bw_adjust=0.7)
clb=plt.colorbar()
clb.ax.set_ylabel('pseudotime',fontsize=14)
plt.savefig(result_path+'vpt_iv_kde_eigen.png')
plt.show()

In [None]:
plt.figure(dpi=600)
mask,x_ls,y_ls= kde_lowess(cell_vpt,crc_smooth)
plt.scatter(cell_vpt[mask], crc_smooth[mask],s=10,c='blue',label='curvature')
plt.plot(x_ls, y_ls, linewidth=4,color='red')
plt.xlabel('velovity psuedotime',fontsize=14,weight='bold')
plt.ylabel('curvature',fontsize=14,weight='bold')
mask,x_ls,y_ls= kde_lowess(cell_vpt,zv1_smooth)
plt.scatter(cell_vpt[mask], zv1_smooth[mask],s=10,c='orange',label='info velo')
plt.plot(x_ls, y_ls, linewidth=4,color='red')
plt.legend(handles=[ln1,ln2])
plt.savefig(result_path+'cviv_kde_eigen.png')
plt.show()

In [None]:
plt.scatter(crc_eu,zv1,s=10)
plt.xlabel('curvature',fontsize=14)
plt.ylabel('information velocity',fontsize=14)
plt.show()

plt.scatter(crc_smooth,zv1_smooth,s=10)
plt.xlabel('curvature',fontsize=14)
plt.ylabel('information velocity',fontsize=14)
plt.show()

In [None]:
# fig, axes = plt.subplots(nrows=2,ncols=1,sharex=True)#,figsize=(8,6)) 
# fig.suptitle('Title of this figure')

# #subplot1
# axes1 = axes[0]
# ln1 = axes1.scatter(cell_vpt,smooth_func(orc), color='#1f77b4', ls='--', label='curvature')
# axes11 = axes1.twinx()
# ln2 = axes11.scatter(cell_vpt,smooth_func(zv1), color='#ff7f0e', ls='--', label='info velo')
# axes1.legend(handles=[ln1,ln2],loc=2)

# #subplot2
# axes2 = axes[1]
# axes2.scatter(cell_vpt,cell_eigen_X[:,0], ls='--')
# axes2.scatter(cell_vpt,cell_eigen_X[:,1], ls='--')
# axes2.scatter(cell_vpt,cell_eigen_X[:,2], ls='--')
# axes2.legend()
# axes2.set_xlabel('velovity pseudotime')

# #展示图片
# plt.show()