In [18]:
import os
import sys
import esm
import time
import torch
import random
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


from Bio import SeqIO
from torch import einsum
from einops import rearrange
from torch.utils.data import DataLoader


In [19]:
class Args:
    def __init__(self, Seq=None, Feat=None, Out=None, esm_model_path=None, lr=None, btz=None,epoch=None,maxlen=None):
        self.Seq = Seq
        self.Feat = Feat
        self.Out = Out
        self.esm_model_path = esm_model_path
        self.lr = lr
        self.btz = btz
        self.epoch = epoch
        self.maxlen = maxlen

args = Args(Seq = "./Demo_train/Demo_trainset.faa",
            Feat = ["./Demo_train/Demo_trainset_featA", "./Demo_train/Demo_trainset_featB", "./Demo_train/Demo_trainset_featC"],
            Out = "Demo_clef",
            esm_model_path = "./pretrained_model/esm2_t33_650M_UR50D.pt",
            lr = 0.00002,
            epoch = 20,
            btz = 128,
            maxlen = 256)

print(args)


<__main__.Args object at 0x000001D2E74EF1F0>


In [20]:
seq_path = args.Seq
modal_paths = args.Feat
modal_path_dict = {f'modal_{i}':path for i,path in enumerate(modal_paths)}
input_file_config = {'seq':seq_path}
input_file_config.update(modal_path_dict)
train_config = {'lr':args.lr, 'batch_size':args.btz, 'num_epoch':args.epoch, 'maxlen':args.maxlen}
output_dir = args.Out
esm_config = {'maxlen':args.maxlen}
if args.esm_model_path:
    esm_config['pretrained_model_params'] = args.esm_model_path
config = {
    'input_file_config':input_file_config,
    'output_dir':output_dir,
    'train_config':train_config,
    'esm_config':esm_config
}
#train_clef(**config)
print(config)

{'input_file_config': {'seq': './Demo_train/Demo_trainset.faa', 'modal_0': './Demo_train/Demo_trainset_featA', 'modal_1': './Demo_train/Demo_trainset_featB', 'modal_2': './Demo_train/Demo_trainset_featC'}, 'output_dir': 'Demo_clef', 'train_config': {'lr': 2e-05, 'batch_size': 128, 'num_epoch': 20, 'maxlen': 256}, 'esm_config': {'maxlen': 256, 'pretrained_model_params': './pretrained_model/esm2_t33_650M_UR50D.pt'}}


In [21]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_key=64, dim_value=64, dropout=0.):
        super().__init__()
        self.scale = dim_key ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(dim, dim_key * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_key * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_value * heads, bias=False)
        self.to_out = nn.Linear(dim_value * heads, dim)
        self.attn_dropout = nn.Dropout(dropout)
        # 初始化模型参数
        self.reset_parameter()

    def reset_parameter(self):
        # xavier初始化使输入输出方差一致：xavier_uniform_均匀分布初始化，xavier_normal_正态分布初始化
        nn.init.xavier_uniform_(self.to_q.weight)
        nn.init.xavier_uniform_(self.to_k.weight)
        nn.init.xavier_uniform_(self.to_v.weight)
        nn.init.xavier_uniform_(self.to_out.weight)
        nn.init.zeros_(self.to_out.bias)

    def forward(self, x, mask=None):
        # x:[batchsize, sequence_length, dim]
        n, h = x.shape[-2], self.heads

        # 从 x 生成 q, k, v  [batchsize, sequence_length, dim_k * heads]
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # map （函数，可迭代对象）； lambda 变量：处理方式 ；
        #  [batch_size, heads, sequence_length, dim_key]
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        # q/dim**0.5
        q = q * self.scale
        
        # q, k 计算点积注意力 [batchsize, head, sequence_length, sequence_length]
        logits = einsum('b h i d, b h j d -> b h i j', q, k)

        # -1e9掩码
        if mask is not None:
            logits.masked_fill(mask, -1e9)

        # softmax(q*k/d**0.5) [batchsize, head, sequence_length, sequence_length]
        attn = logits.softmax(dim=-1)
        
        # dropout
        attn = self.attn_dropout(attn)

        # v*softmax(q*k/d**0.5) [batchsize, head, sequence_length, dim_value]
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        
        #  [batch_size, sequence_length,  dim_value * heads] 
        out = rearrange(out, 'b h n d -> b n (h d)')

        #  dim_value * heads -> dim 
        return self.to_out(out), attn

class TransformerLayer(nn.Module):

    def __init__(self, hid_dim, heads, dropout_rate, att_dropout=0.05):
        super().__init__()
        
        # dim, head, qk_dim, v_dim, dropout， 隐藏层维度整除分类头
        self.attn = Attention(hid_dim, heads, hid_dim //
                              heads, hid_dim // heads, att_dropout)
        
        # feedforward network
        self.ffn = nn.Sequential(
            nn.LayerNorm(hid_dim),
            nn.Linear(hid_dim, hid_dim * 2),
            nn.GELU(),
            nn.Linear(hid_dim * 2, hid_dim),
            nn.Dropout(dropout_rate))
        self.layernorm = nn.LayerNorm(hid_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):

        # [batch_size, sequence_length, hid_dim]
        residual = x
        x = self.layernorm(x)  # pre-LN
        
        # x = [batch_size, sequence_length,  hid_dim] 
        # attn = [batchsize, head, sequence_length, sequence_length]
        x, attn = self.attn(x, mask)
        x = self.dropout(x)
        x = residual + x

        residual = x
        x = self.ffn(x)
        x = residual + x

        return x, attn

In [22]:
def sequence_mask(X, valid_lens):
    # [batchsize, sequence_length],生成初始mask并默认False
    mask = torch.zeros((X.shape[0], X.shape[1]), dtype = torch.bool).to(X.device) 
    
    # [batchsize, 1] -> [batchsize, sequence_length]，扩展有效长度并与mask一致                 
    expanded_valid_lens = valid_lens.view(-1, 1).expand(X.shape[0], X.shape[1])    
    
    # 将超出expanded_valid_lens的部分记为True
    src_key_padding_mask = mask.masked_fill(torch.arange(X.shape[1]).to(X.device).view(1, -1).expand(X.shape[0], X.shape[1]) >= expanded_valid_lens, True)
    return src_key_padding_mask

In [23]:
def standardize(feature):
    mean = feature.mean(dim=0, keepdim=True)
    std = feature.std(dim=0, keepdim=True)
    return (feature - mean) / std

In [24]:
# Encoder A
class clef_enc(nn.Module): 

    def __init__(self, num_embeds, num_hiddens=128, finial_drop=0.1, mlp_relu=True):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerLayer(num_embeds, 8, 0.45, 0.05)
                for _ in range(2)
            ]
        )
        self.Dropout = nn.Dropout(finial_drop)
        self.ln = nn.LayerNorm(num_embeds)
        if mlp_relu:
            # True则包含两个线性层和一个ReLU
            self.mlp = nn.Sequential(nn.Linear(num_embeds, 2 * num_embeds), nn.ReLU(),
                                     nn.Linear(2 * num_embeds, num_hiddens))
        else:
            # False则只包含一个线性层
            self.mlp = nn.Linear(num_embeds, num_hiddens)


    def forward(self, batch, Return_res_rep=False):

        X, valid_lens = batch['esm_feature'], batch['valid_lens']

        src_key_padding_mask = sequence_mask(X, valid_lens)
        
        # 前向传播过程，[b, n] -> [b, 1, 1, n]
        for layer in self.layers:
            X, _ = layer(X, mask=src_key_padding_mask.unsqueeze(1).unsqueeze(2))

        # whether return embeddings per-residue
        if not Return_res_rep:   
            X = torch.cat([X[i, :valid_lens[i] + 2].mean(0).unsqueeze(0)
                           for i in range(X.size(0))], dim=0)
            proj_X = self.mlp(self.Dropout(X))
        else:
            proj_X = torch.cat([X[i, :valid_lens[i]].mean(0).unsqueeze(0)
                                for i in range(X.size(0))], dim=0)
            proj_X = self.mlp(self.Dropout(proj_X))

        return X, proj_X



class clef_multimodal(nn.Module):

    def __init__(self, num_embeds, feat_dim_config, num_hiddens=128,
                 finial_drop=0.1, mlp_relu=True, feat_mlp_relu=True,
                 feature_norm=True):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerLayer(num_embeds, 8, 0.45, 0.05)
                for _ in range(2)
            ]
        )
        self.Dropout = nn.Dropout(finial_drop)
        self.ln = nn.LayerNorm(num_embeds)
        if mlp_relu:
            self.mlp = nn.Sequential(nn.Linear(num_embeds, 2 * num_embeds), nn.ReLU(),
                                     nn.Linear(2 * num_embeds, num_hiddens))
        else:
            self.mlp = nn.Linear(num_embeds, num_hiddens)
            
            
        # 定义多模态字典
        self.modal_indeces = {key:i for i, key in enumerate(feat_dim_config)}

        # 将各个特征维度的值加和，之后通过线性层变为1024维度
        feat_dim = sum([dim for dim in feat_dim_config.values()])
        self.feat_encoder = nn.Sequential(nn.Linear(feat_dim, 1024, bias=False), nn.ReLU(),
                              nn.Linear(1024, num_hiddens))



        # 再对每个维度进行layernorm
        self.feature_norm = feature_norm
        if self.feature_norm:
            self.ln_f = nn.LayerNorm(num_hiddens)

    def forward(self, batch, Return_res_rep=False):

        # 多个模态信息
        has_modal = min([x in batch for x in self.modal_indeces.keys()])
        
        X, valid_lens = batch['esm_feature'], batch['valid_lens']

        src_key_padding_mask = sequence_mask(X, valid_lens)
        for layer in self.layers:
            X, _ = layer(X, mask=src_key_padding_mask.unsqueeze(1).unsqueeze(2))

        if not Return_res_rep:   # If return embeddings per-residue
            X = torch.cat([X[i, :valid_lens[i] + 2].mean(0).unsqueeze(0)
                           for i in range(X.size(0))], dim=0)
            proj_X = self.mlp(self.Dropout(X))
        else:
            proj_X = torch.cat([X[i, :valid_lens[i]].mean(0).unsqueeze(0)
                                for i in range(X.size(0))], dim=0)
            proj_X = self.mlp(self.Dropout(proj_X))

        # 如果含有多模态信息
        if has_modal:
            cross_features = []
            for modal, i in self.modal_indeces.items():
                if modal in batch:
                    # 归一化
                    norm_feat = standardize(batch[modal]) if len(self.modal_indeces.items())>1 else batch[modal]
                    cross_features.append(norm_feat)
            # 特征拼接在一起
            cross_features = torch.cat(cross_features, -1)
            
            # 编码融合后的特征
            proj_features = self.feat_encoder(cross_features)
            if self.feature_norm:
                proj_features = self.ln_f(proj_features)

            # X是transformer的输出，proj_x经过mlp处理，proj_features融合后的多模态
            return X, proj_X, proj_features
        else:
            return X, proj_X

In [25]:
def fasta_to_EsmRep(input_fasta, output_file = None, 
                      pretrained_model_params = None,
                      maxlen = 256,
                      Return = True, 
                      Final_pool = False):
  '''
  input_file : input local fasta file path 
  output_file : output encoded file path 
  '''
  pretrained_model_params = pretrained_model_params if pretrained_model_params else './pretrained_model/esm2_t33_650M_UR50D.pt'
  aa_dict = {amino_acid: i for i, amino_acid in enumerate("ACDEFGHIKLMNPQRSTVWYX")}
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  try:
      input_embedding_net, alphabet = esm.pretrained.load_model_and_alphabet_local(pretrained_model_params)
  except:
      print(f"Skip loading local pre-trained ESM2 model from {pretrained_model_params}.\nTry to use ESM2-650M downloaded from hub")
      weight_path = os.path.dirname(os.path.abspath(pretrained_model_params))
      if os.path.exists(weight_path) and os.path.isdir(weight_path):
            torch.hub.set_dir(weight_path)
      else:
            print(f"Download ESM2-650M to ./cache")
      input_embedding_net, alphabet = esm.pretrained.esm2_t33_650M_UR50D()

  batch_converter = alphabet.get_batch_converter()
  input_embedding_net = input_embedding_net.to(device)
  input_embedding_net.eval()
  output_dict = {}
  real_maxlen = max(1, maxlen - 2)
  num_layer = len(input_embedding_net.layers)
  for record in SeqIO.parse(open(input_fasta), 'fasta'):
    sequence = str(record.seq[: real_maxlen])  
    sequence = "".join([x if x in aa_dict else 'X' for x in sequence])
    data = [
    ("protein1", sequence),
      ]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
      results = input_embedding_net(batch_tokens, repr_layers=[num_layer], return_contacts=True)
    token_representations = results["representations"][num_layer]
    embedding = np.array(token_representations.squeeze(0).detach().to('cpu')).astype(np.float16)
    embedding = embedding[:real_maxlen + 2, ]
    embedding = embedding.mean(0) if Final_pool else embedding
    output_dict[record.id] = embedding
  if output_file:
      try:
          with open(output_file, 'wb') as f:
            pickle.dump(output_dict, f)
          print(f'ESM2 array saved as {output_file}')
      except:
          print(f'ESM2 array failed to save as {output_file}')
          import uuid
          tmp_name = str(uuid.uuid4())+'_esm'
          output_file =os.path.join(os.path.dirname(input_fasta), tmp_name) 
          with open(output_file, 'wb') as f:
            pickle.dump(output_dict, f)
          print(f'Temp ESM2 array saved as {output_file}')
  if Return:
      return output_dict

In [26]:
def check_hidden_layer_dimensions(data_dict):
    hidden_layer_size = None
    for key, value in data_dict.items():
        if not isinstance(value, np.ndarray):
            raise ValueError(f"Value for key '{key}' is not a numpy array.")

        current_size = value.shape[-1]
        if hidden_layer_size is None:
            hidden_layer_size = current_size
        elif hidden_layer_size != current_size:
            return None  

    return hidden_layer_size

In [27]:
def load_feature_from_local(feature_path, silence=False):
    '''
    load feature dict from local path (Using pickle.load() or torch.load())
    the dictionary is like:
        {
          Protein_ID : feature_array [a 1D numpy array]
        }
    '''
    # Try pickle.load() function 
    try:
        with open(feature_path, 'rb') as f:
            obj = pickle.load(f)
        if not silence:
            print("File is loaded using pickle.load()")
        return obj
    except (pickle.UnpicklingError, EOFError):
        pass

    # Try torch.load() function
    try:
        obj = torch.load(feature_path)
        if not silence:
            print("File is loaded using torch.load()")
        return obj
    except (torch.serialization.UnsupportedPackageTypeError, RuntimeError):
        pass

    print("Unable to load file.")
    return None

In [28]:
def is_fasta_file(file_path):
    try:
        with open(file_path, 'r') as file:
            first_line = file.readline().strip()
            return first_line.startswith(">")
    except Exception:
        return False

In [29]:
def find_root_path():
    try:
        current_dir = os.path.dirname(os.path.abspath(__file__))
    except:
        current_dir = os.getcwd()
    project_root = os.path.abspath(os.path.join(current_dir, os.pardir))
    return project_root

In [1]:
class Potein_rep_datasets:  
    def __init__(self, input_path, train_range = None, test_range = None, label_tag = 'label'):
        '''
        [input_path] is a Path_dict containing feature ID and corresponding Local_path
        e.g {'feature':'./path/to/you/feature_file'}
        '''
        sequence_data = {}
        try:
            for key, value in input_path.items():
                if isinstance(value, str):
                    print(f"try to load feature from path:{value}")
                    tmp = load_feature_from_local(value)
                elif isinstance(value, np.ndarray):
                    print(f"try to load feature from numpy_array")
                    tmp = value
                else:
                    print(f"can not load feature {key}")
                    continue
                for ID, feat in tmp.items():
                    if ID not in sequence_data:
                        sequence_data[ID] = {key:feat}
                    else:
                        sequence_data[ID].update( {key:feat} )
            
            self.sequence_data = {}   
            for key, value in sequence_data.items():
                if len(value) < len(input_path):
                   print(f"imcomplete feature ID {key} removed")
                else:
                   self.sequence_data[key] = value
            
            if label_tag not in input_path:
                print(f"Add mock label [{label_tag}] of 0 for each sample")
                for key in self.sequence_data:
                    self.sequence_data[key][label_tag] = 0
        except:
            print(f"No valid [input_path] to load : {input_path}, return an empty dataset")
            self.sequence_data = {}
               
        self.data_indices = {i : ID for i, ID in enumerate(self.sequence_data)}
        
        self.label_tag = label_tag
        
        print(f"total {len(self.data_indices)} sample loaded")
        self.feature_list = list(input_path.keys())
        
        self.train_range = train_range
        self.test_range = test_range
        
        if not self.train_range:
            self.train_range = range(len(self.data_indices))
      
        if not self.test_range:
            self.test_range = range(len(self.data_indices))
            
    def __len__(self):
        # 返回数据集的样本数量
        return len(self.data_indices)

    def __getitem__(self, idx):
        ID = self.data_indices[idx]
        sample = self.sequence_data[ID]
        features = {key: value for key, value in sample.items() if key != self.label_tag}
        label = sample[self.label_tag]
        return features, label


In [31]:
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=1.0):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature

    def forward(self, image_embeds, text_embeds):
        
        # L2归一化
        image_embeds = F.normalize(image_embeds, dim=-1, p=2)
        text_embeds = F.normalize(text_embeds, dim=-1, p=2)

        # [batchsize, 1, embedding_dim] [1, batchsize, embedding_dim] -> [batchsize, batchsize]
        similarity_matrix = F.cosine_similarity(image_embeds.unsqueeze(1), text_embeds.unsqueeze(0), dim=-1) / self.temperature

        # 提取对角元素
        positives = torch.diag(similarity_matrix)

        nce_loss = -torch.log(torch.exp(positives) / torch.exp(similarity_matrix).sum(dim=-1)).mean()

        return nce_loss

In [32]:
def train_clef(input_file_config,
                output_dir,
                model_initial = clef_multimodal,
                tmp_dir = "./tmp",
                embedding_generator = fasta_to_EsmRep,
                esm_config = None,
                train_config = None,
                ):
    # 初始化日志
    log = []
    
    # 如果临时目录不存在，则创建它
    if not os.path.exists(tmp_dir):
        os.mkdir(tmp_dir)
        
    # input_file_config 中包含 seq 键
    assert "seq" in  input_file_config , "For training.sequence representation or fasta file"   
    input_file = input_file_config['seq']
    
    # 遍历其他键值对，检查dim形状
    feat_dim_config = {}
    for key, value in input_file_config.items():
        if key != 'seq':
            feat_dim = check_hidden_layer_dimensions(load_feature_from_local(value, silence=True))
            assert feat_dim, f'Dimension numbers of the last dimension is not same; {value}'
            feat_dim_config[key] = feat_dim
            line = f'{key}--{value}; num_dims:{feat_dim}'
            print(line)
            log.append(f'{line}\n')
            
    # 如果输入是fasta文件，则配置ESM并输出成为embedding
    # if is_fasta_file(input_file):
    #     print(f"Transform representation from fasta file {input_file}")
    #     import uuid
    #     tmp_file = str(uuid.uuid4())+'_tmp_esm'
    #     tmp_file = os.path.join(tmp_dir, tmp_file)
    #     esm_config = esm_config if isinstance(esm_config, dict) else {'Final_pool':False, 'maxlen':256}
    #     esm_config['input_fasta'] = input_file
    #     esm_config['output_file'] = tmp_file
    #     esm_config['Return'] = False
    #     if 'pretrained_model_params' not in esm_config:
    #         esm_config['pretrained_model_params'] = os.path.join(find_root_path(), "./pretrained_model/esm2_t33_650M_UR50D.pt")
    #     try:
    #         embedding_generator(**esm_config)
    #     except:
    #         print("Failed to transform fasta into ESM embeddings, make sure esm config is correct")
    #         import shutil
    #         shutil.rmtree(tmp_dir)
    #         sys.exit(1)
    # else:
    #     print(f"Direct use {input_file} as sequence representation")
    #     tmp_file = input_file
    tmp_file = "./tmp/15610eeb-88d7-40ba-98be-7fb6c6df25e4_tmp_esm"
    
    # 训练设备
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 
    if device == 'cpu':
        print(f'**Note:model will be trained on CPU, it may take very long time')
        
    # 加载序列嵌入数据并检查维度
    num_embeds = check_hidden_layer_dimensions(load_feature_from_local(tmp_file, silence=True))
    assert num_embeds, f'make sure {tmp_file} is a dict with ID:sequence_reps; sequence_reps should be 2D numpy with shape of [protein_length, num_hidden]'
    line = f'Sequnce data--{input_file}; num_dims:{num_embeds}'
    log.append(f'{line}\n')
    
    # 初始化模型
    model=model_initial(num_embeds, feat_dim_config).to(device)
    
    # 配置数据路径
    data_path_config = {'esm_feature':tmp_file}
    data_path_config.update({key:value for key, value in input_file_config.items() if key != 'seq'})
    Dataset = Potein_rep_datasets(data_path_config)
    
    

    # 如果数据集为空，抛出异常并删除临时目录，然后退出程序
    # if len(Dataset) == 0:
    #     raise ValueError("Failed to load feature for training")
    #     import shutil
    #     shutil.rmtree(tmp_dir)
    #     sys.exit(1)
    
    # 定义损失函数
    loss_function= InfoNCELoss()
    
    # 传入超参数，或者使用默认值
    if train_config:
        for x in ['lr', 'batch_size', 'num_epoch']:
            assert x in  train_config, "'lr', 'batch_size' and 'num_epoch' are needed when not using default train configuration "
        lr = train_config['lr']
        batch_size = train_config['batch_size']
        num_epoch = train_config['num_epoch']
        maxlen = 256 if 'maxlen' not in train_config else train_config['maxlen']
    else:
        lr = 0.00002
        batch_size = 128
        num_epoch = 20
        maxlen = 256
        
    dataloader = DataLoader(dataset=Dataset, batch_size=batch_size, shuffle=True)
        
    # 优化器
    optimizer=optim.Adam(model.parameters(),lr=lr)
    
    # 不使用激活函数
    activation = None
    
    # 模型临时保存点
    temp_check_point = output_dir
    
    # 不跟随epoch 保存 
    save_by_epoch = False
    check_list = []
    
    # 创建检查点目录
    if not os.path.exists(temp_check_point):
        os.mkdir(temp_check_point)
        
    # 训练开始时间
    s = time.time()
    
    # epoch循环
    for epoch in range(num_epoch):
        Loss=0
        sum=0 
        print("Start training: "+str(epoch)+"\n")
        
        # batch循环，Dataloader加载数据
        # for batch in Dataset.Dataloader(batch_size=batch_size,  max_num_padding=maxlen, device=device):
        for batch in dataloader:    
            # eval
            with torch.no_grad():
                model.eval()
                _, proj, Y = model(batch)
                loss = loss_function(proj, Y)
                # 损失不是NaN
                assert not np.isnan(loss.item()), 'odd loss value appears'
                Loss += loss.item()
                sum += Y.shape[0]
                
            # train
            model.train()
            _, proj, Y = model(batch)
            loss = loss_function(proj,Y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # epoch训练结束时间，平均损失
        e = time.time()
        t = e - s
        avg_loss = Loss / sum
        line = f'Epoch: {epoch}; Train loss:{avg_loss}; time:{t} s'
        print(line)
        log.append(f'{line}\n')
        
        # 保存checkpoint
        tmp_check_path = os.path.join(temp_check_point, f"./checkpoint_{epoch}.pt")
        if save_by_epoch:
            torch.save(model.state_dict(), tmp_check_path)
        elif epoch == num_epoch - 1 or epoch + 1 in check_list:
            torch.save(model.state_dict(), tmp_check_path)
            line = f'Epoch:{epoch} Checkpoint weights saved--{tmp_check_path}'
            print(line)
            log.append(f'{line}\n')
    
    # 训练结束时间
    e = time.time()
    print(f'Total training time : {e-s} seconds')
    print(f"Done..")
    
    
    import shutil 
    try:
        shutil.rmtree(tmp_dir)
    except:
        print(f"Failed to remove temp file in {tmp_dir}.")
    log_path = os.path.join(output_dir, 'log.txt')
    with open(log_path, 'w') as f:
        f.writelines(log)
    print(f"Log text file saved to {log_path}.")  

In [33]:
import torch
print(torch.cuda.is_available())

True


In [None]:
train_clef(**config)

#### 
modal_0--./Demo_train/Demo_trainset_featA; num_dims:1024

modal_1--./Demo_train/Demo_trainset_featB; num_dims:768

modal_2--./Demo_train/Demo_trainset_featC; num_dims:256

Transform representation from fasta file ./Demo_train/Demo_trainset.faa

Skip loading local pre-trained ESM2 model from pretrained_model\esm2_t33_650M_UR50D.pt.

Try to use ESM2-650M downloaded from hub

ESM2 array saved as ./tmp\aeaae50d-41b3-476b-b6cf-fc685a7cc631_tmp_esm

try to load feature from path:./tmp\aeaae50d-41b3-476b-b6cf-fc685a7cc631_tmp_esm

File is loaded using pickle.load()

try to load feature from path:./Demo_train/Demo_trainset_featA

File is loaded using pickle.load()

try to load feature from path:./Demo_train/Demo_trainset_featB

File is loaded using pickle.load()

try to load feature from path:./Demo_train/Demo_trainset_featC

File is loaded using pickle.load()

Add mock label [label] of 0 for each sample

Add mock label [label] of 0 for each sample

Add mock label [label] of 0 for each sample

Add mock label [label] of 0 for each sample

Add mock label [label] of 0 for each sample

Add mock label [label] of 0 for each sample

total 3178 sample loaded

total 3178 sample loaded

Epoch: 0; Train loss:0.06013690780137904; time:222.6505126953125 s

Epoch: 1; Train loss:0.056203456927276095; time:449.27208948135376 s

Epoch: 2; Train loss:0.054905163334029665; time:673.4195473194122 s

Epoch: 3; Train loss:0.05423521492899852; time:892.0980730056763 s

Epoch: 4; Train loss:0.05371454446251247; time:1112.6760714054108 s

Epoch: 5; Train loss:0.0534589481173708; time:1329.012489080429 s

Epoch: 6; Train loss:0.0532391931817095; time:1548.2551367282867 s

Epoch: 7; Train loss:0.05299747657895913; time:1764.3475527763367 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s


Epoch: 6; Train loss:0.0532391931817095; time:1548.2551367282867 s

Epoch: 7; Train loss:0.05299747657895913; time:1764.3475527763367 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 6; Train loss:0.0532391931817095; time:1548.2551367282867 s

Epoch: 7; Train loss:0.05299747657895913; time:1764.3475527763367 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 6; Train loss:0.0532391931817095; time:1548.2551367282867 s

Epoch: 7; Train loss:0.05299747657895913; time:1764.3475527763367 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 6; Train loss:0.0532391931817095; time:1548.2551367282867 s

Epoch: 7; Train loss:0.05299747657895913; time:1764.3475527763367 s

Epoch: 6; Train loss:0.0532391931817095; time:1548.2551367282867 s

Epoch: 6; Train loss:0.0532391931817095; time:1548.2551367282867 s

Epoch: 7; Train loss:0.05299747657895913; time:1764.3475527763367 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 10; Train loss:0.052566751809597315; time:3605.0055277347565 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 10; Train loss:0.052566751809597315; time:3605.0055277347565 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 8; Train loss:0.05278185840820351; time:1979.615389585495 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 10; Train loss:0.052566751809597315; time:3605.0055277347565 s

Epoch: 11; Train loss:0.0525134675078155; time:3822.644381761551 s

Epoch: 12; Train loss:0.05236718435569557; time:4044.220722913742 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 10; Train loss:0.052566751809597315; time:3605.0055277347565 s

Epoch: 11; Train loss:0.0525134675078155; time:3822.644381761551 s
Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 10; Train loss:0.052566751809597315; time:3605.0055277347565 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 9; Train loss:0.0526635484773457; time:3390.4365899562836 s

Epoch: 10; Train loss:0.052566751809597315; time:3605.0055277347565 s

Epoch: 11; Train loss:0.0525134675078155; time:3822.644381761551 s

Epoch: 11; Train loss:0.0525134675078155; time:3822.644381761551 s

Epoch: 12; Train loss:0.05236718435569557; time:4044.220722913742 s

Epoch: 13; Train loss:0.052220565331964994; time:4268.119744300842 s

Epoch: 13; Train loss:0.052220565331964994; time:4268.119744300842 s

Epoch: 14; Train loss:0.05217440335385975; time:4483.992960929871 s

Epoch: 14; Train loss:0.05217440335385975; time:4483.992960929871 s

Epoch: 15; Train loss:0.05205994315534511; time:4700.0446581840515 s

Epoch: 16; Train loss:0.05203925685510461; time:4916.996402025223 s

Epoch: 16; Train loss:0.05203925685510461; time:4916.996402025223 s

Epoch: 17; Train loss:0.05193668345072646; time:5133.026588916779 s

Epoch: 17; Train loss:0.05193668345072646; time:5133.026588916779 s

Epoch: 18; Train loss:0.05185567987572854; time:5349.241844177246 s

Epoch: 19; Train loss:0.051844323034028424; time:5565.7364683151245 s

Epoch:19 Checkpoint weights saved--Demo_clef\./checkpoint_19.pt

Total training time : 5567.153565168381 seconds

Done..

Log text file saved to Demo_clef\log.txt.