##### 构建蛋白的图模型，以GNN为核心，执行二分类任务（蛋白的氨基酸是否有被标记，或蛋白是否被标记）

In [None]:
import sys
sys.path.append('E:/Proteomics/PhD_script/1. Dizco/')
sys.path.append('D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/Ligand Discovery/fragment-embedding/')
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from fragmentembedding import FragmentEmbedder
from Protein2Graph import prot2graph
from StructureInformationIntegration import data_integration
from os import listdir
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score,precision_score,recall_score,f1_score,accuracy_score,matthews_corrcoef,average_precision_score,brier_score_loss,auc,precision_recall_curve
from models import dizco_GNN
from torch.optim import lr_scheduler
from EarlyStopping import EarlyStopping
from tqdm import tqdm
from multiprocessing import Pool
from rdkit import RDLogger
import pickle
import copy
import warnings
warnings.filterwarnings('ignore')

In [None]:
#对蛋白氨基酸序列的每一个残基说明是否有被探针标记
def ResLabel(table,thre=0.85):
    label_table = table[table['lg_probs']>thre]
    label_table = label_table[~label_table['label_site'].isna()]
    unlabel_table = table[table['lg_probs']<=thre]
    
    labeledPEP = ParseRes(label_table,state='labeled')
    unlabeledPEP = ParseRes(unlabel_table,state='unlabeled')
    res_table = pd.concat([labeledPEP,unlabeledPEP],axis=0)
    res_table = res_table.sort_values(by='label',ascending=False).reset_index(drop=True)
    res_table = res_table.drop_duplicates(subset=['Accession','res_site']).reset_index(drop=True)
    
    prot_lt = list(res_table['Accession'].unique())
    prot_label = {}
    for prot in prot_lt:
        seq = uniprot_infor[uniprot_infor['Entry']==prot].iloc[0,-1]
        prot_dict = {f'{p}{i+1}':np.nan for i,p in enumerate(seq)}
        prot_label.setdefault(prot,prot_dict)
    
    for prot,table in res_table.groupby(by='Accession'):
        for site,label in zip(table['res_site'].to_list(),table['label'].to_list()):
            prot_label[prot][site] = label

    return prot_label

#根据dizco的数据确认残基是否有被标记，若是则标为1，否则标为0
#该标记可进一步延伸为蛋白层面上的标记，即若某一蛋白中任一残基被标为1，则该蛋白也被标为1，否则标为0
def ParseRes(data,state='labeled'):
    pep_label_result = pd.DataFrame()
    for (prot,pep),table in data.groupby(by=['Master Protein Accessions','Upper_Seq']):
        try: seq = uniprot_infor[uniprot_infor['Entry']==prot].iloc[0,-1]
        except: continue

        start = seq.find(pep)+1
        label_site = list(table['label_site'].unique())
        pep_label = []
        for i,p in enumerate(pep):
            if state == 'labeled':
                if f'{p}{i+1}' in label_site:
                    pep_label.append(tuple((prot,f'{p}{start+i}',1)))
                else: pep_label.append(tuple((prot,f'{p}{start+i}',0)))
            elif state == 'unlabeled' : pep_label.append(tuple((prot,f'{p}{start+i}',0)))
        pep_label = pd.DataFrame(pep_label,columns=['Accession','res_site','label'])
        pep_label_result = pd.concat([pep_label_result,pep_label],axis=0)
    pep_label_result = pep_label_result.sort_values(by='label',ascending=False).reset_index(drop=True)
   
    return pep_label_result.drop_duplicates(subset=['Accession','res_site'])

In [None]:
#将蛋白的accession name转为Alphafold的pdb文件名
def prot2pdb(prot_list,pdb_path):
    prot_pdb_list,prot_list_ = [],[]
    for name in listdir(pdb_path):
        if any(prot in name for prot in prot_list):
            if name.split('-')[1] not in prot_list_:
                prot_pdb_list.append(name)
                prot_list_.append(name.split('-')[1])
    return prot_pdb_list

In [None]:
#数据集按8:1:1划分
def DataSetSplit(dataSet):
    data_index = np.array(range(len(dataSet)))
    
    tv_index, test_index = train_test_split(data_index,test_size=0.1,random_state=42,shuffle=True)
    train_index, val_index = train_test_split(tv_index,test_size=1/9,random_state=42,shuffle=True)
    extract = lambda x:[dataSet[i] for i in x]
    train_set,val_set,test_set = extract(train_index),extract(val_index),extract(test_index)
    train_set = DataCombine(train_set)
    val_set = DataCombine(val_set)
    test_set = DataCombine(test_set,sample_type='test_set')
    
    return train_set,val_set,test_set

In [None]:
#合并蛋白的图数据，探针的embedding数据以及label数据
def DataCombine(data,sample_type=None):
    graph_set,probe_emd_set,label_set = [],[],[]
    for i,samples in enumerate(data):
        prot_graph,probeEmbedding,labels = samples
        probeEmbedding = np.asarray(probeEmbedding).astype(np.float32)
        labels = np.asarray(labels).astype(np.float32)
        
        graph_set.append(prot_graph)
        probe_emd_set.append(torch.tensor(probeEmbedding))
        label_set.append(torch.tensor(labels))
    
    if sample_type is None:
        graph_set = [tuple((i,graph)) for i,graph in enumerate(graph_set)]
        graph_set = DataLoader(graph_set, batch_size=8, drop_last=True, shuffle=True)

    return (graph_set,probe_emd_set,label_set)

In [None]:
#在这里由于图模型和探针的embedding模型是分开加载的，因此需要保留batch_idx以调取对应批次的数据
#分开加载的原因是因为在进行氨基酸层面的预测时，每个蛋白的probe embedding和label维度不一致，无法stack到一起
#若进行蛋白层面的预测则无需如此操作
def model_train(model,train_set,val_set,save_best_model=True,model_path=None,save_path=None,earlyStop=True,epochs=100,lr=0.0001,thre=0.5):
    if save_path is not None: early_stopping = EarlyStopping(save_path)
    else: earlyStop = False
    metrics_name = ['train_loss','val_loss','acc','precision','recall','f1','ap','bsl','mcc','auc_score','prc_score']
    #定义损失函数计算方法，定义优化器
    #criterion = nn.binary_cross_entropy_with_logits()
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=1e-3)
    scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=lr/10, max_lr=lr*10,
                                      cycle_momentum=False,step_size_up=len(train_set[0]))
    
    prot_graphs,probe,labels = train_set
    weights = []
    for label in labels:
        w = []
        for l in label.tolist():
            if l == 0: w.append(1)
            elif l == 1: w.append(10)
        weights.append(torch.tensor(w))
    weight_increase = 1
    
    metrics_result = []
    for epoch in tqdm(range(epochs)):
        model.train()
        total_loss = 0

        for batch_idx,graph_values in prot_graphs:
            probe_values = torch.cat([probe[i] for i in batch_idx],axis=0)
            label_values = torch.cat([labels[i] for i in batch_idx],axis=0)
            weight_values = torch.cat([weights[i] for i in batch_idx],axis=0)
            graph_values, probe_values, label_values, weight_values = graph_values.to(device), probe_values.to(device), label_values.to(device), weight_values.to(device)
            
            optimizer.zero_grad()
            output = model(graph_values, probe_values)
            loss = F.binary_cross_entropy_with_logits(output.squeeze(), label_values, weight=weight_values)
            total_loss += loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            scheduler.step()
            
            pred_labels = (F.sigmoid(output) > thre).int()
            mask = (pred_labels != label_values.unsqueeze(1))
            node_nums = []
            for idx,i in enumerate(batch_idx):
                num = len(labels[i])
                if idx == 0: node_nums.append(tuple((0,num)))
                else: node_nums.append(tuple((node_nums[idx-1][1],node_nums[idx-1][1]+num)))
             
            for i,(start,end) in zip(batch_idx,node_nums):
                batch_mask = mask.tolist()[start:end]
                batch_mask = [m[0] for m in batch_mask]
                weights[i][batch_mask] += weight_increase
                
        train_loss = total_loss / len(labels)
        eval_matrics = model_eval(model,val_set)
        metrics_result.append(tuple((train_loss,))+eval_matrics)
        
        if earlyStop:
            early_stopping(eval_matrics[0], model)
            if early_stopping.early_stop:
                print("Early stopping")
                return pd.DataFrame(metrics_result,columns=metrics_name)
            break
        
        if epoch == 0:
            best_mcc = eval_matrics[7]
            best_model = copy.deepcopy(model)
            best_metrics = tuple((epoch,train_loss))+eval_matrics
        else:
            if eval_matrics[7] > best_mcc:
                best_mcc = eval_matrics[7]
                best_model = copy.deepcopy(model)
                best_metrics = tuple((epoch,train_loss))+eval_matrics
        
    if save_best_model:
        torch.save(best_model.state_dict(),model_path+'best_model.pth')
        print('The metrics of best model are:')
        print(f'epoch: {best_metrics[0]}, train_loss: {best_metrics[1]}, val_loss: {best_metrics[2]}')
        print(f'acc: {best_metrics[3]}, precision: {best_metrics[4]}, recall: {best_metrics[5]}')
        print(f'f1: {best_metrics[6]}, ap: {best_metrics[7]}, bsl: {best_metrics[8]}')
        print(f'mcc: {best_metrics[9]}, AUC: {best_metrics[10]}, PRC: {best_metrics[11]}')
        
    return pd.DataFrame(metrics_result,columns=metrics_name)
        
def model_eval(model,val_set,thre=0.5):
    model.eval()
    total_loss = 0
    y_true,y_pred,y_prob = [],[],[]
    prot_graphs,probe,labels = val_set
    
    weights = []
    for label in labels:
        w = []
        for l in label.tolist():
            if l == 0: w.append(1)
            elif l == 1: w.append(10)
        weights.append(torch.tensor(w))
    weight_increase = 1
    
    with torch.no_grad():
        for batch_idx,graph_values in prot_graphs:
            probe_values = torch.cat([probe[i] for i in batch_idx],axis=0)
            label_values = torch.cat([labels[i] for i in batch_idx],axis=0)
            weight_values = torch.cat([weights[i] for i in batch_idx],axis=0)
            graph_values, probe_values, label_values, weight_values = graph_values.to(device), probe_values.to(device), label_values.to(device), weight_values.to(device)
            
            output = model(graph_values, probe_values)
            prob = output.squeeze()
            loss = F.binary_cross_entropy_with_logits(output.squeeze(), label_values, weight=weight_values)
            total_loss += loss.item()
            prob = F.sigmoid(prob)
            pred = (prob >= thre).int()
            y_true.extend(label_values.tolist())
            y_pred.extend(pred.tolist())
            y_prob.extend(prob.tolist())
            
            mask = (pred != label_values.unsqueeze(1))
            node_nums = []
            for idx,i in enumerate(batch_idx):
                num = len(labels[i])
                if idx == 0: node_nums.append(tuple((0,num)))
                else: node_nums.append(tuple((node_nums[idx-1][1],node_nums[idx-1][1]+num)))
             
            for i,(start,end) in zip(batch_idx,node_nums):
                batch_mask = mask.tolist()[start:end]
                batch_mask = [m[0] for m in batch_mask]
                weights[i][batch_mask] += weight_increase
                
        val_loss = total_loss/len(labels)
        eval_matrics = metrics_calculation(y_true,y_pred,y_prob)
        
        return tuple((val_loss,))+eval_matrics

def metrics_calculation(y_true,y_pred,y_prob):
    y_true, y_pred, y_prob = np.float64(y_true), np.float64(y_pred), np.float64(y_prob)

    auc_score = roc_auc_score(y_true,y_prob)
    acc = accuracy_score(y_true,y_pred)
    precision = precision_score(y_true,y_pred,average=None)[1]
    recall = recall_score(y_true,y_pred)
    f1 = f1_score(y_true,y_pred)
    mcc = matthews_corrcoef(y_true,y_pred)
    ap = average_precision_score(y_true,y_prob,average=None)
    bsl = brier_score_loss(y_true,y_prob)
    tpr,fpr,_ = precision_recall_curve(y_true,y_prob)
    prc_score = auc(fpr,tpr)
    
    return tuple((acc,precision,recall,f1,ap,bsl,mcc,auc_score,prc_score))

def metrics_plots(metrics,thre_col='mcc'):
    best_epoch = metrics.sort_values(by=thre_col,ascending=False).index[0]
    
    #评估指标在每个epoch中的表现
    plt.figure(figsize=(6,4),dpi=100)
    for i in range(metrics.shape[1]):
        if i < 2: continue
        plt.plot(metrics.index,metrics.iloc[:,i],label=metrics.columns[i])
    plt.axvline(x=best_epoch,color='black',linestyle='--')
    plt.legend(bbox_to_anchor=(1,1))
    plt.xlabel('Epoch',fontsize=12)
    plt.ylabel('Score',fontsize=12)
    plt.show()
    
    #每个epoch中损失情况
    plt.figure(figsize=(6,4),dpi=100)
    for i in range(metrics.shape[1]):
        if i >= 2: continue
        plt.plot(metrics.index,metrics.iloc[:,i],label=metrics.columns[i])
    plt.axvline(x=best_epoch,color='black',linestyle='--')
    plt.legend(bbox_to_anchor=(1,1))
    plt.xlabel('Epoch',fontsize=12)
    plt.ylabel('Loss',fontsize=12)
    plt.show()
    
def model_test(model,test_set):
    model.eval()
    prot_graphs,probe,labels = test_set
    output = []
    for graph_values, probe_values, label_values in zip(prot_graphs,probe,labels):
        graph_values, probe_values = graph_values.to(device), probe_values.to(device)
        test_output = model(graph_values,probe_values)
        output.append(test_output.tolist())
    return output,[i.tolist() for i in labels]

In [None]:
#1. 把八种探针的PSM数据整合
path = 'D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/'
file_path = f'{path}Test files/'
merge_PSM = pd.read_csv(f'{file_path}merge_isoTOP_PSM_data.csv')
uniprot_infor = pd.read_csv(f'{file_path}uniprotkb_human_AND_reviewed_true_AND_m_2024_09_12.tsv',sep='\t')
uniprot_infor['Gene Names'] = uniprot_infor['Gene Names'].str.split(' ',expand=True)[0]

In [None]:
#1.1 为八种探针鉴定到的所有蛋白创建图
pdb_path = f'{path}AlphaFold_pdbFiles/'
processed_path = f'{path}/prot_raw_graph_coord/'
prot_list = merge_PSM['Master Protein Accessions'].unique()
prot_list = [prot for prot in prot_list if not uniprot_infor[uniprot_infor['Entry']==prot].empty]
prot_pdb_list = prot2pdb(prot_list,pdb_path)

pg = prot2graph(processed_path)
with Pool(12) as p:
    p.map(pg.CreateCoordGraph,tqdm(prot_pdb_list))

In [None]:
#1.2 为八种探针鉴定到的所有蛋白总结结构信息
with Pool(8) as p:
    structure_data = p.map(data_integration,prot_pdb_list)
structure_data_dic = {}
for data in structure_data: structure_data_dic.update(data)

pickle.dump(structure_data_dic,open(f'{path}pkl_files/structure_infor_summary.pkl','wb'))

In [None]:
#1.3 为每个探针生成一个embedding
probe_smile = {
    
    'AJ5': 'C#CCCC1(N=N1)CCC(NC)=O',
    'AJ8': 'C#CCCC1(N=N1)CCC(NC[C@H]2[C@@H](C)CCCN2CC3=CC=C(OC)C=C3)=O',
    'AJ12': 'C#CCCC1(N=N1)CCNC(/C(C2=CC=CC=C2)=C/C3=CC=CC=C3)=O',
    'AJ14': 'C#CCCC1(N=N1)CCC(NC(C2(C[C@H](C3)C4)C[C@H]4C[C@H]3C2)C)=O',
    'AJ22': 'C#CCCC1(N=N1)CCNC(/C(CC)=C/C2=CC=CC([N+]([O-])=O)=C2)=O',
    'AJ32': 'C#CCCC1(N=N1)CCNC(C2(CC2)C3=CC(OC(F)(F)O4)=C4C=C3)=O',
    'AJ39': 'C#CCCC1(N=N1)CCNC(CCC2=NC(C3=CC=CC=C3)=C(C4=CC=CC=C4)O2)=O',
    'CP78': 'C#CCCC1(N=N1)CCC(N2C(CC3=CC=CC=C3)CCCC2)=O'
    
    }

lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
fe = FragmentEmbedder()
probe_emd = {probe:fe.transform([smile]) for probe,smile in probe_smile.items()}

In [None]:
#1.4 根据每个探针找到的肽段以及binding sites的差异，重新定义图，并整合数据
#有的蛋白未收录在AlphaFold中，或许是由于uniprot更新了ID
str_infor_dic = pickle.load(open(f'{path}pkl_files/structure_infor_summary.pkl','rb'))

dataSet = []
for probe,table in tqdm(merge_PSM.groupby(by='probe')):
    if probe == 'AJ5': continue
    prot_label = ResLabel(table,thre=0.85)
    for prot,labels in prot_label.items():
        try: prot_graph = torch.load(f'{processed_path}{prot}.pt')
        except:
            print(f'The graph of {prot} could not be found, please check whether there is a PDB file of it')
            continue
        node_update = np.array([tuple((int(res[1:])-1,label)) for res,label in labels.items() if not np.isnan(label)])
        
        try: new_edge,_ = subgraph(subset=torch.tensor(node_update[:,0]),edge_index=prot_graph.edge_index,relabel_nodes=True)
        except:
            print(f'{prot}')
            print(f'No. res in PDB: {prot_graph.x.shape[0]}')
            print(f'No. res in Uniprot: {len(labels)}')
            continue
            
        new_X = prot_graph.x[torch.tensor(node_update[:,0])]
        # if sum(node_update[:,1]!=0)>0: new_Y = 1
        # else: new_Y = 0
        new_Y = node_update[:,1]
        new_graph = Data(x=new_X,edge_index=new_edge)
        
        probeEmbedding = probe_emd[probe]
        probeEmbedding = probeEmbedding.repeat(len(node_update),axis=0)
        
        dataSet.append(tuple((new_graph,probeEmbedding,new_Y)))

In [None]:
#2. 划分训练集、验证集和测试集
train_set,val_set,test_set = DataSetSplit(dataSet)

In [None]:
#3. 模型训练与评估
graph_dim,probe_dim = test_set[0][0].x.shape[1],test_set[1][0].shape[1]
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_path = 'E:/Proteomics/PhD_script/1. Dizco/'

dizco_gnn = dizco_GNN(graph_dim,probe_dim).to(device)
metrics_result = model_train(dizco_gnn,train_set,val_set,model_path=model_path,lr=0.0001)
metrics_plots(metrics_result,thre_col='mcc')

metrics_result.to_csv('D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/Model_Metrics_Result/Adam_lr(1e-5)_batch(16)(GNN).csv',index=False)