In [111]:
import os
import sys
import esm
import time
import torch
import random
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from Bio import SeqIO
from torch import einsum
from pathlib import Path
from einops import rearrange
from torch.utils.data import DataLoader

In [112]:
class Args:
    def __init__(self, mode=None, esm_model_path=None, weight=None, In=None, Out=None, Maxlen=None):
        self.mode = mode
        self.esm_model_path = esm_model_path
        self.weight = weight
        self.In = In
        self.Out = Out
        self.Maxlen = Maxlen

args = Args(mode = 'clef',
            esm_model_path = './pretrained_model/esm2_t33_650M-UR50D.pt',
            weight = './pretrained_model/CLEF-DP+MSA+3Di+AT.pt',
            In = 'Test_demo.faa',
            Out = 'Test_clef_rep')

args

<__main__.Args at 0x262ec0f7400>

In [113]:
mode = args.mode
esm_model_path = args.esm_model_path
model_params_path = args.weight
input_file = args.In
output_file = args.Out
maxlength = args.Maxlen

#esm_model_path = Path(os.path.abspath(esm_model_path))
esm_config = {'pretrained_model_params':esm_model_path} if esm_model_path else None

model_params_path = Path(os.path.abspath(model_params_path))

config = {
    'input_file':input_file,
    'output_file':output_file,
    'model_params_path':model_params_path,
    'esm_config':esm_config,
    'maxlength':256,
    'mode':mode
    
}

config


{'input_file': 'Test_demo.faa',
 'output_file': 'Test_clef_rep',
 'model_params_path': WindowsPath('f:/FDU/CLEF/Code/pretrained_model/CLEF-DP+MSA+3Di+AT.pt'),
 'esm_config': {'pretrained_model_params': './pretrained_model/esm2_t33_650M-UR50D.pt'},
 'maxlength': 256,
 'mode': 'clef'}

In [114]:
def fasta_to_EsmRep(input_fasta, output_file = None, 
                      pretrained_model_params = None,
                      maxlen = 256,
                      Return = True, 
                      Final_pool = False):
  '''
  input_file : input local fasta file path 
  output_file : output encoded file path 
  '''
  import torch
  import esm
  pretrained_model_params = pretrained_model_params if pretrained_model_params else os.path.join(find_root_path, 'Code/pretrained_model/esm2_t33_650M_UR50D.pt')
  aa_dict = {amino_acid: i for i, amino_acid in enumerate("ACDEFGHIKLMNPQRSTVWYX")}
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  try:
      input_embedding_net, alphabet = esm.pretrained.load_model_and_alphabet_local(pretrained_model_params)
  except:
      print(f"Skip loading local pre-trained ESM2 model from {pretrained_model_params}.\nTry to use ESM2-650M downloaded from hub")
      weight_path = os.path.dirname(os.path.abspath(pretrained_model_params))
      if os.path.exists(weight_path) and os.path.isdir(weight_path):
            torch.hub.set_dir(weight_path)
      else:
            print(f"Download ESM2-650M to ./cache")
      input_embedding_net, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
  batch_converter = alphabet.get_batch_converter()
  input_embedding_net = input_embedding_net.to(device)
  input_embedding_net.eval()
  output_dict = {}
  real_maxlen = max(1, maxlen - 2)
  num_layer = len(input_embedding_net.layers)
  for record in SeqIO.parse(open(input_fasta), 'fasta'):
    sequence = str(record.seq[: real_maxlen])  
    sequence = "".join([x if x in aa_dict else 'X' for x in sequence])
    data = [
    ("protein1", sequence),
      ]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
      results = input_embedding_net(batch_tokens, repr_layers=[num_layer], return_contacts=True)
    token_representations = results["representations"][num_layer]
    embedding = np.array(token_representations.squeeze(0).detach().to('cpu')).astype(np.float16)
    embedding = embedding[:real_maxlen + 2, ]
    embedding = embedding.mean(0) if Final_pool else embedding
    output_dict[record.id] = embedding
  if output_file:
      try:
          with open(output_file, 'wb') as f:
            pickle.dump(output_dict, f)
          print(f'ESM2 array saved as {output_file}')
      except:
          print(f'ESM2 array failed to save as {output_file}')
          import uuid
          tmp_name = str(uuid.uuid4())+'_esm'
          output_file =os.path.join(os.path.dirname(input_file), tmp_name) 
          with open(output_file, 'wb') as f:
            pickle.dump(output_dict, f)
          print(f'Temp ESM2 array saved as {output_file}')
  if Return:
      return output_dict

In [115]:
def is_fasta_file(file_path):
    try:
        with open(file_path, 'r') as file:
            first_line = file.readline().strip()
            return first_line.startswith(">")
    except Exception:
        return False


In [116]:
def find_root_path():
    try:
        current_dir = os.path.dirname(os.path.abspath(__file__))
    except:
        current_dir = os.getcwd()
    project_root = os.path.abspath(os.path.join(current_dir, os.pardir))
    return project_root

In [117]:
def check_hidden_layer_dimensions(data_dict):
    hidden_layer_size = None
    for key, value in data_dict.items():
        if not isinstance(value, np.ndarray):
            raise ValueError(f"Value for key '{key}' is not a numpy array.")

        current_size = value.shape[-1]
        if hidden_layer_size is None:
            hidden_layer_size = current_size
        elif hidden_layer_size != current_size:
            return None  

    return hidden_layer_size

In [118]:
def load_feature_from_local(feature_path, silence=False):
    '''
    load feature dict from local path (Using pickle.load() or torch.load())
    the dictionary is like:
        {
          Protein_ID : feature_array [a 1D numpy array]
        }
    '''
    # Try pickle.load() function 
    try:
        with open(feature_path, 'rb') as f:
            obj = pickle.load(f)
        if not silence:
            print("File is loaded using pickle.load()")
        return obj
    except (pickle.UnpicklingError, EOFError):
        pass

    # Try torch.load() function
    try:
        obj = torch.load(feature_path)
        if not silence:
            print("File is loaded using torch.load()")
        return obj
    except (torch.serialization.UnsupportedPackageTypeError, RuntimeError):
        pass

    print("Unable to load file.")
    return None

In [119]:
def generate_clef_feature(input_file, 
                          output_file,
                          model,
                          params_path = None,
                          loader_config = {'batch_size':64, 'max_num_padding':256},
                          esm_config = {'Final_pool':False, 'maxlen':256, 'Return':False},
                          MLP_proj = False,
                          res_rep = False,
                          Return = False):
    from Data_utils import Potein_rep_datasets
    import torch
    if is_fasta_file(input_file):
        print(f"Transform representation from fasta file {input_file}")
        import uuid
        tmp_file = str(uuid.uuid4())+'_tmp'
        tmp_file = os.path.join(os.path.dirname(input_file), tmp_file)
        esm_config = esm_config if isinstance(esm_config, dict) else {'Final_pool':False, 'maxlen':256, 'Return':False}
        esm_config['input_fasta'] = input_file
        esm_config['output_file'] = tmp_file
        esm_config['Return'] = False
        try:
            fasta_to_EsmRep(**esm_config)
        except:
            print("Failed to transform fasta into ESM embeddings, make sure esm config is correct")
        tmpset = Potein_rep_datasets({'esm_feature':tmp_file})
        try:
            os.remove(tmp_file)
            print("Tmp esm file {tmp_file} removed.")
        except:
            pass
    else:
        print(f"Direct load esm representations from {input_file}")
        tmpset = Potein_rep_datasets({'esm_feature':input_file})
    
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    if isinstance(model, torch.nn.Module):
        model.eval()
    if params_path:
      print(f"Try to load model weights from {params_path}")
      try:
        loaded_params = torch.load(params_path, map_location=device)
        model.load_state_dict(loaded_params)
        print(f"Load model weights successfully")
      except:
        print(f"Failed to load model weights from {params_path}")
    
    loader_config = loader_config if isinstance(loader_config, dict) else {'batch_size':64, 'max_num_padding':256}
    loader_config['shuffle'] = False
    loader_config['device'] = device
    IDs = []
    features = []
    for batch in tmpset.Dataloader(**loader_config):
        with torch.no_grad():
            feat, proj_feat = model(batch, Return_res_rep = res_rep)
        feature = proj_feat if MLP_proj else feat
        feature_list = [feature[i,:].detach().to('cpu').numpy() for i in range(feature.shape[0])]
        IDs.extend(batch['ID'])
        features.extend(feature_list)
    output_dict = {ID:feat.astype(np.float16) for ID, feat in zip(IDs, features)}
    
    if output_file:
        try:
            with open(output_file, 'wb') as f:
              pickle.dump(output_dict, f)
            print(f'CLEF array saved as {output_file}')
        except:
            print(f'CLEF array failed to save as {output_file}')
            import uuid
            tmp_name = str(uuid.uuid4())+'_clef'
            output_file =os.path.join(os.path.dirname(input_file), tmp_name) 
            with open(output_file, 'wb') as f:
              pickle.dump(output_dict, f)
            print(f'Temp CLEF array saved as {output_file}')
    if Return:
        return output_dict

In [120]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_key=64, dim_value=64, dropout=0.):
        super().__init__()
        self.scale = dim_key ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(dim, dim_key * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_key * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_value * heads, bias=False)
        self.to_out = nn.Linear(dim_value * heads, dim)
        self.attn_dropout = nn.Dropout(dropout)
        # 初始化模型参数
        self.reset_parameter()

    def reset_parameter(self):
        # xavier初始化使输入输出方差一致：xavier_uniform_均匀分布初始化，xavier_normal_正态分布初始化
        nn.init.xavier_uniform_(self.to_q.weight)
        nn.init.xavier_uniform_(self.to_k.weight)
        nn.init.xavier_uniform_(self.to_v.weight)
        nn.init.xavier_uniform_(self.to_out.weight)
        nn.init.zeros_(self.to_out.bias)

    def forward(self, x, mask=None):
        # x:[batchsize, sequence_length, dim]
        n, h = x.shape[-2], self.heads

        # 从 x 生成 q, k, v  [batchsize, sequence_length, dim_k * heads]
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # map （函数，可迭代对象）； lambda 变量：处理方式 ；
        #  [batch_size, heads, sequence_length, dim_key]
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        # q/dim**0.5
        q = q * self.scale
        
        # q, k 计算点积注意力 [batchsize, head, sequence_length, sequence_length]
        logits = einsum('b h i d, b h j d -> b h i j', q, k)

        # -1e9掩码
        if mask is not None:
            logits.masked_fill(mask, -1e9)

        # softmax(q*k/d**0.5) [batchsize, head, sequence_length, sequence_length]
        attn = logits.softmax(dim=-1)
        
        # dropout
        attn = self.attn_dropout(attn)

        # v*softmax(q*k/d**0.5) [batchsize, head, sequence_length, dim_value]
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        
        #  [batch_size, sequence_length,  dim_value * heads] 
        out = rearrange(out, 'b h n d -> b n (h d)')

        #  dim_value * heads -> dim 
        return self.to_out(out), attn

class TransformerLayer(nn.Module):

    def __init__(self, hid_dim, heads, dropout_rate, att_dropout=0.05):
        super().__init__()
        
        # dim, head, qk_dim, v_dim, dropout， 隐藏层维度整除分类头
        self.attn = Attention(hid_dim, heads, hid_dim //
                              heads, hid_dim // heads, att_dropout)
        
        # feedforward network
        self.ffn = nn.Sequential(
            nn.LayerNorm(hid_dim),
            nn.Linear(hid_dim, hid_dim * 2),
            nn.GELU(),
            nn.Linear(hid_dim * 2, hid_dim),
            nn.Dropout(dropout_rate))
        self.layernorm = nn.LayerNorm(hid_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):

        # [batch_size, sequence_length, hid_dim]
        residual = x
        x = self.layernorm(x)  # pre-LN
        
        # x = [batch_size, sequence_length,  hid_dim] 
        # attn = [batchsize, head, sequence_length, sequence_length]
        x, attn = self.attn(x, mask)
        x = self.dropout(x)
        x = residual + x

        residual = x
        x = self.ffn(x)
        x = residual + x

        return x, attn

In [121]:

def sequence_mask(X, valid_lens):
    mask = torch.zeros((X.shape[0], X.shape[1]), dtype = torch.bool).to(X.device)                  
    expanded_valid_lens = valid_lens.view(-1, 1).expand(X.shape[0], X.shape[1])    
    src_key_padding_mask = mask.masked_fill(torch.arange(X.shape[1]).to(X.device).view(1, -1).expand(X.shape[0], X.shape[1]) >= expanded_valid_lens, True)
    return src_key_padding_mask


In [122]:
# Encoder A
class clef_enc(nn.Module): 

    def __init__(self, num_embeds, num_hiddens=128, finial_drop=0.1, mlp_relu=True):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerLayer(num_embeds, 8, 0.45, 0.05)
                for _ in range(2)
            ]
        )
        self.Dropout = nn.Dropout(finial_drop)
        self.ln = nn.LayerNorm(num_embeds)
        if mlp_relu:
            # True则包含两个线性层和一个ReLU
            self.mlp = nn.Sequential(nn.Linear(num_embeds, 2 * num_embeds), nn.ReLU(),
                                     nn.Linear(2 * num_embeds, num_hiddens))
        else:
            # False则只包含一个线性层
            self.mlp = nn.Linear(num_embeds, num_hiddens)


    def forward(self, batch, Return_res_rep=False):

        X, valid_lens = batch['esm_feature'], batch['valid_lens']

        src_key_padding_mask = sequence_mask(X, valid_lens)
        
        # 前向传播过程，[b, n] -> [b, 1, 1, n]
        for layer in self.layers:
            X, _ = layer(X, mask=src_key_padding_mask.unsqueeze(1).unsqueeze(2))

        # whether return embeddings per-residue
        if not Return_res_rep:   
            X = torch.cat([X[i, :valid_lens[i] + 2].mean(0).unsqueeze(0)
                           for i in range(X.size(0))], dim=0)
            proj_X = self.mlp(self.Dropout(X))
        else:
            proj_X = torch.cat([X[i, :valid_lens[i]].mean(0).unsqueeze(0)
                                for i in range(X.size(0))], dim=0)
            proj_X = self.mlp(self.Dropout(proj_X))

        return X, proj_X

In [123]:
def generate_ESM_feature(input_embeddings_path, output_file = "temp"):
    
    with open(input_embeddings_path, 'rb') as f:
        input_embedding = pickle.load(f)
    
    output_dict = {}
    
    for key, value in input_embedding.items():

        output_feat = value.mean(0)   
        output = output_feat
        output_dict[key] = output
        
    with open(output_file, 'wb') as f:
        pickle.dump(output_dict, f)

In [124]:
def generate_protein_representation(input_file,
                    output_file,
                    model_params_path = None,
                    tmp_dir = "./tmp",
                    embedding_generator = fasta_to_EsmRep,
                    esm_config = None,
                    remove_tmp = True,
                    mode = 'clef',
                    maxlength = 256    # Hyperparameter determining how many amino acids are used in protein-encoding by PLM
                    ):
    if not os.path.exists(tmp_dir):
        os.mkdir(tmp_dir)
        print(f'Make a temp directory:{tmp_dir}')
    if is_fasta_file(input_file):
        print(f"Transform representation from fasta file {input_file}")
        import uuid
        tmp_file = str(uuid.uuid4())+'_tmp_esm'
        tmp_file = os.path.join(tmp_dir, tmp_file)
        esm_config = esm_config if isinstance(esm_config, dict) else {'Final_pool':False, 'maxlen':maxlength}
        esm_config['input_fasta'] = input_file
        esm_config['output_file'] = tmp_file
        esm_config['Return'] = False
        if 'pretrained_model_params' not in esm_config:
            esm_config['pretrained_model_params'] = os.path.join(find_root_path(), "./pretrained_model/esm2_t33_650M_UR50D.pt")
        try:
            embedding_generator(**esm_config)
        except:
            print("Failed to transform fasta into ESM embeddings, make sure esm config is correct")
            import shutil
            shutil.rmtree(tmp_dir)
            sys.exit(1)

       
    if mode.lower() == 'clef':
        print(f"Using pre-trained encoder in CLEF to generate protein representations")
        num_hidden = check_hidden_layer_dimensions(load_feature_from_local(tmp_file, silence=True))
        assert num_hidden, "Dimension numbers of the last dimension is not same"
        device='cuda:0' if torch.cuda.is_available() else 'cpu'
        encoder = clef_enc(num_hidden).to(device)
        try:
            encoder.load_state_dict(torch.load(model_params_path, map_location=torch.device('cpu')), strict=False)
            print(f"Successfully load CLEF params from {model_params_path}.")
        except:
            print(f"Failed to load CLEF params from {model_params_path}, make sure it is a valid weights for CLEF")
            import shutil
            shutil.rmtree(tmp_dir)
            sys.exit(1)
        tmp_output= output_file
        loader_config = {'batch_size':64, 'max_num_padding':256}
        config = {
          'input_file':tmp_file,
          'output_file':tmp_output,
          'model':encoder,
          'params_path':None,
          'loader_config':loader_config
        }
        generate_clef_feature(**config)

    elif mode.lower() == 'esm':
        print(f"Direct generate esm representations")
        conf = {
        'input_embeddings_path' : tmp_file,
        'output_file' : output_file,
        }
        generate_ESM_feature(**conf)
        print(f"ESM2 (protein) array saved as {output_file}")
    else:
        print(f"{mode} is not a valid mode tag, please select [clef] or [esm] for protein-reps generation")
        import shutil
        shutil.rmtree(tmp_dir)
        sys.exit(1)
        
    print(f"Done..")
    
    if remove_tmp:
        import shutil 
        try:
            shutil.rmtree(tmp_dir)
            print(f"Remove temp directory: {tmp_dir}.")
        except:
            print(f"Failed to remove temp file in {tmp_dir}.")

In [125]:
generate_protein_representation(**config)

Transform representation from fasta file Test_demo.faa
Skip loading local pre-trained ESM2 model from ./pretrained_model/esm2_t33_650M-UR50D.pt.
Try to use ESM2-650M downloaded from hub
ESM2 array saved as ./tmp\c1012338-b13d-47a4-b398-6ef20f3c630c_tmp_esm
Using pre-trained encoder in CLEF to generate protein representations
Successfully load CLEF params from f:\FDU\CLEF\Code\pretrained_model\CLEF-DP+MSA+3Di+AT.pt.
Direct load esm representations from ./tmp\c1012338-b13d-47a4-b398-6ef20f3c630c_tmp_esm
try to load feature from path:./tmp\c1012338-b13d-47a4-b398-6ef20f3c630c_tmp_esm
File is loaded using pickle.load()
Add mock label [label] of 0 for each sample
total 10 sample loaded
CLEF array saved as Test_clef_rep
Done..
Remove temp directory: ./tmp.


In [126]:
def read_faa_file(faa_file_path):
    """
    读取 .faa 文件并打印每个序列的 ID 和描述信息
    
    param faa_file_path: .faa 文件的路径
    """
    try:
        # 使用 SeqIO.parse 读取 .faa 文件
        for record in SeqIO.parse(faa_file_path, "fasta"):
            # 打印序列的 ID 和描述信息
            print(f"Sequence ID: {record.id}")
            print(f"Description: {record.description}")
            print(f"Sequence: {record.seq}")
            print("-" * 50)
    except FileNotFoundError:
        print(f"文件未找到，请检查路径是否正确: {faa_file_path}")
    except Exception as e:
        print(f"读取文件时发生错误: {e}")



In [127]:
faa_file_path = "Test_demo.faa"

read_faa_file(faa_file_path)

Sequence ID: NP_250554.1
Description: NP_250554.1 NP_250554.1 NC_002516:c2023166-2022411 [Pseudomonas aeruginosa PAO1]|T6SE
Sequence: MTTRLPQLLLALLASAVSLAASADEVQVAVAANFTAPIQAIAKEFEKDTGHRLVAAYGATGQFYTQIKNGAPFQVFLSADDSTPAKLEQEGEVVPGSRFTYAIGTLALWSPKAGYVDAEGEVLKSGSFRHLSIANPKTAPYGLAATQAMDKLGLAATLGPKLVEGQNISQAYQFVSSGNAELGFVALSQIYKDGKVATGSAWIVPTELHDPIRQDAVILNKGKDNAAAKALVDYLKGAKAAALIKSYGYEL
--------------------------------------------------
Sequence ID: NP_249515.1
Description: NP_249515.1 NP_249515.1 NC_002516:c899656-899138 [Pseudomonas aeruginosa PAO1]|T6SE
Sequence: MSGKPAARVTDPTTCPVPGHGSNPIVQGSPDVVFDGLPAARQGDASACGSPMISAVSSTVLINGLPAVTLGSIGAHGNVVIGGSGTVLIGDVFTPAPRAPALPLNRNSVPCSGRFQLIDHETGKPVAGRRVRVWSSGGWNAFDTTDADGMTSWIERPTAEILYIDLVQRCDA
--------------------------------------------------
Sequence ID: 4F0V_A
Description: 4F0V_A 4F0V_A Tse1; Chain A, Crystal Structure Of Type Effector Tse1 From Pseudomonas Aeruginousa [Pseudomonas fluorescens]|T6SE
Sequence: MGSSHHHHHHSSGENLYFEGSHMASMTGGQQMGRM

In [128]:
def load_clef_representations(file_path):
    """
    从指定路径加载 CLEF 表示文件(pickle 格式)并返回一个 Pandas DataFrame
    
    param file_path: CLEF 表示文件的路径
    return: 包含蛋白质表示的 Pandas DataFrame
    """
    if not os.path.exists(file_path):
        print(f"文件未找到，请检查路径是否正确: {file_path}")
        return None
    
    try:
        # 使用 pickle 加载文件
        with open(file_path, 'rb') as file:
            data = pickle.load(file)
        
        # 将数据转换为 Pandas DataFrame
        df = pd.DataFrame(data)
        print(f"成功加载文件: {file_path}")
        return df
    except Exception as e:
        print(f"读取文件时发生错误: {e}")
        return None



In [140]:
file_path = "Test_clef_rep.pkl"
df_clef_rep = load_clef_representations(file_path)

if df_clef_rep is not None:
    print(df_clef_rep)

成功加载文件: Test_clef_rep.pkl
      NP_250554.1  NP_249515.1    4F0V_A  YP_898952.1  CBG37356.1  \
0        0.461426    -2.650391 -1.516602     0.036682   -1.694336   
1        1.628906     0.894043 -0.152954    -0.503418   -1.369141   
2        1.270508    -1.605469 -0.790039    -0.363770    0.463135   
3       -0.419434     0.491455  1.163086     0.724609    2.080078   
4       -2.646484    -0.020370 -0.164795    -0.841309   -0.457764   
...           ...          ...       ...          ...         ...   
1275     1.754883    -0.254395 -0.066284     0.112549   -0.114441   
1276     2.541016    -0.245239 -0.409424     0.169678    0.155518   
1277     0.095154    -1.577148 -1.087891    -0.794434   -0.750000   
1278     0.327637    -0.099121 -0.260498    -1.374023   -0.250488   
1279     1.052734    -0.165283 -0.069275    -0.048523   -0.764648   

      WP_151253718.1  ADZ63249.1    P46922    Q9TYU9    Q92F67  
0           0.158569   -1.480469  1.726562 -0.763672  1.833008  
1          -0.5