In [1]:
import pandas as pd
import torch
import torch.nn as nn


class WeightedLayerPooling(nn.Module):
    def __init__(self, num_hidden_layers, layer_start: int = 4):
        super().__init__()
        # layer_start表示我们从倒数第几层开始用
        self.layer_start = layer_start
        self.num_layers_to_pool = num_hidden_layers - layer_start + 1

        # 定义可学习的权重，初始化为1
        # 这样在训练开始时，它等价于一个简单的平均
        self.layer_weights = nn.Parameter(
            torch.ones(self.num_layers_to_pool)
        )

    def forward(self, all_hidden_states):
        # 1. 选取我们需要的层
        # all_hidden_states 是一个 tuple, 我们把它转成 tensor
        # 我们要的是模型最后几层，所以从后面开始取
        layers_to_pool = torch.stack(
            all_hidden_states[-self.num_layers_to_pool:],
            dim=0
        )  # Shape: (num_layers, batch, seq_len, hidden_dim)

        # 2. 计算权重
        # 为了让权重更稳定且具有可解释性，通常会用softmax
        # 这样所有权重加起来等于1，代表了贡献度的百分比
        weight_softmax = torch.softmax(self.layer_weights, dim=0)

        # 3. 调整权重形状以进行广播 (broadcast)
        # 从 (num_layers,) 变成 (num_layers, 1, 1, 1)
        # 这样它就可以和 (num_layers, batch, seq_len, hidden_dim) 的层输出相乘
        reshaped_weights = weight_softmax.view(-1, 1, 1, 1)

        # 4. 执行加权求和
        # 两个张量相乘，然后在一个维度上求和
        weighted_average = torch.sum(layers_to_pool * reshaped_weights, dim=0)

        return weighted_average  # Shape: (batch, seq_len, hidden_dim)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoModel, AutoTokenizer

esm_model = AutoModel.from_pretrained(
    "../../scoures/ESM2-35M",
    trust_remote_code=True,
    output_hidden_states=True  # 确保获取所有层输出
)
tokenizer = AutoTokenizer.from_pretrained(
    "../../scoures/ESM2-35M",
    trust_remote_code=True)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/home/ouyanganqi/miniconda3/envs/myenv/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ouyanganqi/miniconda3/envs/myenv/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ouyanganqi/miniconda3/envs/myenv/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/ouyanganqi/miniconda3/envs/myenv/lib/python3.10/site-packages/traitlets/config/applicat

In [3]:
esm_config = esm_model.config
embedding_size = esm_config.hidden_size

weighted_pooling = WeightedLayerPooling(
    num_hidden_layers=esm_config.num_hidden_layers,
    layer_start=esm_config.num_hidden_layers - 4
)

In [4]:
import pandas as pd

df = pd.read_csv("../../data/EC_X_val_40.csv")
sequences = df["SEQUENCE"].tolist()[:100]  # 取前100个序列进行测试
inputs = tokenizer(
    sequences,
    padding=True,
    truncation=True,
    return_tensors="pt"
)
inputs_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [5]:
esm_outputs = esm_model(
    input_ids=inputs_ids,
    attention_mask=attention_mask
)
all_hidden_states = esm_outputs.hidden_states

# B. 使用层权重融合，得到一个增强的ESM序列表示
esm_output_fused = weighted_pooling(all_hidden_states)
esm_output_fused.shape

torch.Size([100, 42, 480])

In [6]:
tokenizer(
    sequences[1],
    padding=True,
    truncation=True,
    return_tensors="pt"
)

{'input_ids': tensor([[ 0,  6, 15,  4,  4,  8,  4,  4,  8,  4,  4,  6,  4,  4,  2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [7]:
from transformers import DataCollatorWithPadding
from mic_model import MyDataset

dataset = MyDataset(
    df=df,
    tokenizer=tokenizer,
    max_length=1024
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=data_collator
)

In [10]:
next(iter(dataloader))["genome"].shape

torch.Size([256, 84])