## 使用ESM预训练模型提取特征向量
参考：https://github.com/facebookresearch/esm/blob/c9c7d4f0fec964ce10c3e11dccec6c16edaa5144/scripts/extract.py

In [9]:
from transformers import EsmForMaskedLM, EsmTokenizer, BertModel
from esm import Alphabet, FastaBatchedDataset
from torch.utils.data import DataLoader, Dataset
import os
import pathlib
import pandas as pd
import torch
import numpy as np
import os
import re

In [10]:
class ProteinExtractionParams:
    def __init__(
        self,
        model='ESM-1v',
        EMB_LAYER = 33,
        model_seed = 1,
        fasta_file = None,
        csv_file = '../data/DMS_substitutions.csv',
        output_dir = None,
        toks_per_batch=4096,
        repr_layers=[-1],
        include='mean',
        truncation_seq_length=1022,
        nogpu=False,
    ):
        self.model=model
        self.model_seed = model_seed
        self.EMB_LAYER = EMB_LAYER
        self.fasta_file = fasta_file
        self.csv_file = csv_file

        # if not os.path.exists(output_dir):
        #     os.makedirs(output_dir)
        self.toks_per_batch = toks_per_batch
        self.repr_layers = repr_layers
        self.include = include
        self.truncation_seq_length = truncation_seq_length
        self.nogpu = nogpu

In [11]:

config = ProteinExtractionParams()

In [12]:
class Protein_Dataset(Dataset):
    def __init__(self, df, tokenizer, sep_len=1024):
        self.df = df
        self.tokenizer = tokenizer
        self.seq_len = sep_len
        self.seq, self.attention_mask = tokenizer(list(self.df['target_seq']), padding='max_length',
                                                  truncation=True,
                                                  max_length=self.seq_len).values()
        self.DMS_id = np.asarray(df['DMS_id'])
        self.pid = np.asarray(df.index)

    def __getitem__(self, idx):
        return [self.seq[idx], self.attention_mask[idx],self.DMS_id[idx],self.pid[idx]]

    def __len__(self):
        return len(self.df)

    def collate_fn(self, data):
        seq = torch.tensor(np.array([u[0] for u in data]))
        att_mask = torch.tensor(np.array([u[1] for u in data]))
        DMS_id = [u[2] for u in data]
        pid = torch.tensor(np.array([u[3] for u in data]))
        return seq, att_mask, DMS_id,pid

In [13]:
def embed(model,data_loader,repr_layers=33,return_contacts=False):
    model.eval()
    with torch.no_grad():
        for batch_idx, data in enumerate(data_loader):

            seq, mask, DMS_id,pid= data[0],data[1],data[2],data[3]
            print(
                f"Processing {batch_idx + 1} of {len(data)} batches ({seq.size(0)} sequences)"
            )
            if torch.cuda.is_available() and not config.nogpu:
                seq = seq.to(device="cuda", non_blocking=True)
                mask = mask.to(device="cuda", non_blocking=True)

            out = model(seq,mask)

            # logits = out["logits"].to(device="cpu")
            # representations = {
            #     layer: t.to(device="cpu") for layer, t in out["representations"].items()
            # }
            print(out)
            print(out["logits"].shape)
            break



In [14]:
torch.cuda.is_available()

True

In [15]:
def run(config):
    if config.model == 'ESM-1v':
        esm_model = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{config.model_seed}')
        tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{config.model_seed}')
    elif config.model == 'ESM-2':
        esm_model = EsmForMaskedLM.from_pretrained('facebook/esm2_t48_15B_UR50D')
        tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t48_15B_UR50D')
    elif config.model == 'ESM-1b':
        esm_model = EsmForMaskedLM.from_pretrained('facebook/esm1b_t33_650M_UR50S')
        tokenizer = EsmTokenizer.from_pretrained('facebook/esm1b_t33_650M_UR50S')
    esm_model.eval()
    # print(esm_model)
    if torch.cuda.is_available() and not config.nogpu:
        esm_model = esm_model.cuda()
        print("Transferred model to GPU")
    if(config.csv_file):
        data_df = pd.read_csv('../data/DMS_substitutions.csv')
        dfset = Protein_Dataset(data_df,tokenizer=tokenizer)
        dfloader = DataLoader(dfset, batch_size=32, collate_fn=dfset.collate_fn, shuffle=False)
    else:
        print('no file!')

    embed(esm_model,dfloader)


In [16]:
run(config)

Some weights of EsmForMaskedLM were not initialized from the model checkpoint at facebook/esm1v_t33_650M_UR90S_1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Transferred model to GPU
Processing 1 of 4 batches (32 sequences)
MaskedLMOutput(loss=None, logits=tensor([[[ 27.4595,  -9.7221,  -9.7196,  ...,  -9.4770,  -8.8551,  -9.5298],
         [-11.7753, -19.3972, -19.4954,  ..., -15.4997, -14.7554, -19.3613],
         [-14.4387, -20.4217, -20.3754,  ..., -15.5238, -15.7760, -20.6390],
         ...,
         [-14.0112, -18.7812, -18.7424,  ..., -15.5570, -15.5467, -18.8976],
         [-13.7338, -18.7242, -18.5399,  ..., -15.8084, -16.0953, -18.9025],
         [-16.1599, -19.2491, -19.1543,  ..., -15.8731, -15.9174, -19.3599]],

        [[ 25.8212,  -4.4358,  -4.4236,  ...,  -7.9350,  -9.1862,  -4.3968],
         [ -5.4564, -20.2427, -20.6580,  ..., -14.1852, -13.6468, -20.4660],
         [-14.0237, -20.4926, -20.8519,  ..., -14.5333, -14.4374, -20.6440],
         ...,
         [-16.1606, -21.0762, -21.4961,  ..., -14.6801, -15.1779, -21.1792],
         [-16.1606, -21.0762, -21.4961,  ..., -14.6801, -15.1779, -21.1792],
         [-16.1606, -21.