#### 以CNN为核心构建的氨基酸层面的模型进行训练

In [None]:
import sys
sys.path.append('E:/Proteomics/PhD_script/1. Dizco/')
sys.path.append('D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/Ligand Discovery/fragment-embedding/')
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset,DataLoader
from os import listdir
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score,precision_score,recall_score,f1_score,accuracy_score,matthews_corrcoef,average_precision_score,brier_score_loss,auc,precision_recall_curve
from sklearn.ensemble import RandomForestClassifier
from models import dizco_CNN
from prefetch_generator import BackgroundGenerator
from torch.optim import lr_scheduler
from EarlyStopping import EarlyStopping
from tqdm import tqdm
import copy
import warnings
warnings.filterwarnings('ignore')


In [None]:
#仿照DrugMap的方式，对数据进行整合和划分
def data_integation(probe,position_set,structure_set):
    structure_set = structure_set.drop(columns='probe')
    structure_set = pd.get_dummies(structure_set, columns = ['Secondary structure'],dtype=int)
    
    str_set = structure_set[structure_set['DataClass']!='predictive_labeling'].reset_index(drop=True)
    str_pre_set = structure_set[structure_set['DataClass']=='predictive_labeling'].reset_index(drop=True)
    str_set['DataClass'] = str_set['DataClass'].replace('labeled_site',1)
    str_set['DataClass'] = str_set['DataClass'].replace('unlabeled_site',0)
    
    tv_index, test_index = train_test_split(list(str_set.index),
                                            test_size = 0.1,random_state = 42,
                                            shuffle=True,stratify=str_set['DataClass'])
    
    str_tv_set,str_test_set = str_set.loc[tv_index].reset_index(drop=True),str_set.loc[test_index].reset_index(drop=True)
    
    str_x_train, str_x_val, str_y_train, str_y_val, train_idex, val_idex = train_test_split(str_tv_set.iloc[:,4:].to_numpy(),
                                                                                            str_tv_set['DataClass'].to_numpy(),
                                                                                            list(str_tv_set.index),
                                                                                            test_size = 0.2, random_state = 42,
                                                                                            stratify=str_tv_set['DataClass'].to_numpy(),
                                                                                            shuffle=True)
    
    pos_x_train = [position_set['_'.join([probe]+list(values[:2])+[str(int(values[2]))])] for values in  str_tv_set.loc[train_idex].values]
    pos_x_val = [position_set['_'.join([probe]+list(values[:2])+[str(int(values[2]))])] for values in  str_tv_set.loc[val_idex].values]
    pos_x_test = [position_set['_'.join([probe]+list(values[:2])+[str(int(values[2]))])] for values in  str_test_set.values]
    pos_x_pre = [position_set['_'.join([probe]+list(values[:2])+[str(int(values[2]))])] for values in  str_pre_set.values]
    
    str_x_train = StandardScaler().fit_transform(str_x_train)
    str_x_val = StandardScaler().fit_transform(str_x_val)
    str_x_test = StandardScaler().fit_transform(str_test_set.iloc[:,4:].to_numpy())
    str_x_pre = StandardScaler().fit_transform(str_pre_set.iloc[:,4:].to_numpy())
    
    str_x_train = np.asarray(str_x_train).astype(np.float32)
    str_x_val = np.asarray(str_x_val).astype(np.float32)
    str_x_test = np.asarray(str_x_test).astype(np.float32)
    str_x_pre = np.asarray(str_x_pre).astype(np.float32)
    str_y_train = np.asarray(str_y_train).astype(np.float32)
    str_y_val = np.asarray(str_y_val).astype(np.float32)
    str_y_test = np.asarray(str_test_set['DataClass'].values).astype(np.float32)
    
    pos_x_train = np.asarray(pos_x_train).astype(np.float32).transpose(0,2,1)
    pos_x_val = np.asarray(pos_x_val).astype(np.float32).transpose(0,2,1)
    pos_x_test = np.asarray(pos_x_test).astype(np.float32).transpose(0,2,1)
    pos_x_pre = np.asarray(pos_x_pre).astype(np.float32).transpose(0,2,1)
    
    train_set = TensorDataset(torch.Tensor(str_x_train),torch.Tensor(pos_x_train),torch.Tensor(str_y_train))
    val_set = TensorDataset(torch.Tensor(str_x_val),torch.Tensor(pos_x_val),torch.Tensor(str_y_val))
    train_set = DataLoader(train_set, batch_size=5, drop_last=False, shuffle=True)
    val_set = DataLoader(val_set, batch_size=5, drop_last=False, shuffle=True)
    
    test_set = ((torch.Tensor(str_x_test),torch.Tensor(pos_x_test)),torch.Tensor(str_y_test))
    pre_set = (torch.Tensor(str_x_pre),torch.Tensor(pos_x_pre))
    
    return train_set,val_set,test_set,pre_set,str_x_train,str_y_train,str_x_test,str_y_test


In [None]:
#读取蛋白对应的氨基酸层面的距离矩阵
def data_extract_pos(position_path,probe):
    dist_dict = {}
    for sit_class in listdir(f'{position_path}{probe}'):
        for file in listdir(f'{position_path}{probe}/{sit_class}'):
            dist_matrix = pd.read_csv(f'{position_path}{probe}/{sit_class}/{file}')
            dist_dict.setdefault(file.split('.')[0],dist_matrix)
    return dist_dict

#因为发现同一探针下，有的蛋白的残基，同时出现在labeled和unlabeled的文件里，优先保留labeled的数据，而后进行去重
def data_relabel(structure_set):
    structure_set_ = pd.DataFrame()
    for values,table in structure_set.groupby(by=['ProtName','Residue','Res_id']):
        if len(table) >=2:
            if 'labeled_site' in table['DataClass'].to_list(): table['DataClass'] = 'labeled_site'
            elif 'unlabeled_site' in table['DataClass'].to_list(): table['DataClass'] = 'unlabeled_site'
        structure_set_ = pd.concat([structure_set_,table])
    structure_set_ = structure_set_.sort_values(by=['DataClass','ProtName','Res_id'],ascending=[True,True,True])
    structure_set_ = structure_set_.drop_duplicates().reset_index(drop=True)
    return structure_set_

In [None]:
#模型训练
def model_train(model,train_set,val_set,save_best_model=True,model_path=None,save_path=None,earlyStop=True,epochs=100,lr=0.001):
    if save_path is not None: early_stopping = EarlyStopping(save_path)
    else: earlyStop = False
    metrics_name = ['train_loss','val_loss','acc','precision','recall','f1','ap','bsl','mcc','auc_score','prc_score']
    #定义损失函数计算方法，定义优化器
    criterion = nn.BCELoss()
    optimizer = torch.optim.AdamW(dizco_cnn.parameters(),lr=lr,weight_decay=1e-3)
    scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=lr/10, max_lr=lr*10,
                                      cycle_momentum=False,step_size_up=len(train_set))
    
    metrics_result = []
    for epoch in tqdm(range(epochs)):
        trian_pbar = BackgroundGenerator(train_set)
        
        model.train()
        total_loss = 0
        
        for str_values,pos_values,labels in trian_pbar:
            str_values, pos_values, labels = str_values.to(device), pos_values.to(device), labels.float().to(device)
            optimizer.zero_grad()
            output = model(str_values, pos_values)
            loss = criterion(output.squeeze(), labels)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()
        train_loss = total_loss / len(train_set)
        
        eval_matrics = model_eval(model,val_set,criterion)
        metrics_result.append(tuple((train_loss,))+eval_matrics)
        #print(f'epoch {epoch}: {tuple((train_loss,))+eval_matrics}')
        
        if earlyStop:
            early_stopping(eval_matrics[0], model)
            if early_stopping.early_stop:
                print("Early stopping")
                return pd.DataFrame(metrics_result,columns=metrics_name)
            break
        
        if epoch == 0:
            best_mcc = eval_matrics[7]
            best_model = copy.deepcopy(model)
            best_metrics = tuple((epoch,train_loss))+eval_matrics
        else:
            if eval_matrics[7] > best_mcc:
                best_mcc = eval_matrics[7]
                best_model = copy.deepcopy(model)
                best_metrics = tuple((epoch,train_loss))+eval_matrics
        
    if save_best_model:
        torch.save(best_model.state_dict(),model_path+'best_model.pth')
        print('The metrics of best model are:\n')
        print(f'epoch: {best_metrics[0]}, train_loss: {best_metrics[1]}, val_loss: {best_metrics[2]}\n')
        print(f'acc: {best_metrics[3]}, precision: {best_metrics[4]}, recall: {best_metrics[5]}\n')
        print(f'f1: {best_metrics[6]}, ap: {best_metrics[7]}, bsl: {best_metrics[8]}\n')
        print(f'mcc: {best_metrics[9]}, AUC: {best_metrics[10]}, PRC: {best_metrics[11]}')
        
    return pd.DataFrame(metrics_result,columns=metrics_name)

In [None]:
#模型验证
def model_eval(model,val_set,criterion,thre=0.5):
    model.eval()
    total_loss = 0
    y_true,y_pred,y_prob = [],[],[]
    val_pbar = BackgroundGenerator(val_set)
    
    with torch.no_grad():
        for str_values,pos_values,labels in val_pbar:
            str_values, pos_values, labels = str_values.to(device), pos_values.to(device), labels.float().to(device)
            output = model(str_values, pos_values)
            prob = output.squeeze()
            loss = criterion(prob, labels)
            total_loss += loss.item()
            pred = (prob >= thre).int()
            y_true.extend(labels.tolist())
            y_pred.extend(pred.tolist())
            y_prob.extend(prob.tolist())
    
        val_loss = total_loss/len(val_set)
        eval_matrics = metrics_calculation(y_true,y_pred,y_prob)
        
        return tuple((val_loss,))+eval_matrics

In [None]:
#模型性能指标的计算
def metrics_calculation(y_true,y_pred,y_prob):
    y_true, y_pred, y_prob = np.float64(y_true), np.float64(y_pred), np.float64(y_prob)

    auc_score = roc_auc_score(y_true,y_prob)
    acc = accuracy_score(y_true,y_pred)
    precision = precision_score(y_true,y_pred,average=None)[1]
    recall = recall_score(y_true,y_pred)
    f1 = f1_score(y_true,y_pred)
    mcc = matthews_corrcoef(y_true,y_pred)
    ap = average_precision_score(y_true,y_prob,average=None)
    bsl = brier_score_loss(y_true,y_prob)
    tpr,fpr,_ = precision_recall_curve(y_true,y_prob)
    prc_score = auc(fpr,tpr)
    
    return tuple((acc,precision,recall,f1,ap,bsl,mcc,auc_score,prc_score))

In [None]:
#模型测试
def model_test(model,test_set):
    model.eval()
    test_str,test_pos,test_labels = test_set[0][0],test_set[0][1],test_set[1]
    test_str, test_pos, test_labels = test_str.to(device), test_pos.to(device), test_labels.float().to(device)
    test_output = model(test_str,test_pos)
    return np.float64(test_output.tolist()),np.float64(test_labels.tolist())

In [None]:
position_path = 'D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/distance_matrix/'
structure_path = 'D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/Structural Information Calculation/data output/'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_path = 'E:/Proteomics/PhD_script/1. Dizco/'

In [None]:
#以每个探针的数据为基础，为每个探针单独构建一个模型
#选取structure和position的数据作为特征，构建CNN模型
for probe in listdir(position_path):
    # break
    position_set = data_extract_pos(position_path,probe)
    structure_set = pd.read_csv(structure_path+f'{probe}_structure_infor.csv')
    structure_set = data_relabel(structure_set)
    train_set,val_set,test_set,pre_set,str_x_train,str_y_train,str_x_test,str_y_test = data_integation(probe,position_set,structure_set)
    
    #使用RFC作为baseline，数据仅包括结构信息
    # rfc = RandomForestClassifier()
    # rfc.fit(str_x_train,str_y_train)
    # print(rfc.score(str_x_test,str_y_test))
    # print(roc_auc_score(str_y_test,rfc.predict_proba(str_x_test)[:,1]))
    # print(matthews_corrcoef(str_y_test,rfc.predict(str_x_test)))
    
    for batch in train_set: str_dim,pos_dim = batch[0].shape[1],batch[1].shape[:3]
    dizco_cnn = dizco_CNN(str_dim,pos_dim).to(device)
    metrics_result = model_train(dizco_cnn,train_set,val_set,model_path=model_path,lr=0.00001)
    
    #dizco_cnn.load_state_dict(torch.load('E:/Proteomics/PhD_script/1. Dizco/best_model.pth'))
    test_output,test_labels = model_test(dizco_cnn,test_set)
    
    metrics_result.to_csv('D:/All_for_paper/1. PhD Work Program/3. Research project/1. Dizco/Model_Metrics_Result/AdamW_lr(1e-5)_batch(5).csv',index=False)