In [2]:
import MarsGT 
from MarsGT.conv import *
from MarsGT.egrn import *
from MarsGT.marsgt_model import *
from MarsGT.utils import *
import anndata as ad
from collections import Counter
import copy
import dill
from functools import partial
import json
import math
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import os
import pandas as pd
from operator import itemgetter
import random
import scipy.sparse as sp
from scipy.io import mmread
from scipy.sparse import hstack, vstack, coo_matrix
import seaborn as sb
from sklearn import metrics
from sklearn.cluster import KMeans
from sklearn.decomposition import IncrementalPCA
from sklearn.decomposition import SparsePCA
from sklearn.metrics import accuracy_score
from sklearn.metrics.cluster import normalized_mutual_info_score
import time
import torch
import torch.cuda as cuda
from torch import nn
from torch.autograd import Variable
import torch.distributions as D
import torch.nn.functional as F
import torch_geometric.data as Data
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, uniform
from torch_geometric.utils import softmax as Softmax
from torchmetrics.functional import pairwise_cosine_similarity
import warnings
from warnings import filterwarnings
import xlwt
import argparse
from tqdm import tqdm
import scanpy as sc

In [2]:
filterwarnings("ignore")
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser(description='Training GNN on gene cell graph')
parser.add_argument('--fi', type=int, default=0) # This parameter is used for the benchmark to specify the starting sequence number of the created files
parser.add_argument('--labsm', type=float, default=0.1) # The rate of LabelSmoothing
parser.add_argument('--wd', type=float, default=0.1) # The 'weight_decay' parameter is used to specify the strength of L2 regularization
parser.add_argument('--lr', type=float, default=0.0005) # learning rate
parser.add_argument('--n_hid', type=int, default=104) # The number of layers should be a multiple of 'n_head' in order to make any modifications
parser.add_argument('--nheads', type=int, default=8) # The 'heads' parameter represents the number of attention heads in the attention mechanism
parser.add_argument('--nlayers', type=int, default=3) # The number of layers in network
parser.add_argument('--cell_size', type=int, default=30) # The number of cells per subgraph (batch)
parser.add_argument('--neighbor', type=int, default=20) # The number of neighboring nodes to be selected for each cell in the subgraph
parser.add_argument('--egrn', type=bool, default=True) # Whether to output the Enhancer-Gene regulatory network
parser.add_argument('--output_file', type=str, default='PBMCs/output')
args = parser.parse_args([])

output_file = args.output_file
fi=args.fi
labsm = args.labsm
lr = args.lr
wd = args.wd
n_hid = args.n_hid
nheads = args.nheads
nlayers = args.nlayers
cell_size = args.cell_size
neighbor = args.neighbor
egrn = args.egrn

In [3]:
os.chdir('PBMCs/')
RNA_cell_label = pd.read_csv('Cell_names.tsv', sep='\t', header=None)
gene_peak = ad.read_mtx('Gene_Peak.mtx')
gene_cell = ad.read_mtx('Gene_Cell.mtx')
peak_cell = ad.read_mtx('Peak_Cell.mtx')
gene_names = pd.read_csv('Gene_names.tsv', sep='\t', header=None)
cell_names = pd.read_csv('Cell_names.tsv', sep='\t', header=None)
peak_names = pd.read_csv('Peak_names.tsv', sep='\t', header=None)
print('Files read successfully')

peak_cell.obs_names = peak_names[0]
peak_cell.var_names = cell_names[0]
gene_cell.obs_names = gene_names[0]
gene_cell.var_names = cell_names[0]
gene_peak.obs_names = gene_names[0]
gene_peak.var_names = peak_names[0]

#gene cell
RNA_matrix = gene_cell.X
#peak cell
ATAC_matrix = peak_cell.X
#gene peak
RP_matrix = gene_peak.X
Gene_Peak = gene_peak.X

# eGRN = hstack((RNA_matrix.transpose(), ATAC_matrix.transpose()*(RP_matrix.transpose())))
cell_num = RNA_matrix.shape[1]
gene_num = RNA_matrix.shape[0]
peak_num = ATAC_matrix.shape[0]

data = pd.read_csv("cell_emb10.csv", index_col='Unnamed: 0')
use_rep = data.values

In [4]:
if __name__ == "__main__":
    device = torch.device("cuda" if cuda.is_available() else "cpu")
    print('You will use : ',device)
    # clustering result by scanpy
    initial_pre = initial_clustering(RNA_matrix, custom_n_neighbors=30, n_pcs=40, custom_resolution=0.2, use_rep=use_rep)
    # number of every cluster
    cluster_ini_num = len(set(initial_pre)) 
    ini_p1 = [int(i) for i in initial_pre] 
    # partite the data into batches
    indices, Node_Ids, dic = batch_select_whole (RNA_matrix, ATAC_matrix, neighbor = [neighbor], cell_size=cell_size)
    n_batch = len(indices)
    
    # Reduce the dimensionality of features for cell, gene, and peak data.
    node_model = NodeDimensionReduction(RNA_matrix, ATAC_matrix, indices, ini_p1, n_hid=n_hid, n_heads=nheads, 
                                        n_layers=nlayers,labsm=labsm, lr=lr, wd=wd, device=device, num_types=3, num_relations=2, epochs=1)
    gnn,cell_emb,gene_emb,peak_emb,h = node_model.train_model(n_batch=n_batch)

    # Instantiate the MarsGT_model
    MarsGT_model = MarsGT(gnn=gnn, h=h, labsm=labsm, n_hid=n_hid, n_batch=n_batch, device=device,lr=lr,wd=wd, num_epochs=1)
    # Train the model
    MarsGT_gnn = MarsGT_model.train_model(indices=indices,RNA_matrix=RNA_matrix, ATAC_matrix=ATAC_matrix, Gene_Peak=RP_matrix, ini_p1=ini_p1)
    # The result of MarsGT
    MarsGT_result = MarsGT_pred(RNA_matrix, ATAC_matrix, RP_matrix, egrn=egrn, MarsGT_gnn=MarsGT_gnn, indices=indices, 
                        nodes_id=Node_Ids, cell_size=cell_size, device=device, gene_names=gene_names, peak_names=peak_names)
    
    # Save numpy arrays to files
    np.save(output_file + "/Node_Ids.npy", Node_Ids)
    np.save(output_file + "/pred.npy", MarsGT_result['pred_label'])
    np.save(output_file + "/cell_embedding.npy", MarsGT_result['cell_embedding'])

In [5]:
Counter(MarsGT_result['pred_label'])

In [6]:
MarsGT_result['egrn']