In [1]:
DATA_PATH = "../data/results/2023-Dec-05-04:14:26_generate_10.fasta" # Path to data
EMB_PRE_PATH = "/home/bli/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S.pt" 
EMBED_PATH ='./ESM_embed/'
EMB_LAYER = 33

In [2]:
import pathlib
import pandas as pd
import torch

from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer


import os

class ProteinExtractionParams:
    def __init__(
        self,
        model_location=EMB_PRE_PATH,
        fasta_file = None,
        csv_file = None,
        output_dir = None,
        toks_per_batch=10,
        repr_layers=[-1],
        include='mean',
        truncation_seq_length=512,
        nogpu=False,
    ):
        self.model_location = model_location
        self.fasta_file = fasta_file
        self.csv_file = csv_file

        self.output_dir = pathlib.Path(output_dir)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        self.toks_per_batch = toks_per_batch
        self.repr_layers = repr_layers
        self.include = include
        self.truncation_seq_length = truncation_seq_length
        self.nogpu = nogpu


def run(args):
    model, alphabet = pretrained.load_model_and_alphabet(args.model_location)
    model.eval()
    if isinstance(model, MSATransformer):
        raise ValueError(
            "This script currently does not handle models with MSA input (MSA Transformer)."
        )
    if torch.cuda.is_available() and not args.nogpu:
        model = model.cuda()
        print("Transferred model to GPU")


    if(args.fasta_file):
        dataset = FastaBatchedDataset.from_file(args.fasta_file)
        batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
        data_loader = torch.utils.data.DataLoader(
            dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches
        )
        print(f"Read {args.fasta_file} with {len(dataset)} sequences")
    elif(args.csv_file):
        data_df = pd.read_csv(args.csv_file)
        
        protein_id = data_df['id']
        # class FastaBatchedDataset(object):
        #     def __init__(self, sequence_labels, sequence_strs):
        #         self.sequence_labels = list(sequence_labels)
        #         self.sequence_strs = list(sequence_strs)
        dataset = FastaBatchedDataset(data_df['id'],data_df['seq'])
        batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
        data_loader = torch.utils.data.DataLoader(
            dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches
        )
        print(f"Read {args.csv_file} with {len(dataset)} sequences")
    else:
        print('no file!')

    args.output_dir.mkdir(parents=True, exist_ok=True)
    return_contacts = "contacts" in args.include                                                                                                                                

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available() and not args.nogpu:
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)

            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")

            for i, label in enumerate(labels):
                args.output_file = args.output_dir / f"{label}.pt"
                args.output_file.parent.mkdir(parents=True, exist_ok=True)
                result = {"label": label}
                truncate_len = min(args.truncation_seq_length, len(strs[i]))
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in args.include:
                    result["representations"] = {
                        layer: t[i, 1 : truncate_len + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in args.include:
                    result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in args.include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone() for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = conacts[i, : truncate_len, : truncate_len].clone()

                torch.save(
                    result,
                    args.output_file,
                )


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def extract_embed(data_file):
    input_data  =DATA_PATH
    output_dir = EMBED_PATH
    try:
        # 创建文件夹
        os.makedirs(output_dir)
        print(f"Folder '{output_dir}' has been created.")
    except FileExistsError:
        print(f"Folder '{output_dir}' already exists.")
    args = ProteinExtractionParams(fasta_file=input_data,output_dir=output_dir)
    run(args)
    print('Extract ESM embeddings for {}, save in {}'.format(input_data,output_dir))

In [4]:

extract_embed(DATA_PATH)

Folder './ESM_embed/' already exists.
Transferred model to GPU
Read ../data/results/2023-Dec-05-04:14:26_generate_10.fasta with 10 sequences
Processing 1 of 10 batches (1 sequences)
Processing 2 of 10 batches (1 sequences)
Processing 3 of 10 batches (1 sequences)
Processing 4 of 10 batches (1 sequences)
Processing 5 of 10 batches (1 sequences)
Processing 6 of 10 batches (1 sequences)
Processing 7 of 10 batches (1 sequences)
Processing 8 of 10 batches (1 sequences)
Processing 9 of 10 batches (1 sequences)
Processing 10 of 10 batches (1 sequences)
Extract ESM embeddings for ../data/results/2023-Dec-05-04:14:26_generate_10.fasta, save in ./ESM_embed/


In [5]:
def load_esm_embed(EMBED_PATH):

    EMB_LAYER = 33
    Xs = []

    for file in os.listdir(EMBED_PATH):
        fn = f'{EMBED_PATH}/{file}'
        embs = torch.load(fn)
        
        Xs.append(embs['mean_representations'][EMB_LAYER])

    Xs = torch.stack(Xs, dim=0).numpy()
    print('load esm embedding')

    return Xs




In [6]:
embed_temp = load_esm_embed(EMBED_PATH)

load esm embedding


In [7]:
pd.DataFrame(embed_temp)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1270,1271,1272,1273,1274,1275,1276,1277,1278,1279
0,0.167233,0.130327,-0.168505,0.030995,-0.039343,-0.08955,-0.243623,-0.142271,-0.119898,0.069863,...,0.125417,0.020173,-0.048031,0.2146,-1.258521,-0.07579,0.121951,0.013886,0.198329,0.207286
1,0.142538,0.001049,-0.095545,0.11419,0.041693,-0.1765,-0.208702,-0.056122,-0.150602,-0.11546,...,0.132358,0.097449,-0.016638,0.017523,-1.032016,-0.104907,0.002389,0.187973,0.164439,0.159598
2,0.166239,0.100672,-0.163813,0.078611,0.019763,-0.116436,-0.133373,-0.009965,-0.058678,0.031321,...,0.137263,-0.009128,-0.049559,0.120188,-1.012077,-0.077144,0.12606,0.010027,0.176264,0.155619
3,0.028822,0.143447,-0.191251,0.218951,-0.138921,-0.007613,-0.143106,-0.04132,-0.072974,-0.03892,...,0.167929,-0.11038,-0.07319,0.170297,-0.821683,-0.033038,0.023359,-0.10837,-0.05446,0.023952
4,0.117276,0.050493,-0.087267,0.177359,-0.047596,-0.171305,-0.322699,0.062665,-0.02913,-0.088583,...,-0.049856,0.120746,-0.132633,0.104556,-1.391821,0.102206,0.010904,0.151753,0.136671,0.025402
5,0.112103,0.153404,-0.13976,0.124497,-0.136296,-0.074639,-0.157331,-0.027425,-0.050942,0.061439,...,0.129701,-0.052547,-0.06477,0.175037,-1.251352,0.041543,0.146135,0.041446,0.105533,0.136077
6,0.208613,0.230555,-0.153985,0.193166,-0.014747,-0.149238,-0.32604,-0.134078,-0.19303,-0.023235,...,0.118994,-0.015515,-0.017523,0.049031,-1.140906,0.021154,-0.052373,-0.030905,0.135686,0.227094
7,0.184214,0.109668,-0.151756,0.121802,-0.060418,-0.080945,-0.140876,-0.017313,-0.079954,0.078913,...,0.048046,0.055762,-0.103678,0.083915,-1.258884,-0.029053,0.100055,-0.004953,0.13381,0.107471
8,0.150803,0.097124,-0.152745,0.142981,-0.033813,-0.106572,-0.20877,-0.066382,-0.096352,0.050012,...,0.014088,0.018893,-0.085045,0.119966,-1.387731,-0.031153,0.086471,0.080461,0.09371,0.135901
9,0.193531,0.13022,-0.190075,0.108972,0.046108,-0.16308,-0.299659,-0.121659,-0.151458,-0.031004,...,0.052929,0.016717,-0.051666,0.112521,-1.235597,-0.018505,-0.045528,0.01252,0.091805,0.215697


In [8]:
from autogluon.tabular import TabularDataset,TabularPredictor
import pandas as pd
predicter = TabularPredictor.load('./AutoML_ESM/')
y_pred = predicter.predict(pd.DataFrame(embed_temp))
y_pred

0    0
1    0
2    0
3    0
4    0
5    0
6    0
7    0
8    0
9    0
Name: label, dtype: int64

In [12]:
list(y_pred)
b = ['a','b','c','d','e','f']
a = [1,1,0,0,1,0]
selected_seq= [string for flag, string in zip(a, b) if flag == 1]
selected_seq

['a', 'b', 'e']