In [17]:
DATA_PATH = "../dataset/dataset2.0/" # Path to data
EMB_PRE_PATH = "/home/bli/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S.pt" 
EMBED_PATH =DATA_PATH+'ESM_finetune_embed/'
EMB_LAYER = 33
import os
import pandas as pd
import pathlib
import pandas as pd
import torch
import numpy as np
from torch import nn 
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
import utils
from model import LayerNormNet
from sklearn.metrics import accuracy_score,roc_auc_score,roc_curve,f1_score,recall_score,precision_score,matthews_corrcoef

In [18]:
class ProteinExtractionParams:
    def __init__(
        self,
        model_location=EMB_PRE_PATH,
        fasta_file = None,
        csv_file = None,
        output_dir = None,
        toks_per_batch = 1024,
        repr_layers=[33],
        include='mean',
        truncation_seq_length = 512,
        nogpu=False,
    ):
        self.model_location = model_location
        self.fasta_file = fasta_file
        self.csv_file = csv_file


        self.toks_per_batch = toks_per_batch
        self.repr_layers = repr_layers
        self.include = include
        self.truncation_seq_length = truncation_seq_length
        self.nogpu = nogpu
args = ProteinExtractionParams(output_dir='./ESM_finetuning/')

In [19]:
class fine_tunning_esm(nn.Module):
  def __init__(self, n_classes=2,args = None):
    super(fine_tunning_esm, self).__init__()
    self.esm1b,self.alphabet = pretrained.load_model_and_alphabet(args.model_location)
    self.drop1 = nn.Dropout(p=0.2)
    self.drop2 = nn.Dropout(p=0.5)
    self.drop3 = nn.Dropout(p=0.2)

    self.relu = nn.ReLU()
    self.linear = nn.Linear(1280, n_classes)
    self.softmax = nn.Softmax(dim=1)
  def forward(self, strs,toks):
    return_contacts = "contacts" in args.include
    out = self.esm1b(toks, repr_layers=args.repr_layers, return_contacts=return_contacts)

    representations = {
        layer: t for layer, t in out["representations"].items()
    }
    if return_contacts:
        contacts = out["contacts"].to(device="cpu") 
    x = []   
    for i, tok in enumerate(toks):
        result = {}
        truncate_len = min(args.truncation_seq_length, len(strs[i]))
        if "per_tok" in args.include:
            result["representations"] = {
                layer: t[i, 1 : truncate_len + 1].clone()
                for layer, t in representations.items()
            }
        if "mean" in args.include:
            result["mean_representations"] = {
                layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                for layer, t in representations.items()
            }
        if "bos" in args.include:
            result["bos_representations"] = {
                layer: t[i, 0].clone() for layer, t in representations.items()
            }
        if return_contacts:
            result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
        x.append(result['mean_representations'][33])
    x = torch.vstack((*x,))
    out = self.linear(x)
    return out
class fine_tunning_esm_LN(nn.Module):
    
  def __init__(self,args = None):
    super(fine_tunning_esm_LN, self).__init__()
    self.esm1b,self.alphabet = pretrained.load_model_and_alphabet(args.model_location)
    self.model = LayerNormNet(hidden_dim=512,out_dim=2)

  def forward(self, strs,toks):
    return_contacts = "contacts" in args.include
    out = self.esm1b(toks, repr_layers=args.repr_layers, return_contacts=return_contacts)

    representations = {
        layer: t for layer, t in out["representations"].items()
    }
    if return_contacts:
        contacts = out["contacts"].to(device="cpu") 
    x = []   
    for i, tok in enumerate(toks):
        result = {}
        truncate_len = min(args.truncation_seq_length, len(strs[i]))
        if "per_tok" in args.include:
            result["representations"] = {
                layer: t[i, 1 : truncate_len + 1].clone()
                for layer, t in representations.items()
            }
        if "mean" in args.include:
            result["mean_representations"] = {
                layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                for layer, t in representations.items()
            }
        if "bos" in args.include:
            result["bos_representations"] = {
                layer: t[i, 0].clone() for layer, t in representations.items()
            }
        if return_contacts:
            result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
        x.append(result['mean_representations'][33])
    x = torch.vstack((*x,))
    out = self.model(x)
    return out
class ESM_CNN1d(nn.Module):
  def __init__(self,kernel_sizes = [3, 4, 5], n_classes=2,args = None,num_channels = [512, 512, 512]):
    super(ESM_CNN1d, self).__init__()
    self.esm1b,self.alphabet = pretrained.load_model_and_alphabet(args.model_location)
    self.dropout = nn.Dropout(0.5)
    self.pool = nn.AdaptiveAvgPool1d(1)
    self.relu = nn.ReLU()
    # 创建多个一维卷积层
    self.convs = nn.ModuleList()
    for c, k in zip(num_channels, kernel_sizes):
        self.convs.append(nn.Conv1d(1280, c, k))
    self.relu = nn.ReLU()
    self.decoder = nn.Linear(sum(num_channels), n_classes)
    self.softmax = nn.Softmax(dim=1)
  def forward(self, strs,toks):
    return_contacts = "contacts" in args.include
    out = self.esm1b(toks, repr_layers=args.repr_layers, return_contacts=return_contacts)
    embeddings = out["representations"][33]
    embeddings = embeddings.permute(0, 2, 1)
    encoding = torch.cat([
        torch.squeeze(self.relu(self.pool(conv(embeddings))), dim=-1)
        for conv in self.convs], dim=1)
    outputs = self.decoder(self.dropout(encoding))
    return outputs

In [20]:
def get_testDataloader(test_df,alphabet,args):
    test_dataset = FastaBatchedDataset(test_df['label'],test_df['seq'])
    test_batches = test_dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=test_batches
    )
    return test_dataloader

In [21]:
def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())
def evaluate_accuracy_gpu(net, data_iter, devices):
    """使用GPU计算模型在数据集上的精度

    Defined in :numref:`sec_lenet`"""
    net.eval()  # 设置为评估模式
    # 正确预测的数量，总预测的数量
    metric = utils.Accumulator(2)
    pred_y = []
    y = []
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_iter):
            labels = torch.tensor(labels)
            labels = labels.to(devices[0])
            toks = toks.to(devices[0])
            y_hat = net(strs,toks)
            metric.add(accuracy(y_hat, labels), labels.numel())
            
            y_hat = y_hat.argmax(axis=1)
            pred_y.append(y_hat)
            y.append(labels)

    y = torch.cat(y,dim=0)
    y_hat =  torch.cat(pred_y,dim=0)
    acc = metric[0] / metric[1]
    return y,y_hat,acc

def evaluate(true_y,pred_y):
    # 计算评估指标
    true_y = true_y.cpu()
    pred_y = pred_y.cpu()
    accuracy = accuracy_score(true_y, pred_y)
    mcc = matthews_corrcoef(true_y, pred_y)
    roc_auc = roc_auc_score(true_y, pred_y)
    f1 = f1_score(true_y, pred_y)
    recall = recall_score(true_y, pred_y)
    precision = precision_score(true_y,pred_y)

    # 将评估结果转换为字典
    results = {
        'accuracy': accuracy,
        'mcc' :mcc,
        'roc_auc': roc_auc,
        'f1_score': f1,
        'recall_score': recall,
        'precision':precision
    }

    # # 打印字典结果
    print(results)
    return(results)

In [22]:
def test(net,alphabet,test_csv ='test_data1_2023-6-9_15_31.csv', devices = utils.try_all_gpus()):
    print('start test---------------->>>>>>>>>>')
    test_df = pd.read_csv(DATA_PATH+test_csv)
    test_dataloader = get_testDataloader(test_df=test_df,alphabet=alphabet,args=args)
    y,y_hat,acc = evaluate_accuracy_gpu(net,test_dataloader,devices=devices)
    print(acc)
    
    evaluate(true_y=y,pred_y=y_hat)

   

In [23]:

esm1b,alphabet = pretrained.load_model_and_alphabet(args.model_location)
net = fine_tunning_esm_LN(args=args)
path_checkpoint = "./ESM_finetuning/checkpoint_esm2_2.0/bestmodel.pkl"  # 模型权重路径
checkpoint = torch.load(path_checkpoint)  # 加载权重
net.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
net.cuda()
trainer = torch.optim.Adam([{"params": net.esm1b.parameters(), "lr": 1e-7}],
                            lr = 1e-7, weight_decay=1e-3)
trainer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数



In [32]:
test(net=net,alphabet=alphabet,test_csv='test_data2_2023-6-9_15_31.csv')

start test---------------->>>>>>>>>>


0.708029197080292
{'accuracy': 0.708029197080292, 'mcc': 0.46829651968447306, 'roc_auc': 0.6788172043010753, 'f1_score': 0.5348837209302326, 'recall_score': 0.3709677419354839, 'precision': 0.9583333333333334}


In [25]:
class f_esm(nn.Module):
  def __init__(self, n_classes=2,args = None):
    super(f_esm, self).__init__()
    self.esm1b,self.alphabet = pretrained.load_model_and_alphabet(args.model_location)
    self.drop1 = nn.Dropout(p=0.2)
    self.drop2 = nn.Dropout(p=0.5)
    self.drop3 = nn.Dropout(p=0.2)

    self.relu = nn.ReLU()
    self.linear = nn.Linear(1280, n_classes)
    self.softmax = nn.Softmax(dim=1)
  def forward(self, strs,toks):
    return_contacts = "contacts" in args.include
    esm_out = self.esm1b(toks, repr_layers=args.repr_layers, return_contacts=return_contacts)

    representations = {
        layer: t for layer, t in esm_out["representations"].items()
    }
    if return_contacts:
        contacts = esm_out["contacts"].to(device="cpu") 
    x = []   
    for i, tok in enumerate(toks):
        result = {}
        truncate_len = min(args.truncation_seq_length, len(strs[i]))
        if "per_tok" in args.include:
            result["representations"] = {
                layer: t[i, 1 : truncate_len + 1].clone()
                for layer, t in representations.items()
            }
        if "mean" in args.include:
            result["mean_representations"] = {
                layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                for layer, t in representations.items()
            }
        if "bos" in args.include:
            result["bos_representations"] = {
                layer: t[i, 0].clone() for layer, t in representations.items()
            }
        if return_contacts:
            result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
        x.append(result['mean_representations'][33])
    x = torch.vstack((*x,))
    out = self.linear(x)
    return x

In [26]:
def get_finetune_embed(net, data_iter, devices):
    net.eval()  # 设置为评估模式
    embedding = []
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_iter):

            toks = toks.to(devices[0])
            embed = net(strs,toks)
            embed = embed.detach().cpu().numpy()
            
            embedding.append(embed)

    embedding = np.vstack((*embedding,))
    print(embedding.shape)
    return embedding

In [27]:
def extract_embed(data_file,net):
    input_data  =DATA_PATH+data_file
    output_dir = EMBED_PATH+data_file.split('.')[0]
    test_df = pd.read_csv(input_data)
    dataloader = get_testDataloader(test_df=test_df,alphabet=alphabet,args=args)
    embeds = get_finetune_embed(net=net,data_iter = dataloader,devices= utils.try_all_gpus())
    np.save(output_dir+'_embeds.npy', embeds)
    print('Extract ESM fine_tuning embeddings for {}, save in {}'.format(input_data,output_dir))

In [28]:
# net = f_esm(args = args)
# path_checkpoint = "./ESM_finetuning/checkpoint_esm2_2.0/bestmodel.pkl"  # 模型权重路径
# checkpoint = torch.load(path_checkpoint)  # 加载权重
# net.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
# net.cuda()
# for file in os.listdir(DATA_PATH):
#     if file.endswith('.csv'):
#         print(file)
#         extract_embed(file,net)
        

In [29]:
# import os
# import torch
# import pandas as pd
# import numpy as np
# from sklearn.decomposition import PCA
# def load_embed(csv_file):
#     Embed_PATH = EMBED_PATH+csv_file.split('.')[0]+'_embeds.npy'
#     data_df =  pd.read_csv(DATA_PATH+csv_file)
#     ys = data_df['label']
#     Xs = np.load(Embed_PATH)
#     print('load{} esm finetuning embedding from {}'.format(csv_file,Embed_PATH))
#     print(len(ys))
#     print(Xs.shape)
#     return Xs,ys
# Xs = []
# ys = []
# for file in os.listdir(DATA_PATH):
#     if file.endswith('.csv'):
#         if file.endswith('llpsdb_d.csv'):continue
#         x,y = load_embed(file)
#         Xs.append(x)
#         ys.append(y)
# Xs = np.vstack((*Xs,))
# ys = [y for sub in ys for y in sub]
# print(len(ys))
# print(Xs.shape)
# num_pca_components = 2
# pca = PCA(num_pca_components)
# Xs_pca = pca.fit_transform(Xs)

In [30]:
# import matplotlib.pyplot as plt
# fig_dims = (7, 6)
# fig, ax = plt.subplots(figsize=fig_dims)
# sc = ax.scatter(Xs_pca[:,0], Xs_pca[:,1], c=ys, marker='.')
# ax.set_xlabel('ESM finetuning PCA first principal component')
# ax.set_ylabel('ESM finetuning PCA second principal component')
# plt.colorbar(sc, label='LLPS tendency')