## 使用ESM预训练模型提取特征向量
参考：https://github.com/facebookresearch/esm/blob/c9c7d4f0fec964ce10c3e11dccec6c16edaa5144/scripts/extract.py

In [1]:
from transformers import EsmForMaskedLM, EsmTokenizer, EsmModel
from esm import Alphabet, FastaBatchedDataset
from torch.utils.data import DataLoader, Dataset
import os
import pathlib
import pandas as pd
import torch
import numpy as np
import os
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ProteinExtractionParams:
    def __init__(
        self,
        model='ESM-1b',
        EMB_LAYER = 33,
        model_seed = 1,
        fasta_file = None,
        csv_file = '../data/DMS_substitutions.csv',

        batch_size=32,
        repr_layers=[-1],
        include='mean',
        truncation_seq_length=1022,
        nogpu=False,
    ):
        self.model=model
        self.model_seed = model_seed
        self.EMB_LAYER = EMB_LAYER
        self.fasta_file = fasta_file
        self.csv_file = csv_file
        self.batch_size = batch_size
        self.repr_layers = repr_layers
        self.include = include
        self.truncation_seq_length = truncation_seq_length
        self.nogpu = nogpu
        self.save_path = csv_file.split('.csv')[0]+'esm_embed'

In [3]:

config = ProteinExtractionParams()

In [4]:
class Protein_Dataset(Dataset):
    def __init__(self, df, tokenizer, sep_len=1022):
        self.df = df
        self.tokenizer = tokenizer
        self.seq_len = sep_len
        self.seq, self.attention_mask = tokenizer(list(self.df['target_seq']), padding='max_length',
                                                  truncation=True,
                                                  max_length=self.seq_len).values()
        self.DMS_id = np.asarray(df['DMS_id'])
        self.pid = np.asarray(df.index)

    def __getitem__(self, idx):
        return [self.seq[idx], self.attention_mask[idx],self.DMS_id[idx],self.pid[idx]]

    def __len__(self):
        return len(self.df)

    def collate_fn(self, data):
        seq = torch.tensor(np.array([u[0] for u in data]))
        att_mask = torch.tensor(np.array([u[1] for u in data]))
        DMS_id = [u[2] for u in data]
        pid = torch.tensor(np.array([u[3] for u in data]))
        return seq, att_mask, DMS_id,pid

In [5]:
def embed(model,data_loader,save_path,repr_layers=33,return_contacts=False):
    model.eval()
    with torch.no_grad():
        for batch_idx, data in enumerate(data_loader):

            seq, mask, DMS_id,pid= data[0],data[1],data[2],data[3]
            print(
                f"Processing {batch_idx + 1} of {len(data)} batches ({seq.size(0)} sequences)"
            )
            if torch.cuda.is_available() and not config.nogpu:
                seq = seq.to(device="cuda", non_blocking=True)
                mask = mask.to(device="cuda", non_blocking=True)

            out = model(seq,mask)
            # print(out.last_hidden_state.mean(dim=1).shape)
            # print(out.pooler_output.shape)
            batch_representations = out.pooler_output
            for index,dms in enumerate(DMS_id):
                representations = batch_representations[index]
                torch.save(representations, f'{save_path}/{dms}.pt')
                # print(representations)



In [6]:
torch.cuda.is_available()

True

In [7]:
def run(config):
    if config.model == 'ESM-1v':
        esm_model = EsmModel.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{config.model_seed}')
        tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{config.model_seed}')
    elif config.model == 'ESM-2':
        esm_model = EsmModel.from_pretrained('facebook/esm2_t48_15B_UR50D')
        tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t48_15B_UR50D')
    elif config.model == 'ESM-1b':
        esm_model = EsmModel.from_pretrained('facebook/esm1b_t33_650M_UR50S')
        tokenizer = EsmTokenizer.from_pretrained('facebook/esm1b_t33_650M_UR50S')
    esm_model.eval()
    # print(esm_model)
    if torch.cuda.is_available() and not config.nogpu:
        esm_model = esm_model.cuda()
        print("Transferred model to GPU")
    if(config.csv_file):
        data_df = pd.read_csv('../data/DMS_substitutions.csv')
        dfset = Protein_Dataset(data_df,tokenizer=tokenizer)
        dfloader = DataLoader(dfset, batch_size=config.batch_size, collate_fn=dfset.collate_fn, shuffle=False)
    else:
        print('no file!')
    save_path = config.save_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    embed(esm_model,dfloader,save_path)


In [8]:
run(config)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Transferred model to GPU
Processing 1 of 4 batches (32 sequences)
Processing 2 of 4 batches (32 sequences)
Processing 3 of 4 batches (32 sequences)
Processing 4 of 4 batches (32 sequences)
Processing 5 of 4 batches (32 sequences)
Processing 6 of 4 batches (32 sequences)
Processing 7 of 4 batches (25 sequences)
