In [1]:
# data process
import os
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.utils.convert import to_networkx
import networkx as nx
from sklearn.preprocessing import MinMaxScaler

data_dir='/data/NK/'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_adj=pd.read_csv(os.path.join(data_dir, f'ppi_of_NK.csv'),sep=",") 
signaling=pd.read_csv(os.path.join(data_dir, f'signalingLayer_of_NK.csv'),sep=",",index_col='gene_id') 
data_adj.head(10)

Unnamed: 0,1,368,8754,5290,3172,3164,204,14,847,183,...,375611,90139,9331,283358,8708,9227,124961,2529,117156,5251
1,0,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
368,1,0,0,0,1,1,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
8754,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5290,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3172,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3164,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
204,0,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
14,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
847,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
183,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [3]:
ed=sp.coo_matrix(data_adj) #Convert the adjacency matrix into a sparse matrix of the coo form
indices=np.vstack((ed.row,ed.col))
index=torch.LongTensor(indices)
values=torch.FloatTensor(ed.data)
edge_index=torch.sparse_coo_tensor(index,values,ed.shape)
edge_index=edge_index.coalesce().indices() #Extract indices
edge_index

tensor([[   0,    0,    0,  ..., 8508, 8509, 8510],
        [   1,    2,    3,  ..., 3769, 8474, 6051]])

In [4]:
information_of_cells=pd.read_csv(os.path.join(data_dir, f'information_of_NK.csv'),sep=",")#
information_of_cells.head(10)

Unnamed: 0,donor_id,age,sex,cell_type,orig.ident
meta100,689_690,59,male,natural killer cell,onek1k
meta101,689_690,59,male,natural killer cell,onek1k
meta102,689_690,59,male,natural killer cell,onek1k
meta103,689_690,59,male,natural killer cell,onek1k
meta104,689_690,59,male,natural killer cell,onek1k
meta105,689_690,59,male,natural killer cell,onek1k
meta106,689_690,59,male,natural killer cell,onek1k
meta107,689_690,59,male,natural killer cell,onek1k
meta108,689_690,59,male,natural killer cell,onek1k
meta109,689_690,59,male,natural killer cell,onek1k


In [5]:
def read_single_csv(input_path):
    df_chunk=pd.read_csv(input_path,sep=",",chunksize=3000)  #The hunksize parameter enables batch reads (this parameter is used to set how many rows of data are read into each batch)
    res_chunk=[]
    for chunk in df_chunk:
        res_chunk.append(chunk)
    res_df=pd.concat(res_chunk)
    return res_df

In [6]:
Log_normalized_matrix_of_naive_cd4=read_single_csv(os.path.join(data_dir, f'expression_of_NK.csv'))
Log_normalized_matrix_of_naive_cd4.head(5)

Unnamed: 0,1,368,8754,5290,3172,3164,204,14,847,183,...,375611,90139,9331,283358,8708,9227,124961,2529,117156,5251
meta100,0.0,0.0,0.0,0.0,0.0,0.0,0.096977,0.204885,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta101,0.190978,0.0,0.0,0.206233,0.0,0.101229,0.188599,0.297983,0.101229,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta102,0.10591,0.0,0.0,0.0,0.0,0.087403,0.096477,0.206901,0.088311,0.0,...,0.0,0.0,0.103732,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta103,0.190978,0.0,0.0,0.0,0.0,0.060198,0.112436,0.302726,0.088311,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta104,0.090321,0.0,0.0,0.0,0.0,0.0,0.133303,0.195215,0.247647,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [7]:
transfer = MinMaxScaler(feature_range=(0, 1))
data_i = transfer.fit_transform(Log_normalized_matrix_of_naive_cd4)
Log_normalized_matrix_of_naive_cd4=pd.DataFrame(data_i,index=Log_normalized_matrix_of_naive_cd4.index,columns=Log_normalized_matrix_of_naive_cd4.columns)
Log_normalized_matrix_of_naive_cd4.head(5)

Unnamed: 0,1,368,8754,5290,3172,3164,204,14,847,183,...,375611,90139,9331,283358,8708,9227,124961,2529,117156,5251
meta100,0.0,0.0,0.0,0.0,0.0,0.0,0.140142,0.233978,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta101,0.294112,0.0,0.0,0.356121,0.0,0.262787,0.272547,0.340296,0.138441,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta102,0.163104,0.0,0.0,0.0,0.0,0.226895,0.139419,0.236281,0.120774,0.0,...,0.0,0.0,0.346153,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta103,0.294112,0.0,0.0,0.0,0.0,0.156272,0.162482,0.345712,0.120774,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
meta104,0.139097,0.0,0.0,0.0,0.0,0.0,0.192637,0.222935,0.338683,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [8]:
len(Log_normalized_matrix_of_naive_cd4)

11452

In [9]:
cells=pd.concat([Log_normalized_matrix_of_naive_cd4,information_of_cells.iloc[:,1]],axis=1)
cells.head(5)

Unnamed: 0,1,368,8754,5290,3172,3164,204,14,847,183,...,90139,9331,283358,8708,9227,124961,2529,117156,5251,age
meta100,0.0,0.0,0.0,0.0,0.0,0.0,0.140142,0.233978,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,59
meta101,0.294112,0.0,0.0,0.356121,0.0,0.262787,0.272547,0.340296,0.138441,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,59
meta102,0.163104,0.0,0.0,0.0,0.0,0.226895,0.139419,0.236281,0.120774,0.0,...,0.0,0.346153,0.0,0.0,0.0,0.0,0.0,0.0,0.0,59
meta103,0.294112,0.0,0.0,0.0,0.0,0.156272,0.162482,0.345712,0.120774,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,59
meta104,0.139097,0.0,0.0,0.0,0.0,0.0,0.192637,0.222935,0.338683,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,59


In [10]:
cells.describe()

Unnamed: 0,1,368,8754,5290,3172,3164,204,14,847,183,...,90139,9331,283358,8708,9227,124961,2529,117156,5251,age
count,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,...,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0,11452.0
mean,0.100811,0.001269,0.030494,0.098143,0.001499,0.060871,0.180552,0.181779,0.164394,0.000742,...,0.008377,0.046287,0.011506,0.000165,0.00393,0.014115,0.049619,0.000606,0.00191,66.457475
std,0.131556,0.031028,0.106064,0.138048,0.033351,0.133637,0.165444,0.149131,0.154798,0.025336,...,0.062114,0.131219,0.059529,0.012521,0.051987,0.077618,0.102299,0.022086,0.037966,15.866244
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,19.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.086913,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,60.0
50%,0.0,0.0,0.0,0.0,0.0,0.0,0.161024,0.150697,0.146433,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,69.0
75%,0.175879,0.0,0.0,0.187482,0.0,0.0,0.28487,0.273721,0.26408,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,78.0
max,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,97.0


In [11]:
data_list=[]
age_list=[]
#type(Data) = <class 'torch_geometric.data.data.Data'>
#Iterate over each row
for index, row in cells.iterrows(): 
    data=pd.DataFrame(row).iloc[:-1,:]
    matrix=data.to_numpy()
    x=torch.tensor(matrix,dtype=torch.float)
    y= torch.as_tensor(row.iloc[-1]).type(torch.int64)
    data_pyg=Data(x=x,y=y,edge_index=edge_index)#Convert to pyg data format
    data_list.append(data_pyg)


In [12]:
torch.save(data_list,'NK_pyg.pt') #Save the processed data file

In [13]:
# train
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.nn as pyg_nn
import os
import pandas as pd
from scipy.stats import pearsonr
from torchmetrics import MeanAbsoluteError,PearsonCorrCoef
import logging
import shutil

data_dir='/data/NK/'

## setting logging
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger()
logger.addHandler(logging.FileHandler('NK.log', 'a'))

print = logger.info

data_load = torch.load('./NK_pyg.pt')
# print((len(data_load)))

data_adj=pd.read_csv(os.path.join(data_dir, f'ppi_of_NK.csv'),sep=",")
in_dim=len(data_adj)

train_dataset = data_load[0:3487] # train data
test_dataset = data_load[3487:] # test data
# print((len(train_dataset),len(test_dataset)))
class DeepRNAGenConv(nn.Module):
    
    def __init__(self,in_dim, node_features_dim=None, node_embedding_dim=None, num_layers=None, node_output_features_dim=None, convolution_dropout=0.1, dense_dropout=0.0):
        super(DeepRNAGenConv, self).__init__()
        
        self.node_encoder = torch.nn.Linear(node_features_dim, node_embedding_dim)
        
        self.gcn_layers = torch.nn.ModuleList()
        self.in_dim=in_dim
        self.hidden1 = nn.Linear(in_features=in_dim, out_features=1000, bias=True)
        self.hidden2 = nn.Linear(1000, 100)

        for i in range(num_layers):
            convolution =  pyg_nn.GENConv(in_channels=node_embedding_dim, out_channels=node_embedding_dim)
            norm = torch.nn.LayerNorm(node_embedding_dim)
            activation = torch.nn.ReLU()
            layer = pyg_nn.DeepGCNLayer(conv=convolution, norm=norm, act=activation, dropout=convolution_dropout)
            self.gcn_layers.append(layer)

        self.dropout = torch.nn.Dropout(p=dense_dropout)
        self.decoder = torch.nn.Linear(node_embedding_dim, node_output_features_dim)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        xs=x.view(-1,self.in_dim)
        xs= F.leaky_relu(self.hidden1(xs))
        xs = F.leaky_relu(self.hidden2(xs))
        x = self.node_encoder(x)
        for layer in self.gcn_layers:
            x = layer(x, edge_index)
        x = self.dropout(x)
        x = pyg_nn.global_mean_pool(x, data.batch)
        x=0.5*xs+0.5*x
        x = self.decoder(x)
        return x[:, 0]



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
epochs = 10  
lr = 0.001  
num_node_features = 1 
train_loader = torch_geometric.loader.DataListLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch_geometric.loader.DataListLoader(test_dataset, batch_size=64, shuffle=False)
model = pyg_nn.DataParallel( DeepRNAGenConv(
    in_dim=in_dim,
    node_features_dim=1,
    node_embedding_dim=100,
    num_layers=2,
    node_output_features_dim=1,
    convolution_dropout=0.2,
    dense_dropout=0.2)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.7, patience=3,
                                                       min_lr=0.0001)
loss_function =  torch.nn.MSELoss(reduction='mean')

mean_absolute_error = MeanAbsoluteError().to(device)
pearson = PearsonCorrCoef().to(device)

folder_path="./models"

if os.path.exists(folder_path):
    shutil.rmtree(folder_path)
    os.mkdir(folder_path)
else:
    os.mkdir(folder_path)
def run():
  for epoch in range(epochs):
    lr = scheduler.optimizer.param_groups[0]['lr']
    model.train()
    loss_train = 0  
    mae_total=0
    pearson_total=0
    count=0
    for data in train_loader:
        optimizer.zero_grad()
        pred = model(data)
        target = [data.y.float().unsqueeze(-1) for data in data]
        target = torch.cat((target)).to(pred.device)
        loss = loss_function(pred, target)
        loss_train += loss.item()
        mae=mean_absolute_error(pred, target)
        result=pearson(pred, target )
        mae_total+=mae.item()
        pearson_total+=result.item()

        loss.backward()
        optimizer.step()
        count=count+1
    loss_train /= count
    mae_total /= count
    pearson_total /= count
    print(("【EPOCH: 】%s" % str(epoch)))
    print(('Train  Loss',loss_train, 'Train MAE: ',mae_total,'Train PCC: ',pearson_total))
    test(epoch)

  print(('【Finished Training!】'))

best_test_loss=0
def test(i):
  global best_test_loss
  model.eval()
  with torch.no_grad():
    loss_train = 0
    mae_train=0
    pearson_train=0
    count=0
    for data in train_loader:
        pred = model(data)
        target = [data.y.float().unsqueeze(-1) for data in data]
        target = torch.cat((target)).to(pred.device)
        loss = loss_function(pred, target)
        loss_train += loss.item()
        mae=mean_absolute_error(pred, target)
        mae_train += mae.item()
        result=pearson(pred, target )
        pearson_train += result.item()
        count=count+1

    mae_train/= count
    pearson_train/= count
    loss_train /= count

    print(('Val Loss: {:.4f}'.format(loss_train),'Val MAE: {:.4f}'.format(mae_train),' Val PCC: ',pearson_train))


    loss_test = 0  
    mae_test=0
    count=0
    total_target=torch.zeros(1, 1).to(device)
    total_pred=torch.zeros(1, 1).to(device)
    for data in test_loader:
        pred = model(data)
        target = [data.y.float().unsqueeze(-1) for data in data]
        target = torch.cat((target)).to(pred.device)

        loss = loss_function(pred, target)
        loss_test += loss.item()
        mae=mean_absolute_error(pred, target)
        mae_test += mae.item()
        count=count+1
        total_pred=torch.cat((total_pred,pred.unsqueeze(-1)),dim=0)
        total_target=torch.cat((total_target,target.unsqueeze(-1)),dim=0)
    pred_target=torch.cat((total_pred,total_target),dim=1)
    pred_target=pred_target.cpu().numpy()[1:,:]
    df_pred_target=pd.DataFrame(pred_target,columns=['pred','target'])
    loss_test /= count
    mae_test/= count
    scheduler.step(mae_test)
    print(('Test  Loss: {:.4f}'.format(loss_test), 'Test MAE:{:.4f}'.format(mae_test),'Test pcc',pearsonr(df_pred_target['pred'],df_pred_target['target'])))
    # torch.save(model.state_dict(), folder_path+'/'+'model_epoch{}.pth'.format(i))
    if i==0:
        best_test_loss=loss_test 
    if best_test_loss>loss_test:
        best_test_loss=loss_test 
        torch.save(model.state_dict(), folder_path+'/'+'NK_best_model.pth')
if __name__ == '__main__':
   avg_mse = []
   avg_mae=[]
   run()

【EPOCH: 】0
('Train  Loss', 807.6242703524503, 'Train MAE: ', 21.949475929953834, 'Train PCC: ', 0.22604991651394152)
('Val Loss: 326.6633', 'Val MAE: 14.9447', ' Val PCC: ', 0.26700470374304464)
('Test  Loss: 312.2223', 'Test MAE:14.9399', 'Test pcc', (0.09131019590415981, 3.2162466167879396e-16))
【EPOCH: 】1
('Train  Loss', 321.2649547230114, 'Train MAE: ', 14.768905761025168, 'Train PCC: ', 0.30188410332934423)
('Val Loss: 325.5215', 'Val MAE: 14.9597', ' Val PCC: ', 0.3820777101950212)
('Test  Loss: 368.1000', 'Test MAE:16.7388', 'Test pcc', (0.13466393095090756, 1.4944937456130078e-33))
【EPOCH: 】2
('Train  Loss', 296.67744556773795, 'Train MAE: ', 14.198814027959651, 'Train PCC: ', 0.4316959540952336)
('Val Loss: 272.1084', 'Val MAE: 13.6207', ' Val PCC: ', 0.5288097804242914)
('Test  Loss: 263.7893', 'Test MAE:13.5660', 'Test pcc', (0.19069113585100045, 4.1773110035904336e-66))
【EPOCH: 】3
('Train  Loss', 255.7736591685902, 'Train MAE: ', 13.180619499900125, 'Train PCC: ', 0.6010794

In [14]:
###test
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.nn as pyg_nn
import os
import numpy as np
import pandas as pd
from torchmetrics import MeanAbsoluteError,PearsonCorrCoef
from scipy.stats import pearsonr
import torch.nn.functional as F
import os
data_dir='/data/NK/'
def read_single_csv(input_path):
    df_chunk=pd.read_csv(input_path,sep=",",chunksize=3000)  #The hunksize parameter enables batch reads (this parameter is used to set how many rows of data are read into each batch)
    res_chunk=[]
    for chunk in df_chunk:
        res_chunk.append(chunk)
    res_df=pd.concat(res_chunk)
    return res_df

data_load = torch.load('./NK_pyg.pt')

Log_normalized_matrix_of_naive_cd4=read_single_csv(os.path.join(data_dir, f'expression_of_NK.csv'))
id=Log_normalized_matrix_of_naive_cd4.iloc[3487:,:].index
data_adj=pd.read_csv(os.path.join(data_dir, f'ppi_of_NK.csv'),sep=",")
in_dim=len(data_adj)
test_dataset = data_load[3487:] # test datasets
print(len(test_dataset))

class DeepRNAGenConv(nn.Module):
    
    def __init__(self, in_dim,node_features_dim=None, node_embedding_dim=None, num_layers=None, node_output_features_dim=None, convolution_dropout=0.1, dense_dropout=0.0):
        super(DeepRNAGenConv, self).__init__()
        
        self.node_encoder = torch.nn.Linear(node_features_dim, node_embedding_dim)
        
        self.gcn_layers = torch.nn.ModuleList()
        self.in_dim=in_dim
        self.hidden1 = nn.Linear(in_features=in_dim, out_features=1000)
        
        self.drop1 = nn.Dropout(0.0)
        self.drop2 = nn.Dropout(0.0)
        
        self.hidden2 = nn.Linear(1000, 100)

        for i in range(num_layers):
            convolution =  pyg_nn.GENConv(in_channels=node_embedding_dim, out_channels=node_embedding_dim)
            norm = torch.nn.LayerNorm(node_embedding_dim)
            activation = torch.nn.ReLU()
            layer = pyg_nn.DeepGCNLayer(conv=convolution, norm=norm, act=activation, dropout=convolution_dropout)
            self.gcn_layers.append(layer)

        self.dropout = torch.nn.Dropout(p=dense_dropout)
        self.decoder = torch.nn.Linear(node_embedding_dim, node_output_features_dim)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        xs=x.view(-1,self.in_dim)
        xs=F.leaky_relu( self.drop1(self.hidden1(xs)))
        xs=F.leaky_relu( self.drop2(self.hidden2(xs)))
        x = self.node_encoder(x)
        for layer in self.gcn_layers:
            x = layer(x, edge_index)
        x = self.dropout(x)
        x = pyg_nn.global_mean_pool(x, data.batch)
        x=0.5*xs+0.5*x
        x= self.decoder (x)
        return x[:, 0]
    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
epochs = 20  
lr = 0.001  # LR
num_node_features = 1 
num_classes =1 
test_loader = torch_geometric.loader.DataListLoader(test_dataset, batch_size=64, shuffle=False)
net=DeepRNAGenConv(
    in_dim=in_dim,
    node_features_dim=1,
    node_embedding_dim=100,
    num_layers=2,
    node_output_features_dim=1,
    convolution_dropout=0.2,
    dense_dropout=0.2)
net.initialize()
model = pyg_nn.DataParallel(net).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) 
loss_function =  torch.nn.MSELoss(reduction='mean')
model.load_state_dict(torch.load('models/NK_best_model.pth')) #load the best model

mean_squared_error = MeanAbsoluteError().to(device)
pearson = PearsonCorrCoef().to(device)

def run():
 model.eval()
 with torch.no_grad():

    loss_test = 0 
    mae_test=0
    pearson_test=0
    count=0
    total_target=torch.zeros(1, 1).to(device)
    total_pred=torch.zeros(1, 1).to(device)
    for data in test_loader:
        pred = model(data)
        target = [data.y.float().unsqueeze(-1) for data in data]
        target = torch.cat((target)).to(pred.device)
        loss = loss_function(pred, target)
        loss_test += loss.item()
        mae=mean_squared_error(pred, target)
        mae_test += mae.item()
        count=count+1
        total_pred=torch.cat((total_pred,pred.unsqueeze(-1)),dim=0)
        total_target=torch.cat((total_target,target.unsqueeze(-1)),dim=0)
    pred_target=torch.cat((total_pred,total_target),dim=1)
    pred_target=pred_target.cpu().numpy()[1:,:]
    df_pred_target=pd.DataFrame(pred_target,index=id,columns=['pred','target'])
    df_pred_target.to_csv("NK_agePrediction.csv")
    print(( df_pred_target.head(10),len(pred_target)))
    print(("Test PCC:",pearsonr(df_pred_target['pred'],df_pred_target['target'])))

    loss_test /= count
    mae_test/= count
    pearson_test/= count

    print(('Test  Loss: {:.4f}'.format(loss_test), 'Test MAE{:.4f}'.format(mae_test),'Test PCC:',pearsonr(df_pred_target['pred'],df_pred_target['target'])))
    
    return loss_test, mae_test
if __name__ == '__main__':
   loss_test, mae_test=run()
   print(('RMSE_test: {:.4f} , mae_test {:.4f}'.format( np.sqrt(loss_test), mae_test)))

7965
(             pred  target
meta1   76.007973    88.0
meta2   75.887177    88.0
meta3   84.207802    88.0
meta4   72.617401    88.0
meta5   72.927551    88.0
meta6   73.741524    88.0
meta7   80.561523    88.0
meta8   70.731133    88.0
meta9   71.290337    88.0
meta10  77.289612    88.0, 7965)
('Test PCC:', (0.3241779692134997, 2.7992905773537834e-194))
('Test  Loss: 196.2529', 'Test MAE11.1311', 'Test PCC:', (0.3241779692134997, 2.7992905773537834e-194))
RMSE_test: 14.0090 , mae_test 11.1311


In [15]:
#xai
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_geometric.nn as pyg_nn
import tqdm
import os
import pandas as pd
from torchmetrics import MeanAbsoluteError,PearsonCorrCoef
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.loader import DataLoader
import torch_geometric.nn as pyg_nn
from explain import Explainer, PGExplainer
import torch.nn.functional as F
from tqdm import tqdm
import os

data_dir='/data/NK/'
data_adj=pd.read_csv(os.path.join(data_dir, f'ppi_of_NK.csv'),sep=",")
in_dim=len(data_adj)

data_load = torch.load('./NK_pyg.pt')

test_dataset = data_load[0:3487] 
print((len(test_dataset)))

class DeepRNAGenConv(nn.Module):
    
    def __init__(self,in_dim, node_features_dim=None, node_embedding_dim=None, num_layers=None, node_output_features_dim=None, convolution_dropout=0.1, dense_dropout=0.0):
        super(DeepRNAGenConv, self).__init__()
        
        self.node_encoder = torch.nn.Linear(node_features_dim, node_embedding_dim)
        self.in_dim=in_dim
        self.gcn_layers = torch.nn.ModuleList()
        self.hidden1 = nn.Linear(in_features=in_dim, out_features=1000, bias=True)
        self.hidden2 = nn.Linear(1000, 100)

        for i in range(num_layers):
            convolution =  pyg_nn.GENConv(in_channels=node_embedding_dim, out_channels=node_embedding_dim)
            norm = torch.nn.LayerNorm(node_embedding_dim)
            activation = torch.nn.ReLU()
            layer = pyg_nn.DeepGCNLayer(conv=convolution, norm=norm, act=activation, dropout=convolution_dropout)
            self.gcn_layers.append(layer)

        self.dropout = torch.nn.Dropout(p=dense_dropout)
        self.decoder = torch.nn.Linear(node_embedding_dim, node_output_features_dim)
    
    def forward(self, x, edge_index, batch):
        xs=x.view(-1,self.in_dim)
        xs= F.leaky_relu(self.hidden1(xs))
        xs = F.leaky_relu(self.hidden2(xs))
        x = self.node_encoder(x)
        for layer in self.gcn_layers:
            x = layer(x, edge_index)
        x = self.dropout(x)
        x = pyg_nn.global_mean_pool(x,batch)
        x=0.5*xs+0.5*x
        x = self.decoder(x)
        return x[:, 0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 50
num_node_features = 1 
test_loader = torch_geometric.loader.DataLoader(test_dataset, batch_size=64, shuffle=False)
model = pyg_nn.DataParallel( DeepRNAGenConv(
    in_dim=in_dim,
    node_features_dim=1,
    node_embedding_dim=100,
    num_layers=2,
    node_output_features_dim=1,
    convolution_dropout=0.2,
    dense_dropout=0.2)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.7, patience=3,
                                                       min_lr=0.0001)
loss_function =  torch.nn.MSELoss(reduction='mean')# loss function
model.load_state_dict(torch.load('models/NK_best_model.pth', map_location=device))
model.to(device)
if isinstance(model, pyg_nn.DataParallel):
    model = model.module


mean_squared_error = MeanAbsoluteError().to(device)
pearson = PearsonCorrCoef().to(device)

def run():
 lr = scheduler.optimizer.param_groups[0]['lr']
 model.eval()
 with torch.no_grad():

    loss_test = 0
    mae_test=0
    pearson_test=0
    count=0
    print(model)
    for data in tqdm(test_loader):
        pred = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))
        target=data.y.float().to(pred.device)
        loss = loss_function(pred, target)
        loss_test += loss.item()
        mae=mean_squared_error(pred, target )
        mae_test += mae.item()
        result=pearson(pred, target )
        pearson_test += result.item()
        count=count+1
    
    loss_test /= count
    mae_test/= count
    scheduler.step(mae_test)
    pearson_test/= count
    print(('Test  Loss: {:.4f}'.format(loss_test), 'Test MAE:{:.4f}'.format(mae_test),'Test PCC:{:.4f}'.format(pearson_test)))

explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=5, lr=0.003).to(device),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    ),
    threshold_config=dict(threshold_type='topk', value=100),
)
total_list=[]
total_edges=[]
def age_explain():
    data_loader = DataLoader(test_dataset,batch_size=1)
    data_loader1 = DataLoader(test_dataset,batch_size=1,shuffle=True)

    for epoch in range(5):
            t_loss=0
            count=0
            for data in data_loader:  # Indices to train against.
                data=data.to(device)
                loss = explainer.algorithm.train(epoch, model, data.x, data.edge_index, batch=data.batch,target=data.y.float())
                t_loss += loss
                count=count+1
            t_loss/= count
            print(('PGExlianer  Loss: {:.4f}'.format(t_loss)))
    f=open('NK_explain_nodes_index.txt','w',encoding='utf-8')
    f1=open('NK_explain_edges_index.txt','w',encoding='utf-8')
    for data in tqdm(data_loader):
        data=data.to(device)
        explanation = explainer(data.x, data.edge_index,batch=data.batch,target=data.y.float())
        path = 'subgraph_EAT_node11.png'
        node_list,edge_list=explanation.visualize_graph(path,'networkx')
        f.writelines(str(node_list).strip("[]")) #nodes
        f.write('\n')
        f1.writelines(str(edge_list).strip("[]")) #edges
        f1.write('\n')
    f.close()
    f1.close()
    print("explain end......")

if __name__ == '__main__':
    run()
    age_explain()


3487
DeepRNAGenConv(
  (node_encoder): Linear(in_features=1, out_features=100, bias=True)
  (gcn_layers): ModuleList(
    (0): DeepGCNLayer(block=res+)
    (1): DeepGCNLayer(block=res+)
  )
  (hidden1): Linear(in_features=8511, out_features=1000, bias=True)
  (hidden2): Linear(in_features=1000, out_features=100, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (decoder): Linear(in_features=100, out_features=1, bias=True)
)
100%|██████████| 55/55 [00:13<00:00,  3.98it/s]
('Test  Loss: 125.0052', 'Test MAE:8.8510', 'Test PCC:0.8129')
PGExlianer  Loss: 257.9779
PGExlianer  Loss: 190.6352
PGExlianer  Loss: 190.6016
PGExlianer  Loss: 190.6001
PGExlianer  Loss: 190.6001
100%|██████████| 3487/3487 [00:33<00:00, 105.24it/s]
explain end......


In [16]:
# 
id=Log_normalized_matrix_of_naive_cd4.iloc[:3487,:].index
id

Index(['meta100', 'meta101', 'meta102', 'meta103', 'meta104', 'meta105',
       'meta106', 'meta107', 'meta108', 'meta109',
       ...
       'meta11443', 'meta11444', 'meta11445', 'meta11446', 'meta11447',
       'meta11448', 'meta11449', 'meta11450', 'meta11451', 'meta11452'],
      dtype='object', length=3487)

In [17]:
node_exaction=pd.read_csv(f'NK_explain_edges_index.txt',header=None,sep='\n')
node_exaction.head(5)

Unnamed: 0,0
0,"6, 1827, 1.0], [6, 2700, 1.0], [6, 4458, 1.0],..."
1,"0, 3, 1.0], [3, 0, 1.0], [3, 83, 1.0], [3, 156..."
2,"5, 2195, 1.0], [5, 3206, 1.0], [5, 3516, 1.0],..."
3,"5, 383, 1.0], [5, 2195, 1.0], [5, 3206, 1.0], ..."
4,"7, 1160, 1.0], [11, 50, 1.0], [11, 83, 1.0], [..."


In [18]:
node_exaction = node_exaction[0].str.split(',',expand=True) 
node_exaction.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
0,6,1827,1.0],[6,2700,1.0],[6,4458,1.0],[6,...,1.0],[11,5615,1.0],[29,63,1.0],[29,1014,1.0
1,0,3,1.0],[3,0,1.0],[3,83,1.0],[3,...,1.0],[11,895,1.0],[11,900,1.0],[11,901,1.0
2,5,2195,1.0],[5,3206,1.0],[5,3516,1.0],[5,...,1.0],[11,1780,1.0],[11,1799,1.0],[11,1956,1.0
3,5,383,1.0],[5,2195,1.0],[5,3206,1.0],[5,...,1.0],[11,1363,1.0],[11,1383,1.0],[11,1630,1.0
4,7,1160,1.0],[11,50,1.0],[11,83,1.0],[11,...,1.0],[19,2158,1.0],[19,3852,1.0],[19,4329,1.0


In [19]:
node_exaction.index=id
node_exaction.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
meta100,6,1827,1.0],[6,2700,1.0],[6,4458,1.0],[6,...,1.0],[11,5615,1.0],[29,63,1.0],[29,1014,1.0
meta101,0,3,1.0],[3,0,1.0],[3,83,1.0],[3,...,1.0],[11,895,1.0],[11,900,1.0],[11,901,1.0
meta102,5,2195,1.0],[5,3206,1.0],[5,3516,1.0],[5,...,1.0],[11,1780,1.0],[11,1799,1.0],[11,1956,1.0
meta103,5,383,1.0],[5,2195,1.0],[5,3206,1.0],[5,...,1.0],[11,1363,1.0],[11,1383,1.0],[11,1630,1.0
meta104,7,1160,1.0],[11,50,1.0],[11,83,1.0],[11,...,1.0],[19,2158,1.0],[19,3852,1.0],[19,4329,1.0


In [20]:
node_exaction.dtypes

0      object
1      object
2      object
3      object
4      object
        ...  
295    object
296    object
297    object
298    object
299    object
Length: 300, dtype: object

In [21]:
data_adj["index"]=data_adj.index
data_adj=data_adj.reset_index()
data_adj.head(5)

Unnamed: 0,level_0,1,368,8754,5290,3172,3164,204,14,847,...,90139,9331,283358,8708,9227,124961,2529,117156,5251,index
0,1,0,1,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
1,368,1,0,0,0,1,1,1,0,0,...,0,0,0,0,0,0,0,0,0,368
2,8754,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,8754
3,5290,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,5290
4,3172,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,3172


In [22]:
data_adj["key"]=data_adj.index
data_adj.head(5)

Unnamed: 0,level_0,1,368,8754,5290,3172,3164,204,14,847,...,9331,283358,8708,9227,124961,2529,117156,5251,index,key
0,1,0,1,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
1,368,1,0,0,0,1,1,1,0,0,...,0,0,0,0,0,0,0,0,368,1
2,8754,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,8754,2
3,5290,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,5290,3
4,3172,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,3172,4


In [23]:
node_key_value=data_adj.set_index(["key"])["index"].to_dict()

In [24]:
import json
with open('node_key_value.json', 'w') as f:
    json.dump(node_key_value, f)

In [25]:
import json
with open('node_key_value.json' ) as f:
    df_test = json.load(f)

In [26]:
#df_test

In [27]:
node_exaction_clone2=node_exaction.copy(deep=True)
node_exaction_clone2.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
meta100,6,1827,1.0],[6,2700,1.0],[6,4458,1.0],[6,...,1.0],[11,5615,1.0],[29,63,1.0],[29,1014,1.0
meta101,0,3,1.0],[3,0,1.0],[3,83,1.0],[3,...,1.0],[11,895,1.0],[11,900,1.0],[11,901,1.0
meta102,5,2195,1.0],[5,3206,1.0],[5,3516,1.0],[5,...,1.0],[11,1780,1.0],[11,1799,1.0],[11,1956,1.0
meta103,5,383,1.0],[5,2195,1.0],[5,3206,1.0],[5,...,1.0],[11,1363,1.0],[11,1383,1.0],[11,1630,1.0
meta104,7,1160,1.0],[11,50,1.0],[11,83,1.0],[11,...,1.0],[19,2158,1.0],[19,3852,1.0],[19,4329,1.0


In [28]:
node_exaction_clone2.to_csv("edges_exaction_index.csv")

In [29]:
node_exaction_index=pd.read_csv(f'edges_exaction_index.csv',header=0,index_col=0,sep=',')
node_exaction_index.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
meta100,6,1827,1.0],[6,2700,1.0],[6,4458,1.0],[6,...,1.0],[11,5615,1.0],[29,63,1.0],[29,1014,1.0
meta101,0,3,1.0],[3,0,1.0],[3,83,1.0],[3,...,1.0],[11,895,1.0],[11,900,1.0],[11,901,1.0
meta102,5,2195,1.0],[5,3206,1.0],[5,3516,1.0],[5,...,1.0],[11,1780,1.0],[11,1799,1.0],[11,1956,1.0
meta103,5,383,1.0],[5,2195,1.0],[5,3206,1.0],[5,...,1.0],[11,1363,1.0],[11,1383,1.0],[11,1630,1.0
meta104,7,1160,1.0],[11,50,1.0],[11,83,1.0],[11,...,1.0],[19,2158,1.0],[19,3852,1.0],[19,4329,1.0


In [30]:
node_exaction_index2=node_exaction_index.copy(deep=True)
node_exaction_index2.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
meta100,6,1827,1.0],[6,2700,1.0],[6,4458,1.0],[6,...,1.0],[11,5615,1.0],[29,63,1.0],[29,1014,1.0
meta101,0,3,1.0],[3,0,1.0],[3,83,1.0],[3,...,1.0],[11,895,1.0],[11,900,1.0],[11,901,1.0
meta102,5,2195,1.0],[5,3206,1.0],[5,3516,1.0],[5,...,1.0],[11,1780,1.0],[11,1799,1.0],[11,1956,1.0
meta103,5,383,1.0],[5,2195,1.0],[5,3206,1.0],[5,...,1.0],[11,1363,1.0],[11,1383,1.0],[11,1630,1.0
meta104,7,1160,1.0],[11,50,1.0],[11,83,1.0],[11,...,1.0],[19,2158,1.0],[19,3852,1.0],[19,4329,1.0


In [31]:
for i in node_exaction_index2.columns:
  node_exaction_index2[i]=node_exaction_index2[i].astype(str).str.replace('[','')
node_exaction_index2.head(5)

  


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
meta100,6,1827,1.0],6,2700,1.0],6,4458,1.0],6,...,1.0],11,5615,1.0],29,63,1.0],29,1014,1.0
meta101,0,3,1.0],3,0,1.0],3,83,1.0],3,...,1.0],11,895,1.0],11,900,1.0],11,901,1.0
meta102,5,2195,1.0],5,3206,1.0],5,3516,1.0],5,...,1.0],11,1780,1.0],11,1799,1.0],11,1956,1.0
meta103,5,383,1.0],5,2195,1.0],5,3206,1.0],5,...,1.0],11,1363,1.0],11,1383,1.0],11,1630,1.0
meta104,7,1160,1.0],11,50,1.0],11,83,1.0],11,...,1.0],19,2158,1.0],19,3852,1.0],19,4329,1.0


In [32]:
for i in node_exaction_index2.columns:
    if (int(i)+1)% 3!=0:
       pass
    else:
       print(i)

2
5
8
11
14
17
20
23
26
29
32
35
38
41
44
47
50
53
56
59
62
65
68
71
74
77
80
83
86
89
92
95
98
101
104
107
110
113
116
119
122
125
128
131
134
137
140
143
146
149
152
155
158
161
164
167
170
173
176
179
182
185
188
191
194
197
200
203
206
209
212
215
218
221
224
227
230
233
236
239
242
245
248
251
254
257
260
263
266
269
272
275
278
281
284
287
290
293
296
299


In [33]:
node_exaction_index2.to_csv("edges_exaction_index1.csv")

In [34]:
node_exaction_index=pd.read_csv(f'edges_exaction_index1.csv',header=0,index_col=0,sep=',')
node_exaction_index.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
meta100,6,1827,1.0],6,2700,1.0],6,4458,1.0],6,...,1.0],11,5615,1.0],29,63,1.0],29,1014,1.0
meta101,0,3,1.0],3,0,1.0],3,83,1.0],3,...,1.0],11,895,1.0],11,900,1.0],11,901,1.0
meta102,5,2195,1.0],5,3206,1.0],5,3516,1.0],5,...,1.0],11,1780,1.0],11,1799,1.0],11,1956,1.0
meta103,5,383,1.0],5,2195,1.0],5,3206,1.0],5,...,1.0],11,1363,1.0],11,1383,1.0],11,1630,1.0
meta104,7,1160,1.0],11,50,1.0],11,83,1.0],11,...,1.0],19,2158,1.0],19,3852,1.0],19,4329,1.0


In [35]:
for i in node_exaction_index.columns:
  if (int(i)+1)% 3!=0:
     df=node_exaction_index[i]
     node_exaction_index[i]=node_exaction_index[i].map(node_key_value)
     df.update(node_exaction_index[i])
node_exaction_index.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
meta100,204,1854,1.0],204,84987,1.0],204,1396,1.0],204,...,1.0],93185,2495,1.0],22921,1994,1.0],22921,4738,1.0
meta101,1,5290,1.0],5290,1,1.0],5290,1655,1.0],5290,...,1.0],93185,10521,1.0],93185,220988,1.0],93185,3184,1.0
meta102,3164,2746,1.0],3164,51596,1.0],3164,9588,1.0],3164,...,1.0],93185,2197,1.0],93185,6192,1.0],93185,10541,1.0
meta103,3164,1059,1.0],3164,2746,1.0],3164,51596,1.0],3164,...,1.0],93185,3428,1.0],93185,6209,1.0],93185,8653,1.0
meta104,14,81631,1.0],93185,25962,1.0],93185,1655,1.0],93185,...,1.0],51296,9761,1.0],51296,64434,1.0],51296,483,1.0


In [36]:
node_exaction_index.to_csv("NK_explain_edges.csv")

In [37]:
len(node_exaction_index)

3487