##### 构建蛋白的图模型，以GNN为核心，执行多分类任务（未被标记的肽段上的残基记为0，被标记的肽段上的非标记残基记为1，被标记的肽段上的标记残基记为2）

In [None]:
import sys
sys.path.append('E:/Proteomics/PhD_script/1. Dizco/')
sys.path.append('D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/Ligand Discovery/fragment-embedding/')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from fragmentembedding import FragmentEmbedder
from Protein2Graph import prot2graph
from StructureInformationIntegration import data_integration
from os import listdir
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler,label_binarize
from sklearn.metrics import roc_auc_score,precision_score,recall_score,f1_score,accuracy_score,matthews_corrcoef,average_precision_score,auc,precision_recall_curve
from models import dizco_GNN
from torch.optim import lr_scheduler
from EarlyStopping import EarlyStopping
from tqdm import tqdm
from multiprocessing import Pool
from rdkit import RDLogger
import pickle
import copy
import warnings
warnings.filterwarnings('ignore')

In [None]:
#取蛋白的结构数据为特征一，将每个蛋白都构建成一张图进行GNN训练，同时加入probe的embedding数据，以及每一个氨基酸的结构信息
#对每个蛋白的氨基酸位点进行标记
#质谱有打到但预测评分小于0.85的肽段(不论是否有修饰)上的氨基酸全部标记为0
#对于有修饰的肽段，修饰位点标为2，其余标为1
#没有修饰但评分高于0.85的肽段单拎出来作为最后的预测数据
def ResLabel(table,thre=0.85):
    label_table = table[table['lg_probs']>thre]
    label_table = label_table[~label_table['label_site'].isna()]
    unlabel_table = table[table['lg_probs']<=thre]
    
    labeledPEP = ParseRes(label_table,state='labeled')
    unlabeledPEP = ParseRes(unlabel_table,state='unlabeled')
    res_table = pd.concat([labeledPEP,unlabeledPEP],axis=0)
    res_table = res_table.sort_values(by='label',ascending=False).reset_index(drop=True)
    res_table = res_table.drop_duplicates(subset=['Accession','res_site']).reset_index(drop=True)
    
    prot_lt = list(res_table['Accession'].unique())
    prot_label = {}
    for prot in prot_lt:
        seq = uniprot_infor[uniprot_infor['Entry']==prot].iloc[0,-1]
        prot_dict = {f'{p}{i+1}':np.nan for i,p in enumerate(seq)}
        prot_label.setdefault(prot,prot_dict)
    
    for prot,table in res_table.groupby(by='Accession'):
        for site,label in zip(table['res_site'].to_list(),table['label'].to_list()):
            prot_label[prot][site] = label

    return prot_label
    
def ParseRes(data,state='labeled'):
    pep_label_result = pd.DataFrame()
    for (prot,pep),table in data.groupby(by=['Master Protein Accessions','Upper_Seq']):
        try: seq = uniprot_infor[uniprot_infor['Entry']==prot].iloc[0,-1]
        except: continue

        start = seq.find(pep)+1
        label_site = list(table['label_site'].unique())
        pep_label = []
        for i,p in enumerate(pep):
            if state == 'labeled':
                if f'{p}{i+1}' in label_site:
                    pep_label.append(tuple((prot,f'{p}{start+i}',2)))
                else: pep_label.append(tuple((prot,f'{p}{start+i}',1)))
            elif state == 'unlabeled' : pep_label.append(tuple((prot,f'{p}{start+i}',0)))
        pep_label = pd.DataFrame(pep_label,columns=['Accession','res_site','label'])
        pep_label_result = pd.concat([pep_label_result,pep_label],axis=0)
    pep_label_result = pep_label_result.sort_values(by='label',ascending=False).reset_index(drop=True)
   
    return pep_label_result.drop_duplicates(subset=['Accession','res_site'])
    
def prot2pdb(prot_list,pdb_path):
    prot_pdb_list,prot_list_ = [],[]
    for name in listdir(pdb_path):
        if any(prot in name for prot in prot_list):
            if name.split('-')[1] not in prot_list_:
                prot_pdb_list.append(name)
                prot_list_.append(name.split('-')[1])
    return prot_pdb_list

def DataSetSplit(dataSet):
    data_index = np.array(range(len(dataSet)))
    
    tv_index, test_index = train_test_split(data_index,test_size=0.1,random_state=42,shuffle=True)
    train_index, val_index = train_test_split(tv_index,test_size=1/9,random_state=42,shuffle=True)
    extract = lambda x:[dataSet[i] for i in x]
    train_set,val_set,test_set = extract(train_index),extract(val_index),extract(test_index)
    train_set = DataCombine(train_set)
    val_set = DataCombine(val_set)
    test_set = DataCombine(test_set,sample_type='test_set')
    
    return train_set,val_set,test_set
    
def DataCombine(data,sample_type=None):
    graph_set,probe_emd_set,label_set = [],[],[]
    str_df = pd.DataFrame()
    for i,samples in enumerate(data):
        prot_graph,str_table,probeEmbedding,labels = samples
        str_table_copy = str_table.copy()
        str_table_copy.insert(0,'sample_index',i)
        str_df = pd.concat([str_df,str_table_copy],axis=0)
        probeEmbedding = np.asarray(probeEmbedding).astype(np.float32)
        labels = np.asarray(labels).astype(np.float32)
        
        graph_set.append(prot_graph)
        probe_emd_set.append(torch.tensor(probeEmbedding))
        label_set.append(torch.LongTensor(labels))
    
    str_df = pd.get_dummies(str_df, columns = ['Secondary structure'],dtype=int)
    str_set = []
    for i in range(len(data)):
        str_table = str_df[str_df['sample_index']==i]
        str_table = StandardScaler().fit_transform(str_table.iloc[:,3:])
        str_table = np.asarray(str_table).astype(np.float32)
        str_set.append(torch.tensor(str_table))
    
    if sample_type is None:
        graph_set = [tuple((i,graph)) for i,graph in enumerate(graph_set)]
        graph_set = DataLoader(graph_set, batch_size=16, drop_last=False, shuffle=True)

    return (graph_set,str_set,probe_emd_set,label_set)


In [None]:
#因为进行多分类任务，所以对应的损失函数也有相应改变，模型最后一层的输出需要用的softmax
def model_train(model,train_set,val_set,save_best_model=True,model_path=None,save_path=None,earlyStop=True,epochs=100,lr=0.0001):
    if save_path is not None: early_stopping = EarlyStopping(save_path)
    else: earlyStop = False
    metrics_name = ['train_loss','val_loss','acc','pre0', 'pre1', 'pre2','recall','f1','ap','mcc','auc_score','prc_score_0','prc_score_1','prc_score_2']
    #定义损失函数计算方法，定义优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=1e-3)
    scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=lr/10, max_lr=lr*10,
                                      cycle_momentum=False,step_size_up=len(train_set[0]))
    
    prot_graphs,prot_struct,probe,labels = train_set
    metrics_result = []
    for epoch in tqdm(range(epochs)):
        model.train()
        total_loss = 0

        for batch_idx,graph_values in prot_graphs:
            struct_values = torch.cat([prot_struct[i] for i in batch_idx],axis=0)
            probe_values = torch.cat([probe[i] for i in batch_idx],axis=0)
            label_values = torch.cat([labels[i] for i in batch_idx],axis=0)
            graph_values, struct_values, probe_values, label_values = graph_values.to(device), struct_values.to(device), probe_values.to(device), label_values.to(device)
            
            optimizer.zero_grad()
            output = model(graph_values, struct_values, probe_values)
            loss = criterion(output, label_values)
            total_loss += loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            scheduler.step()
            
        train_loss = total_loss / len(train_set)
        eval_matrics = model_eval(model,val_set,criterion)
        metrics_result.append(tuple((train_loss,))+eval_matrics)
        
        if earlyStop:
            early_stopping(eval_matrics[0], model)
            if early_stopping.early_stop:
                print("Early stopping")
                return pd.DataFrame(metrics_result,columns=metrics_name)
            break
        
        if epoch == 0:
            best_mcc = eval_matrics[7]
            best_model = copy.deepcopy(model)
            best_metrics = tuple((epoch,train_loss))+eval_matrics
        else:
            if eval_matrics[7] > best_mcc:
                best_mcc = eval_matrics[7]
                best_model = copy.deepcopy(model)
                best_metrics = tuple((epoch,train_loss))+eval_matrics
        
    if save_best_model:
        torch.save(best_model.state_dict(),model_path+'best_model.pth')
        print('The metrics of best model are:')
        print(f'epoch: {best_metrics[0]}, train_loss: {best_metrics[1]}, val_loss: {best_metrics[2]}')
        print(f'acc: {best_metrics[3]}, pre0: {best_metrics[4]}, pre1: {best_metrics[5]}, pre2: {best_metrics[6]}')
        print(f'recall: {best_metrics[7]}, f1: {best_metrics[8]}, ap: {best_metrics[9]}')
        print(f'mcc: {best_metrics[10]}, AUC: {best_metrics[11]}, PRC0: {best_metrics[12]}, PRC1: {best_metrics[13]}, PRC2: {best_metrics[14]}')
        
    return pd.DataFrame(metrics_result,columns=metrics_name)
        
def model_eval(model,val_set,criterion,thre=0.5):
    model.eval()
    total_loss = 0
    y_true,y_pred,y_prob = [],[],[]
    prot_graphs,prot_struct,probe,labels = val_set
    
    with torch.no_grad():
        for batch_idx,graph_values in prot_graphs:
            struct_values = torch.cat([prot_struct[i] for i in batch_idx],axis=0)
            probe_values = torch.cat([probe[i] for i in batch_idx],axis=0)
            label_values = torch.cat([labels[i] for i in batch_idx],axis=0)
            graph_values, struct_values, probe_values, label_values = graph_values.to(device), struct_values.to(device), probe_values.to(device), label_values.to(device)
            
            output = model(graph_values, struct_values, probe_values)
            loss = criterion(output, label_values)
            total_loss += loss.item()
            prob = F.softmax(output, dim=1).detach().cpu().numpy()
            y_true.extend(label_values.detach().tolist())
            y_pred.extend(output.argmax(dim=1).detach().cpu().numpy())
            y_prob.extend(prob)
    
        val_loss = total_loss/len(val_set)
        eval_matrics = metrics_calculation(y_true,y_pred,y_prob)
        
        return tuple((val_loss,))+eval_matrics

def metrics_calculation(y_true,y_pred,y_prob):
    y_true, y_pred, y_prob = np.float64(y_true), np.float64(y_pred), np.float64(y_prob)

    auc_score = roc_auc_score(y_true,y_prob,multi_class='ovr')
    acc = accuracy_score(y_true,y_pred)
    pre0,pre1,pre2 = precision_score(y_true,y_pred,average=None)
    recall = recall_score(y_true,y_pred,average='macro')
    f1 = f1_score(y_true,y_pred,average='macro')
    mcc = matthews_corrcoef(y_true,y_pred)
    ap = average_precision_score(y_true,y_prob,average='macro')
    y_true_binarized = label_binarize(y_true, classes=[0, 1, 2])
    prc_score_lt = []
    for i in range(3):
        tpr,fpr,_ = precision_recall_curve(y_true_binarized[:,i],y_prob[:,i])
        prc_score = auc(fpr,tpr)
        prc_score_lt.append(prc_score)
    
    return tuple((acc,pre0,pre1,pre2,recall,f1,ap,mcc,auc_score,prc_score_lt[0],prc_score_lt[1],prc_score_lt[2]))


In [None]:
#1. 把八种探针的PSM数据整合
path = 'D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/'
file_path = f'{path}Test files/'
merge_PSM = pd.read_csv(f'{file_path}merge_isoTOP_PSM_data.csv')
uniprot_infor = pd.read_csv(f'{file_path}uniprotkb_human_AND_reviewed_true_AND_m_2024_09_12.tsv',sep='\t')
uniprot_infor['Gene Names'] = uniprot_infor['Gene Names'].str.split(' ',expand=True)[0]

In [None]:
#1.1 为八种探针鉴定到的所有蛋白创建图
pdb_path = f'{path}AlphaFold_pdbFiles/'
processed_path = f'{path}/prot_raw_graph_coord/'
prot_list = merge_PSM['Master Protein Accessions'].unique()
prot_list = [prot for prot in prot_list if not uniprot_infor[uniprot_infor['Entry']==prot].empty]
prot_pdb_list = prot2pdb(prot_list,pdb_path)

pg = prot2graph(processed_path)
with Pool(12) as p:
    p.map(pg.CreateCoordGraph,tqdm(prot_pdb_list))

In [None]:
#1.2 为八种探针鉴定到的所有蛋白总结结构信息
with Pool(8) as p:
    structure_data = p.map(data_integration,prot_pdb_list)
structure_data_dic = {}
for data in structure_data: structure_data_dic.update(data)

pickle.dump(structure_data_dic,open(f'{path}pkl_files/structure_infor_summary.pkl','wb'))

In [None]:
#1.3 为每个探针生成一个embedding
probe_smile = {
    
    'AJ5': 'C#CCCC1(N=N1)CCC(NC)=O',
    'AJ8': 'C#CCCC1(N=N1)CCC(NC[C@H]2[C@@H](C)CCCN2CC3=CC=C(OC)C=C3)=O',
    'AJ12': 'C#CCCC1(N=N1)CCNC(/C(C2=CC=CC=C2)=C/C3=CC=CC=C3)=O',
    'AJ14': 'C#CCCC1(N=N1)CCC(NC(C2(C[C@H](C3)C4)C[C@H]4C[C@H]3C2)C)=O',
    'AJ22': 'C#CCCC1(N=N1)CCNC(/C(CC)=C/C2=CC=CC([N+]([O-])=O)=C2)=O',
    'AJ32': 'C#CCCC1(N=N1)CCNC(C2(CC2)C3=CC(OC(F)(F)O4)=C4C=C3)=O',
    'AJ39': 'C#CCCC1(N=N1)CCNC(CCC2=NC(C3=CC=CC=C3)=C(C4=CC=CC=C4)O2)=O',
    'CP78': 'C#CCCC1(N=N1)CCC(N2C(CC3=CC=CC=C3)CCCC2)=O'
    
    }

lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
fe = FragmentEmbedder()
probe_emd = {probe:fe.transform([smile]) for probe,smile in probe_smile.items()}

In [None]:
#1.4 根据每个探针找到的肽段以及binding sites的差异，重新定义图，并整合数据
#有的蛋白未收录在AlphaFold中，或许是由于uniprot更新了ID
str_infor_dic = pickle.load(open(f'{path}pkl_files/structure_infor_summary.pkl','rb'))

dataSet = []
for probe,table in tqdm(merge_PSM.groupby(by='probe')):
    prot_label = ResLabel(table,thre=0.85)
    for prot,labels in prot_label.items():
        try: prot_graph = torch.load(f'{processed_path}{prot}.pt')
        except:
            print(f'The graph of {prot} could not be found, please check whether there is a PDB file of it')
            continue
        node_update = np.array([tuple((int(res[1:])-1,label)) for res,label in labels.items() if not np.isnan(label)])
        
        try: new_edge,_ = subgraph(subset=torch.tensor(node_update[:,0]),edge_index=prot_graph.edge_index,relabel_nodes=True)
        except:
            print(f'{prot}')
            print(f'No. res in PDB: {prot_graph.x.shape[0]}')
            print(f'No. res in Uniprot: {len(labels)}')
            continue
            
        new_X = prot_graph.x[torch.tensor(node_update[:,0])]
        new_Y = node_update[:,1]
        new_graph = Data(x=new_X,edge_index=new_edge)
        
        str_table = str_infor_dic[prot].copy()
        str_table = str_table.iloc[node_update[:,0],:]
        
        probeEmbedding = probe_emd[probe]
        probeEmbedding = probeEmbedding.repeat(len(node_update),axis=0)
        
        dataSet.append(tuple((new_graph,str_table,probeEmbedding,new_Y)))

In [None]:
#2. 划分训练集、验证集和测试集
train_set,val_set,test_set = DataSetSplit(dataSet)

In [None]:
#3. 模型训练与评估
graph_dim,struct_dim,probe_dim = test_set[0][0].x.shape[1],test_set[1][0].shape[1],test_set[2][0].shape[1]
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_path = 'E:/Proteomics/PhD_script/1. Dizco/'

dizco_gnn = dizco_GNN(graph_dim,struct_dim,probe_dim).to(device)
metrics_result = model_train(dizco_gnn,train_set,val_set,model_path=model_path,lr=0.0001)

metrics_result.to_csv('D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/Model_Metrics_Result/Adam_lr(1e-5)_batch(16)(GNN).csv',index=False)