##### 此部分代码仅使用简单的单层transformer encoder对蛋白embedding进行编译，不涉及额外信息

In [None]:
import sys
sys.path.append('E:/Proteomics/PhD_script/1. Dizco/')
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from DataTransform import CustomDataset
from torch.utils.data import DataLoader
from EarlyStopping import EarlyStopping
from models import FFFTrans
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score,precision_score,recall_score,f1_score,accuracy_score,matthews_corrcoef,average_precision_score,brier_score_loss,auc,precision_recall_curve
from statannotations.Annotator import Annotator
import copy
import warnings
warnings.filterwarnings('ignore')

In [None]:
def ResLabel(table,thre=0.85):
    label_table = table[table['lg_probs']>thre]
    label_table = label_table[~label_table['label_site'].isna()]
    unlabel_table = table[table['lg_probs']<=thre]
    
    labeledPEP = ParseRes(label_table,state='labeled')
    unlabeledPEP = ParseRes(unlabel_table,state='unlabeled')
    res_table = pd.concat([labeledPEP,unlabeledPEP],axis=0)
    res_table = res_table.sort_values(by='label',ascending=False).reset_index(drop=True)
    res_table = res_table.drop_duplicates(subset=['Accession','res_site']).reset_index(drop=True)
    
    prot_lt = list(res_table['Accession'].unique())
    prot_label = {}
    for prot in prot_lt:
        seq = uniprot_infor[uniprot_infor['Entry']==prot].iloc[0,-1]
        prot_dict = {f'{p}{i+1}':0 for i,p in enumerate(seq)}
        prot_label.setdefault(prot,prot_dict)
    
    for prot,table in res_table.groupby(by='Accession'):
        for site,label in zip(table['res_site'].to_list(),table['label'].to_list()):
            prot_label[prot][site] = label

    return prot_label
    
def ParseRes(data,state='labeled'):
    pep_label_result = pd.DataFrame()
    for (prot,pep),table in data.groupby(by=['Master Protein Accessions','Upper_Seq']):
        try: seq = uniprot_infor[uniprot_infor['Entry']==prot].iloc[0,-1]
        except: continue

        start = seq.find(pep)+1
        label_site = list(table['label_site'].unique())
        pep_label = []
        for i,p in enumerate(pep):
            if state == 'labeled':
                if f'{p}{i+1}' in label_site:
                    pep_label.append(tuple((prot,f'{p}{start+i}',1)))
                else: pep_label.append(tuple((prot,f'{p}{start+i}',0)))
            elif state == 'unlabeled' : pep_label.append(tuple((prot,f'{p}{start+i}',0)))
        pep_label = pd.DataFrame(pep_label,columns=['Accession','res_site','label'])
        pep_label_result = pd.concat([pep_label_result,pep_label],axis=0)
    pep_label_result = pep_label_result.sort_values(by='label',ascending=False).reset_index(drop=True)
   
    return pep_label_result.drop_duplicates(subset=['Accession','res_site'])

def collate_fn(batch):
    probe_emd, prot_emd, labels, att_mask = zip(*batch)
    batch_probe_emd = torch.stack(probe_emd)
    batch_prot_emd = torch.stack(prot_emd)
    batch_labels = torch.stack(labels)
    batch_att_mask = torch.stack(att_mask)
    return batch_probe_emd, batch_prot_emd, batch_labels, batch_att_mask

def model_train(model,train_set,val_set,save_best_model=True,model_path=None,save_path=None,earlyStop=True,epochs=100,lr=0.0001):
    if save_path is not None: early_stopping = EarlyStopping(save_path)
    else: earlyStop = False
    metrics_name = ['train_loss','val_loss','acc','precision','recall','f1','ap','bsl','mcc','auc_score','prc_score']
    #定义损失函数计算方法，定义优化器
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=1e-4)
    scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=lr/10, max_lr=lr*10,
                                      cycle_momentum=False,step_size_up=len(train_set))
    
    metrics_result = []
    for epoch in tqdm(range(epochs)):
        model.train()
        total_loss = 0
        total_num = 0
        
        for probe_emd,prot_emd,labels,att_mask in train_set:
            probe_emd,prot_emd,labels,att_mask = probe_emd.to(device), prot_emd.to(device), labels.to(device), att_mask.to(device)
            optimizer.zero_grad()
            output = model(probe_emd,prot_emd,att_mask)
            mask = ~att_mask.bool()
            output_masked = output[mask]
            labels_masked = labels[mask]
            total_num += len(labels_masked)
            loss = criterion(output_masked.squeeze(), labels_masked.float())
            total_loss += loss.item()
            loss.backward(retain_graph=True)
            optimizer.step()
            scheduler.step()
        
        train_loss = total_loss / total_num
        eval_matrics = model_eval(model,val_set,criterion)
        metrics_result.append(tuple((train_loss,))+eval_matrics)
    
        if earlyStop:
            early_stopping(eval_matrics[0], model)
            if early_stopping.early_stop:
                print("Early stopping")
                return pd.DataFrame(metrics_result,columns=metrics_name)
            break
        
        if epoch == 0:
            best_mcc = eval_matrics[7]
            best_model = copy.deepcopy(model)
            best_metrics = tuple((epoch,train_loss))+eval_matrics
        else:
            if eval_matrics[7] > best_mcc:
                best_mcc = eval_matrics[7]
                best_model = copy.deepcopy(model)
                best_metrics = tuple((epoch,train_loss))+eval_matrics
        
    if save_best_model:
        torch.save(best_model.state_dict(),model_path+'best_model.pth')
        print('The metrics of best model are:\n')
        print(f'epoch: {best_metrics[0]}, train_loss: {best_metrics[1]}, val_loss: {best_metrics[2]}\n')
        print(f'acc: {best_metrics[3]}, precision: {best_metrics[4]}, recall: {best_metrics[5]}\n')
        print(f'f1: {best_metrics[6]}, ap: {best_metrics[7]}, bsl: {best_metrics[8]}\n')
        print(f'mcc: {best_metrics[9]}, AUC: {best_metrics[10]}, PRC: {best_metrics[11]}')
        
    return pd.DataFrame(metrics_result,columns=metrics_name)


def model_eval(model,val_set,criterion,thre=0.5):
    model.eval()
    total_loss = 0
    y_true,y_pred,y_prob = [],[],[]
    
    with torch.no_grad():
        for probe_emd,prot_emd,labels,att_mask in val_set:
            probe_emd,prot_emd,labels,att_mask = probe_emd.to(device), prot_emd.to(device), labels.to(device), att_mask.to(device)
            output = model(probe_emd,prot_emd,att_mask)
            mask = ~att_mask.bool()
            output_masked = output[mask]
            labels_masked = labels[mask]
            
            prob = output_masked.squeeze()
            loss = criterion(prob, labels_masked.float())
            total_loss += loss.item()
            pred = (prob >= thre).int()
            y_true.extend(labels_masked.tolist())
            y_pred.extend(pred.tolist())
            y_prob.extend(prob.tolist())
    
        val_loss = total_loss/len(val_set)
        eval_matrics = metrics_calculation(y_true,y_pred,y_prob)
        
        return tuple((val_loss,))+eval_matrics
    
def metrics_calculation(y_true,y_pred,y_prob):
    y_true, y_pred, y_prob = np.float64(y_true), np.float64(y_pred), np.float64(y_prob)

    auc_score = roc_auc_score(y_true,y_prob)
    acc = accuracy_score(y_true,y_pred)
    precision = precision_score(y_true,y_pred,average=None)[1]
    recall = recall_score(y_true,y_pred)
    f1 = f1_score(y_true,y_pred)
    mcc = matthews_corrcoef(y_true,y_pred)
    ap = average_precision_score(y_true,y_prob,average=None)
    bsl = brier_score_loss(y_true,y_prob)
    tpr,fpr,_ = precision_recall_curve(y_true,y_prob)
    prc_score = auc(fpr,tpr)
    
    return tuple((acc,precision,recall,f1,ap,bsl,mcc,auc_score,prc_score))

def metrics_plots(metrics,thre_col='mcc'):
    best_epoch = metrics.sort_values(by=thre_col,ascending=False).index[0]
    
    #评估指标在每个epoch中的表现
    plt.figure(figsize=(6,4),dpi=100)
    for i in range(metrics.shape[1]):
        if i < 2: continue
        plt.plot(metrics.index,metrics.iloc[:,i],label=metrics.columns[i])
    plt.axvline(x=best_epoch,color='black',linestyle='--')
    plt.legend(bbox_to_anchor=(1,1))
    plt.xlabel('Epoch',fontsize=12)
    plt.ylabel('Score',fontsize=12)
    plt.show()
    
    #每个epoch中损失情况
    plt.figure(figsize=(6,4),dpi=100)
    for i in range(metrics.shape[1]):
        if i >= 2: continue
        plt.plot(metrics.index,metrics.iloc[:,i],label=metrics.columns[i])
    plt.axvline(x=best_epoch,color='black',linestyle='--')
    plt.legend(bbox_to_anchor=(1,1))
    plt.xlabel('Epoch',fontsize=12)
    plt.ylabel('Loss',fontsize=12)
    plt.show()
    
def model_test(model,test_set):
    model.eval()
    out_lt, label_lt = np.array([]), np.array([])

    for probe_emd,prot_emd,labels,att_mask in val_set:
        probe_emd,prot_emd,labels,att_mask = probe_emd.to(device), prot_emd.to(device), labels.to(device), att_mask.to(device)
        output = model(probe_emd,prot_emd,att_mask)
        mask = ~att_mask.bool()
        output_masked = output[mask]
        labels_masked = labels[mask]
        out_lt = np.append(out_lt, output_masked.tolist())
        label_lt = np.append(label_lt, labels_masked.tolist())
        
    return out_lt,label_lt


path = 'D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/'
merge_PSM = pd.read_csv(f'{path}Test files/merge_isoTOP_PSM_data.csv')
uniprot_infor = pd.read_csv(f'{path}Test files/uniprotkb_human_AND_reviewed_true_AND_m_2024_09_12.tsv',sep='\t')
uniprot_infor['Gene Names'] = uniprot_infor['Gene Names'].str.split(' ',expand=True)[0]

#1. 整合数据
dataset,prot_labels = [],[]
for probe,table in tqdm(merge_PSM.groupby(by='probe')):
    if probe == 'AJ5': continue
    prot_label = ResLabel(table,thre=0.85)
    for prot,prot_infor in prot_label.items():
        seq,label = [],[]
        for res,l in prot_infor.items():
            seq.append(res[0])
            label.append(l)
        dataset.append(tuple((probe,''.join(seq),label)))
        if 1 in label: prot_labels.append(1)
        else: prot_labels.append(0)

In [None]:
#Distribution of identified protein length
prot_list = merge_PSM['Master Protein Accessions'].unique()
prot_len = [uniprot_infor[uniprot_infor['Entry']==prot].iloc[0,7] for prot in prot_list if not uniprot_infor[uniprot_infor['Entry']==prot].empty]
plt.figure(figsize=(1,3))
sns.boxplot(prot_len,width=0.5,showfliers=False)
plt.xlabel('Identified\nproteins',fontsize=12)
plt.ylabel('Protein length',fontsize=12)
plt.show()

In [None]:
#2. 划分数据集
tv_index, test_index = train_test_split(list(range(len(dataset))), test_size=0.1, random_state=42,
                                        shuffle=True, stratify=prot_labels)
tv_set, test_set = [dataset[i] for i in tv_index], [dataset[i] for i in test_index]
train_set, val_set = train_test_split(tv_set, test_size=1/9, random_state=42,
                                      shuffle=True, stratify=np.array(prot_labels)[tv_index])
train_set = CustomDataset(train_set)
val_set = CustomDataset(val_set)
test_set = CustomDataset(test_set)
train_set = DataLoader(train_set, shuffle=True, batch_size=8, drop_last=True, collate_fn=collate_fn)
val_set = DataLoader(val_set, shuffle=False, batch_size=8, drop_last=True, collate_fn=collate_fn)

In [None]:
#3. 模型训练与评估
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_path = 'E:/Proteomics/PhD_script/1. Dizco/'

fff = FFFTrans().to(device)
metrics_result = model_train(fff,train_set,val_set,model_path=model_path,lr=0.0001)
metrics_plots(metrics_result,thre_col='mcc')

In [None]:
#4. 测试集评估
fff.load_state_dict(torch.load(f'{model_path}best_model.pth'))

test_output,test_labels = model_test(fff,test_set)
test_data = pd.DataFrame([test_output,test_labels]).T
test_data.columns = ['prob','label']
test_data['model'] = 'Transformer'

plt.figure(figsize=(2,4),dpi=100)
plt.rcParams['font.sans-serif'] = 'Arial'
ax=sns.boxplot(data=test_data,x='model',y='prob',hue='label',
            width=0.5,showfliers=False)
box_pairs = [(('Transformer',1),('Transformer',0))]
annot = Annotator(ax, data=test_data,x='model',y='prob',hue='label',pairs=box_pairs)
annot.configure(test='t-test_ind', text_format='star',line_height=0.03,line_width=1)
annot.apply_and_annotate()
plt.show()