In [None]:
import numpy as np
import torch
import torch_geometric.nn
import torch_geometric.data as data
from torch_geometric.utils.convert import to_networkx

import torch.nn as nn

import scanpy as sc
import numpy as np
import pandas as pd

def sigmoid(x):
    return 1/(1+np.exp(-x))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from torch_geometric.nn import TransformerConv

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
set_seed(0)

# Here we use heart as one example

In [None]:
# for specific encoder/decoder
# tissue_list = { 
#                "heart":[233, 676, 783, 947,266, 223, 233, 978, 928, 852, 839, 733]}


tissue_list = { 
               "scrna_heart":['D4',
 'H2',
 'H3',
 'D6',
 'D2',
 'H7',
 'D11',
 'D3',
 'D1',
 'D5',
 'H4',
 'D7',
 'H6',
 'H5',
 'G19'], 
}

# construct graph batch
# based on simulation results
graph_list = []
cor_list = []
label_list = []
count = 0

for tissue in tissue_list.keys():
    for i in tissue_list[tissue]:
        print(i)
        pathway_count = f"./heart_atlas/{tissue}_" + i + "_rna_expression" + ".csv"
        pathway_matrix = f"./heart_atlas/{tissue}_" + i + "_pvalue" + ".csv"

        pd_adata_new =  pd.read_csv(pathway_count, index_col=0)
        correlation = pd.read_csv(pathway_matrix, index_col=0)
        cor_list.append(correlation)

        print(correlation.shape)
        print(pd_adata_new.shape)
        adata = sc.AnnData(pd_adata_new)

        adata_new = adata.copy()
        edges_new = np.array([np.nonzero(correlation.values)[0],np.nonzero(correlation.values)[1]])
        graph = data.Data(x=torch.FloatTensor(adata_new.X.copy()), edge_index=torch.FloatTensor(edges_new).long())

        vis = to_networkx(graph)
        graph.gene_list = pd_adata_new.index
        graph.show_index = tissue +"__" + str(i)

        graph_list.append(graph)
        label_list.append(tissue)
        
        count +=1

In [None]:
# with open("heart_rna_graph_list", "wb") as fp:
#     pickle.dump(graph_list, fp)
# with open("heart_rna_cor_list", "wb") as fp:
#     pickle.dump(cor_list, fp)
# with open("heart_rna_label_list", "wb") as fp:
#     pickle.dump(label_list, fp)

In [None]:
class MLPEncoder_Multiinput(torch.nn.Module):
    def __init__(self, out_channels, graph_list, label_list):
        super(MLPEncoder_Multiinput, self).__init__()
        self.activ = nn.Mish()
        
        conv_dict = {}
        for i in graph_list:
            conv_dict[i.show_index] = nn.Linear(i.x.shape[1], out_channels*4)
        self.convl1 = nn.ModuleDict(conv_dict)
        
    def forward(self, x, edge_index, show_index):
        x = self.convl1[show_index](x)
        x = self.activ(x)
        return x

In [None]:
class MLPEncoder_Commoninput(torch.nn.Module):
    def __init__(self, out_channels, graph_list, label_list):
        super(MLPEncoder_Commoninput, self).__init__()
        self.activ = nn.Mish()
        
        conv_dict_l2 = {}
        conv_dict_l3 = {}
        tissue_specific_list = list(set(label_list))
        
        for i in tissue_specific_list:
            conv_dict_l2[i] = nn.Linear(out_channels*4, out_channels*2)
            conv_dict_l3[i] = nn.Linear(out_channels*2, out_channels)
        self.convl2 = nn.ModuleDict(conv_dict_l2)
        self.convl3 = nn.ModuleDict(conv_dict_l3)
        
    
    def get_weight(self, show_index):
        return self.convl2[show_index.split('__')[0]].state_dict(), self.convl3[show_index.split('__')[0]].state_dict()
            
        
    def forward(self, x, edge_index, show_index):
        x = self.convl2[show_index.split('__')[0]](x)
        x = self.activ(x)
        return self.convl3[show_index.split('__')[0]](x)
    

In [None]:
class MLP_edge_Decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, graph_list):
        super(MLP_edge_Decoder, self).__init__()
        
        dec_dict = {}
        for i in graph_list:
            dec_dict[i.show_index] = torch.nn.Sequential(
                                              nn.Linear(in_channels,  out_channels)
                                             , nn.Mish(),
                                              nn.Linear(out_channels,  out_channels) 
                                              ,nn.Mish(),
                                              nn.Linear(out_channels,  out_channels)
                                             )
        self.MLP = nn.ModuleDict(dec_dict)
        
    def forward(self, x, show_index):
        x = self.MLP[show_index](x)
        return torch.sigmoid(x)

In [None]:
gene_encoder_is = MLPEncoder_Multiinput(32, graph_list, label_list).to(device)
gene_encoder_com = MLPEncoder_Commoninput(32, graph_list, label_list).to(device)

In [None]:
gene_decoder = MLP_edge_Decoder(1000,1000,graph_list).to(device)

In [None]:
optimizer_enc_is = torch.optim.Adam(gene_encoder_is.parameters(), lr=1e-4)
optimizer_enc_com = torch.optim.Adam(gene_encoder_com.parameters(), lr=1e-4)

In [None]:
optimizer_enc_com

In [None]:
optimizer_dec2 = torch.optim.Adam(gene_decoder.parameters(), lr=1e-3)

In [None]:
loss_f = nn.BCELoss()

In [None]:
Z = np.load("graph_sim_heartsctransform_new.npy")

In [None]:
for epoch in range(2000):
    
    loss = 0
    
    for i in range(0,len(graph_list)):
        
        optimizer_enc_is.zero_grad(set_to_none=True)
        optimizer_enc_com.zero_grad(set_to_none=True)
        optimizer_dec2.zero_grad(set_to_none=True)
        

        graph = graph_list[i].to(device)
        
        x = graph.x
        train_pos_edge_index = graph.edge_index.long()
        
        x = gene_encoder_is(x, train_pos_edge_index, graph.show_index)
        z = gene_encoder_com(x, train_pos_edge_index, graph.show_index)
        
        edge_adj = torch.FloatTensor(cor_list[i].values).to(device)
        
        adj = torch.matmul(z, z.t())
        edge_reconstruct = gene_decoder(adj, graph.show_index)
        
        loss = loss_f(edge_reconstruct.flatten(), edge_adj.flatten())
        
        if epoch % 200 ==0:
            print(loss)
    
        loss.backward()
        optimizer_enc_is.step()
        optimizer_enc_com.step()
        optimizer_dec2.step()
    print("epoch finish")

In [None]:
emb_list = []
gene_list = []
tissue_list = []

In [None]:
graph.show_index

In [None]:
with torch.no_grad():
    for i in range(0,len(graph_list)):
        graph = graph_list[i].to(device)
        x = graph.x
        train_pos_edge_index = graph.edge_index.long()
        
        x = gene_encoder_is(x, train_pos_edge_index, graph.show_index)
        z = gene_encoder_com(x, train_pos_edge_index, graph.show_index)
        
        emb_list.append(z.cpu().numpy())
        
        gene_list.append(graph.gene_list)
        tissue_list.append([graph.show_index for j in range(len(x))])

In [None]:
adata = sc.AnnData(np.concatenate(emb_list))

In [None]:
adata

In [None]:
adata.obs['gene'] = np.concatenate(gene_list)
adata.obs['tissue'] = np.concatenate(tissue_list)

In [None]:
adata.obs['tissue']

In [None]:
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color='tissue')

In [None]:
sc.tl.leiden(adata)

In [None]:
sc.pl.umap(adata, color='leiden')

In [None]:
adata

In [None]:
adata.write_h5ad("heart_global/heart_umi_sharedAutoencoder.h5ad")

In [None]:
class GCNEncoder_Multiinput(torch.nn.Module):
    def __init__(self, out_channels, graph_list, label_list):
        super(GCNEncoder_Multiinput, self).__init__()
        self.activ = nn.Mish()
        
        conv_dict = {}
        for i in graph_list:
            conv_dict[i.show_index] = torch_geometric.nn.Sequential('x, edge_index', [(TransformerConv(i.x.shape[1], out_channels, heads = 4),'x, edge_index -> x'),
                                                     (torch_geometric.nn.GraphNorm(out_channels*4), 'x -> x')])
        self.convl1 = nn.ModuleDict(conv_dict)
    
        
    def forward(self, x, edge_index, show_index):
        x = self.convl1[show_index](x, edge_index)
        x = self.activ(x)
        return x

In [None]:
class GCNEncoder_Commoninput(torch.nn.Module):
    def __init__(self, out_channels, graph_list, label_list):
        super(GCNEncoder_Commoninput, self).__init__()
        self.activ = nn.Mish()
        
        conv_dict_l2 = {}
        conv_dict_l3 = {}
        tissue_specific_list = list(set(label_list))
        
        for i in tissue_specific_list:
            conv_dict_l2[i] = torch_geometric.nn.Sequential('x, edge_index', [(TransformerConv(out_channels*4, out_channels, heads = 2),'x, edge_index -> x'),
                                                     (torch_geometric.nn.GraphNorm(out_channels*2), 'x -> x')])
            conv_dict_l3[i] = TransformerConv(out_channels*2, out_channels)
        self.convl2 = nn.ModuleDict(conv_dict_l2)
        self.convl3 = nn.ModuleDict(conv_dict_l3)
        
        self.gn = torch_geometric.nn.GraphNorm(out_channels*2)
        
        
    
    def get_weight(self, show_index):
        return self.convl2[show_index.split('__')[0]].state_dict(), self.convl3[show_index.split('__')[0]].state_dict()
            
        
    def forward(self, x, edge_index, show_index):
        x = self.convl2[show_index.split('__')[0]](x, edge_index)
        x = self.activ(x)
        return self.convl3[show_index.split('__')[0]](x, edge_index)

In [None]:
class MLP_edge_Decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, graph_list):
        super(MLP_edge_Decoder, self).__init__()
        
        dec_dict = {}
        for i in graph_list:
            dec_dict[i.show_index] = torch.nn.Sequential(
                                              nn.Linear(in_channels,  out_channels)
                                             , nn.Mish(),
                                              nn.Linear(out_channels,  out_channels) 
                                              ,nn.Mish(),
                                              nn.Linear(out_channels,  out_channels)
                                             )
        self.MLP = nn.ModuleDict(dec_dict)
        
    def forward(self, x, show_index):
        x = self.MLP[show_index](x)
        return torch.sigmoid(x)

In [None]:
from torch_geometric.nn import DataParallel

In [None]:
gene_encoder_is = GCNEncoder_Multiinput(32, graph_list, label_list).to(device)
gene_encoder_com = GCNEncoder_Commoninput(32, graph_list, label_list).to(device)

In [None]:
gene_decoder = MLP_edge_Decoder(1000,1000,graph_list).to(device)

In [None]:
print(f"Let's use {torch.cuda.device_count()} GPUs!")

In [None]:
optimizer_enc_is = torch.optim.Adam(gene_encoder_is.parameters(), lr=1e-4)
optimizer_enc_com = torch.optim.Adam(gene_encoder_com.parameters(), lr=1e-4)

In [None]:
optimizer_enc_com

In [None]:
optimizer_dec2 = torch.optim.Adam(gene_decoder.parameters(), lr=1e-3)

In [None]:
loss_f = nn.BCELoss()

In [None]:
# Z = np.load("graph_sim_cscore_global.npy")
# Z = np.load("graph_sim_cscore_global_withrna.npy")
# Z = np.load("graph_sim_cscore_global_withrna_withspatial.npy")

In [None]:
Z

In [None]:
!nvidia-smi

In [None]:
for epoch in range(2000):
    
    loss = 0
    
    for i in range(0,len(graph_list)):
        
        optimizer_enc_is.zero_grad(set_to_none=True)
        optimizer_enc_com.zero_grad(set_to_none=True)
        optimizer_dec2.zero_grad(set_to_none=True)
        

        graph = graph_list[i].to(device)
        
        x = graph.x
        train_pos_edge_index = graph.edge_index
        
        x = gene_encoder_is(x, train_pos_edge_index, graph.show_index)
        z = gene_encoder_com(x, train_pos_edge_index, graph.show_index)
        
        edge_adj = torch.FloatTensor(cor_list[i].values).to(device)
        
        adj = torch.matmul(z, z.t())
        edge_reconstruct = gene_decoder(adj, graph.show_index)
        
        loss = loss_f(edge_reconstruct.flatten(), edge_adj.flatten())
        
        if epoch % 200 ==0:
            print(loss)
                    
        loss.backward()
        
        optimizer_enc_is.step()
        optimizer_enc_com.step()
        optimizer_dec2.step()
    print("epoch finish")

In [None]:
emb_list = []
gene_list = []
tissue_list = []

In [None]:
graph.show_index

In [None]:
with torch.no_grad():
    for i in range(0,len(graph_list)):
        graph = graph_list[i].to(device)
        x = graph.x
        train_pos_edge_index = graph.edge_index.long()
        
        x = gene_encoder_is(x, train_pos_edge_index, graph.show_index)
        z = gene_encoder_com(x, train_pos_edge_index, graph.show_index)
        
        emb_list.append(z.cpu().numpy())
        
        gene_list.append(graph.gene_list)
        tissue_list.append([graph.show_index for j in range(len(x))])

In [None]:
gene_list

In [None]:
adata = sc.AnnData(np.concatenate(emb_list))

In [None]:
adata

In [None]:
adata.obs['gene'] = np.concatenate(gene_list)
adata.obs['tissue'] = np.concatenate(tissue_list)

In [None]:
adata.obs['tissue']

In [None]:
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color='tissue')

In [None]:
sc.tl.leiden(adata)

In [None]:
sc.pl.umap(adata, color='leiden')

In [None]:
adata.obs['tissue_new'] = [i.split("__")[0] for i in adata.obs['tissue']]

In [None]:
sc.pl.umap(adata, color='tissue_new')

In [None]:
adata.write_h5ad("heart_global/heart_umi_shareGAE")

# PCA

In [None]:
graph_list

In [None]:
emb_list = []
gene_list = []
tissue_list = []

In [None]:
len(tissue_list)

In [None]:
for i in range(0,len(graph_list)):
    graph = graph_list[i]
    adata = sc.AnnData(graph.x.cpu().numpy())
    sc.pp.scale(adata)
    sc.tl.pca(adata, 32)

    emb_list.append(adata.obsm['X_pca'])

    gene_list.append(graph.gene_list)
    tissue_list.append([graph.show_index for j in range(len(graph.x))])

In [None]:
adata = sc.AnnData(np.concatenate(emb_list))

In [None]:
adata

In [None]:
adata.obs['gene'] = np.concatenate(gene_list)
adata.obs['tissue'] = np.concatenate(tissue_list)

In [None]:
adata.obs['tissue']

In [None]:
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color='tissue')

In [None]:
sc.tl.leiden(adata)

In [None]:
sc.pl.umap(adata, color='leiden')

In [None]:
adata.obs['tissue_new'] = [i.split("__")[0] for i in adata.obs['tissue']]

In [None]:
sc.pl.umap(adata, color='tissue_new')

In [None]:
adata.write_h5ad("heart_global/heart_umi_PCA.h5ad")

# Gene2vec

In [None]:
import numpy as np
from torch_geometric.utils.convert import to_networkx

import numpy as np

import scanpy as sc
import numpy as np


def sigmoid(x):
    return 1/(1+np.exp(-x))

In [None]:
import gensim

In [None]:
####training parameters########
dimension = 32  # dimension of the embedding
num_workers = 32  # number of worker threads
sg = 1  # sg =1, skip-gram, sg =0, CBOW
max_iter = 10  # number of iterations
window_size = 1  # The maximum distance between the gene and predicted gene within a gene list
txtOutput = True

In [None]:
# model = gensim.models.Word2Vec(gene_pairs, vector_size=dimension, window=window_size, min_count=1, workers=num_workers,sg=sg, )
# model.train(gene_pairs,total_examples=model.corpus_count,epochs=max_iter)

In [None]:
# vector = model.wv['ENSG00000158747.15'] 

In [None]:
# vector_list = np.zeros((1000,32))
# gene_list = []
# for num,i in enumerate(edge_list.index):
#     vector_list[num] = model.wv[i] 
#     gene_list.append(i)

In [None]:
# gene_list

In [None]:
def generate_list(model, edge_list):
    vector_list = np.zeros((1000,32))
    gene_list = []
    for num,i in enumerate(edge_list.index):
        vector_list[num] = model.wv[i] 
        gene_list.append(i)
    
    return vector_list, gene_list

In [None]:
vec_list = []
gene_list_final = []

for num,i in enumerate(graph_list):
    edge_list = cor_list[num]
    
    nonz_index = np.nonzero(edge_list.values)
    
    gene_pairs = []
    for i,j in zip(nonz_index[0], nonz_index[1]):
        gene_pairs.append([edge_list.index[i], edge_list.columns[j]])
        
    model = gensim.models.Word2Vec(gene_pairs, vector_size=dimension, window=window_size, min_count=1, workers=num_workers,sg=sg, )
    model.train(gene_pairs,total_examples=model.corpus_count,epochs=max_iter)
    
    print('finish gene2vec training')
    vector_list = np.zeros((1000,32))
    gene_list = []
    for num,i in enumerate(edge_list.index):
        vector_list[num] = model.wv[i] 
        gene_list.append(i)
        
    vec_list.append(vector_list)
    gene_list_final.append(gene_list)
    

In [None]:
tissue_list = []
for graph in graph_list:
    label_list = [graph.show_index for i in range(len(graph.x))]
    tissue_list.append(label_list)

In [None]:
np.concatenate(np.array(tissue_list))

In [None]:
adata = sc.AnnData(np.concatenate(np.array(vec_list)))

In [None]:
adata.obs['tissue'] = np.concatenate(np.array(tissue_list))

In [None]:
adata.obs['gene'] = np.concatenate(np.array(gene_list_final))

In [None]:
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)

In [None]:
sc.tl.leiden(adata)

In [None]:
sc.pl.umap(adata, color='leiden')

In [None]:
sc.pl.umap(adata, color='tissue')

In [None]:
adata.write_h5ad('heart_global/heart_umi_gene2vec.h5ad')

# scBERT

In [None]:
# Please see the codes of scBERT
# https://github.com/TencentAILabHealthcare/scBERT

# GIANT

In [None]:
# Please see the codes of GIANT
# https://github.com/chenhcs/GIANT

# GAE/VGAE

In [None]:
# Please see the seperated file for VGAE/GAE

# SUGRL

In [None]:
# Please see the codes of SUGRL
# https://github.com/YujieMo/SUGRL

# GPS

In [None]:
# Please see the codes of GPS
# https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html#torch_geometric.nn.conv.GPSConv

# Graphormer

In [None]:
# Please see the codes of Graphormer
# https://github.com/microsoft/Graphormer