## Encoder Test

In this notebook, we will look into the different encoders (for sequence and structure). 

In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, PretrainedConfig

from Bio import SeqIO

In [2]:
class ESMSeqEncoder(nn.Module):
    def __init__(
        self,
        model_name: str = 'facebook/esm2_t6_8M_UR50D',
        args=None,
        delay_load: bool = False,
        no_pooling: bool = True,   # NEW: return full per-residue embeddings?
    ):
        super().__init__()
        self.is_loaded = False

        self.model_name = model_name
        self.select_layer = getattr(args, 'protein_select_layer', -1)
        self.pooling = getattr(args, 'protein_pooling', 'cls')  # 'cls' or 'mean'
        self.no_pooling = no_pooling  # NEW flag

        if not delay_load:
            self.load_model()

    def load_model(self, device_map=None):
        if self.is_loaded:
            print(f'{self.model_name} is already loaded. Skipping load.')
            return

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
        self.encoder = AutoModel.from_pretrained(
            self.model_name,
            device_map=device_map,
            trust_remote_code=True,
            output_hidden_states=True
        )
        # Freeze encoder weights by default
        self.encoder.requires_grad_(False)

        self.is_loaded = True

    def tokenize(self, sequences):
        return self.tokenizer(
            sequences,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=1024
        )

    @torch.no_grad()
    def forward(self, sequences):
        if not self.is_loaded:
            self.load_model()

        # Tokenize & move to model device
        inputs = self.tokenize(sequences)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Run through ESM, grab hidden states
        outputs = self.encoder(**inputs)
        # hidden_states is a tuple: (layer0, layer1, ..., layerN)
        hidden_states = outputs.hidden_states[self.select_layer]  # (batch, seq_len, hidden_size)

        if self.no_pooling:
            # Return full sequence embeddings
            return hidden_states

        # Otherwise pool to single vector per sequence
        if self.pooling == 'cls':
            # CLS token is at position 0
            features = hidden_states[:, 0, :]
        elif self.pooling == 'mean':
            mask = inputs['attention_mask'].unsqueeze(-1).expand_as(hidden_states)
            sum_emb = torch.sum(hidden_states * mask, dim=1)
            counts = mask.sum(dim=1).clamp(min=1e-9)
            features = sum_emb / counts
        else:
            raise ValueError(f"Unsupported pooling type: {self.pooling}")

        return features
    
    # Load sequence from FASTA
    def load_fasta_sequence(self, fasta_path):
        record = next(SeqIO.parse(fasta_path, "fasta"))
        return str(record.seq)

    @property
    def dtype(self):
        if not self.is_loaded:
            # If not loaded, infer from config (usually fp32)
            return torch.get_default_dtype()
        return self.encoder.dtype

    # @property
    # def device(self):
    #     if not self.is_loaded:
    #         return torch.device('cpu')
    #     # encoder.device may be a map for multi-GPU; pick first
    #     dev = next(self.encoder.parameters()).device
    #     return dev
    @property
    def device(self):
        return self.encoder.device


    @property
    def config(self):
        if self.is_loaded:
            return self.encoder.config
        return PretrainedConfig.from_pretrained(self.model_name)

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def dummy_feature(self):
        """
        Returns a zero tensor matching the shape of the output:
        - (1, seq_len, hidden_size) if sequence_output, else (1, hidden_size)
        Note: seq_len = 1 for dummy by default.
        """
        if self.no_pooling:
            # dummy single residue embedding
            return torch.zeros(1, 1, self.hidden_size, device=self.device, dtype=self.dtype)
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)


In [3]:


# Initialize encoder
encoder = ESMSeqEncoder(
    model_name='facebook/esm2_t6_8M_UR50D',
    no_pooling=True  # Set to True if you want per-residue embeddings
)
encoder.load_model()

# Load test sequence
sequence = encoder.load_fasta_sequence("../asset/demo_seq_str/pdb_1ubq/rcsb_pdb_1UBQ.fasta")

# Run through encoder
with torch.no_grad():
    output = encoder(sequence)  # Note: input is a list of sequences

print("Output shape:", output.shape)  # (1, hidden_size) or (1, seq_len, hidden_size) if no_pooling=True

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


facebook/esm2_t6_8M_UR50D is already loaded. Skipping load.
Output shape: torch.Size([1, 78, 320])


In [4]:
output

tensor([[[ 0.1612,  0.9960, -0.2067,  ...,  1.1084, -0.2607, -0.5740],
         [-0.0533,  0.7383, -0.3275,  ...,  0.6771, -0.3348,  0.0325],
         [ 0.0191,  0.1796, -0.0121,  ...,  0.2409,  0.1158,  0.0507],
         ...,
         [ 0.8836, -0.0777, -0.1994,  ...,  0.2270,  0.3404,  0.0509],
         [ 0.3051, -0.4682,  0.1557,  ..., -0.3124,  0.4537, -0.8090],
         [ 0.1718,  0.5675,  0.0740,  ...,  0.2063, -0.1933, -0.4174]]])

In [5]:
import torch
import torch.nn as nn
import esm
from esm.inverse_folding.util import extract_coords_from_structure, get_encoder_output

from typing import Optional
from biotite.structure import AtomArray
from biotite.structure.io.pdbx import CIFFile, get_structure as get_structure_cif
from biotite.structure.io.pdb import PDBFile, get_structure as get_structure_pdb
import biotite.structure as struc  # for filter_peptide_backbone

def load_structure(
    file_path: str,
    chain: Optional[str] = None,
    model: int = 1
) -> AtomArray:
    """
    Load a protein structure from .cif/.mmcif or .pdb, select one model & chain,
    then filter to peptide backbone atoms only.
    """
    ext = file_path.split('.')[-1].lower()
    # Read & convert to AtomArray
    if ext in ("cif", "mmcif"):
        cif    = CIFFile.read(file_path)
        struct = get_structure_cif(cif, model=model)
    elif ext == "pdb":
        pdb    = PDBFile.read(file_path)
        struct = get_structure_pdb(pdb, model=model)
    else:
        raise ValueError(f"Unsupported extension '.{ext}'")

    # Optional chain selection
    if chain is not None:
        struct = struct[struct.chain_id == chain]

    # **Filter to peptide backbone (drops waters, side-chains, non-standard residues)**
    backbone_mask = struc.filter_peptide_backbone(struct)
    struct = struct[backbone_mask]

    return struct

class ESMIFEncoder(nn.Module):
    def __init__(
        self,
        model_name: str = "esm_if1_gvp4_t16_142M_UR50",
        args=None,
        delay_load: bool = False,
        no_pooling: bool = True,
    ):
        super().__init__()
        self.model_name = model_name
        self.no_pooling = no_pooling
        self.is_loaded = False

        if not delay_load:
            self.load_model()

    def load_model(self):
        if self.is_loaded:
            print(f"{self.model_name} already loaded. Skipping.")
            return

        # Load the inverse-folding model and its alphabet
        model, alphabet = getattr(esm.pretrained, self.model_name)()
        model = model.eval().requires_grad_(False)
        self.model = model
        self.alphabet = alphabet
        self.is_loaded = True

    @torch.no_grad()
    def forward(self, structure_path: str, chain: str = None):
        if not self.is_loaded:
            self.load_model()

        # 1) Load and filter backbone atoms / select chain
        structure = load_structure(structure_path, chain)

        # 2) Extract (L × 3 × 3) coords tensor + sequence string
        coords, seq = extract_coords_from_structure(structure)

        # 3) Convert coords to torch tensor
        coords_tensor = torch.tensor(coords, dtype=torch.float32)

        # 4) Run the inverse-folding model
        encoder_out = get_encoder_output(self.model, self.alphabet, coords_tensor)
        # embeddings = encoder_out["representations"]  # (L, hidden_size)
        embeddings = encoder_out
        

        if self.no_pooling:
            # Return per-residue (1, L, hidden_size)
            return embeddings.unsqueeze(0)
        else:
            # Mean pool over L residues → (1, hidden_size)
            return embeddings.mean(dim=0, keepdim=True)

    # @property
    # def device(self):
    #     if not self.is_loaded:
    #         return torch.device("cpu")
    #     return next(self.model.parameters()).device

    @property
    def device(self):
        return self.model.device

    @property
    def dtype(self):
        if not self.is_loaded:
            return torch.get_default_dtype()
        return next(self.model.parameters()).dtype

    @property
    def hidden_size(self):
        if not self.is_loaded:
            self.load_model()
        return self.model.embed_dim

    @property
    def dummy_feature(self):
        """
        - If no_pooling=True: returns (1,1,hidden_size)
        - Else: (1,hidden_size)
        """
        if self.no_pooling:
            return torch.zeros(1, 1, self.hidden_size,
                               device=self.device, dtype=self.dtype)
        else:
            return torch.zeros(1, self.hidden_size,
                               device=self.device, dtype=self.dtype)

In [6]:
# Initialize the encoder
encoder = ESMIFEncoder(no_pooling = True)

# Specify the path to your PDB or mmCIF file
structure_path = "../asset/demo_seq_str/pdb_1ubq/1ubq.cif"

# Optionally, specify the chain of interest
chain = "A"

# Get the protein embeddings
embeddings = encoder(structure_path, chain)

  F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)


In [7]:
embeddings.size()

torch.Size([1, 76, 512])

In [8]:
def build_seq_tower(seq_tower_cfg, **kwargs):
    seq_tower = getattr(seq_tower_cfg, 'mm_seq_tower', getattr(seq_tower_cfg, 'seq_tower', None))

    if seq_tower == 'ESM':
        return ESMSeqEncoder(model_name='facebook/esm2_t33_650M_UR50D', args=seq_tower_cfg, **kwargs)

    raise ValueError(f'Unknown sequence encoder: {seq_tower}')


def build_struc_tower(struc_tower_cfg, **kwargs):
    struc_tower = getattr(struc_tower_cfg, 'mm_struc_tower', getattr(struc_tower_cfg, 'struc_tower', None))
    if struc_tower == 'ESMIF':
        return ESMIFEncoder(model_name='esm_if1_gvp4_t16_142M_UR50', args=struc_tower_cfg, **kwargs)
    # elif struc_tower == 'ESM':
    #     return ESMStructEncoder(model_name='facebook/esm2_t33_650M_UR50D', args=struc_tower_cfg, **kwargs)
    
    raise ValueError(f'Unknown structure encoder: {struc_tower}')

In [9]:
import torch
import torch.nn as nn
import re


class IdentityMap(nn.Module):
    def __init__(self, attr_config: str):
        super().__init__()
        self.attr = attr_config

    def forward(self, x, *args, **kwargs):
        return x

    @property
    def config(self):
        return {self.attr: 'identity'}


class SimpleResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.pre_norm = nn.LayerNorm(channels)

        self.proj = nn.Sequential(
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)


def build_seq_projector(config, delay_load=False, **kwargs):
    projector_type = getattr(config, 'mm_seq_projector_type', 'linear')

    if projector_type == 'linear':
        return nn.Linear(config.mm_seq_hidden_size, config.hidden_size)

    mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
    if mlp_gelu_match:
        mlp_depth = int(mlp_gelu_match.group(1))
        modules = [nn.Linear(config.mm_seq_hidden_size, config.hidden_size)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        return nn.Sequential(*modules)

    if projector_type == 'identity':
        return IdentityMap('mm_seq_projector_type')

    raise ValueError(f'Unknown projector type: {projector_type}')


def build_struc_projector(config, delay_load=False, **kwargs):
    projector_type = getattr(config, 'mm_struc_projector_type', 'linear')

    if projector_type == 'linear':
        return nn.Linear(config.mm_struc_hidden_size, config.hidden_size)

    mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
    if mlp_gelu_match:
        mlp_depth = int(mlp_gelu_match.group(1))
        modules = [nn.Linear(config.mm_struc_hidden_size, config.hidden_size)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        return nn.Sequential(*modules)

    if projector_type == 'identity':
        return IdentityMap('mm_struc_projector_type')

    raise ValueError(f'Unknown projector type: {projector_type}')


## pannot_arch.py

In [10]:
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
# build_vision_tower


IGNORE_INDEX = -100

PROT_TOKEN_INDEX = -300
DEFAULT_PROT_TOKEN = "<prot>"
DEFAULT_PROT_PATCH_TOKEN = "<prot_patch>"
DEFAULT_PROT_START_TOKEN = "<prot_start>"
DEFAULT_PROT_END_TOKEN = "<prot_end>"
PROT_PLACEHOLDER = "<prot-placeholder>"

SEQ_TOKEN_INDEX = -330
DEFAULT_SEQ_TOKEN = "<seq>"
DEFAULT_SEQ_PATCH_TOKEN = "<seq_patch>"
DEFAULT_SEQ_START_TOKEN = "<seq_start>"
DEFAULT_SEQ_END_TOKEN = "<seq_end>"

STR_TOKEN_INDEX = -360
DEFAULT_STR_TOKEN = "<str>"
DEFAULT_STR_PATCH_TOKEN = "<str_patch>"
DEFAULT_STR_START_TOKEN = "<str_start>"
DEFAULT_STR_END_TOKEN = "<str_end>"



# from pannot.mm_utils import get_anyres_image_grid_shape





In [11]:

class PannotMetaModel:

    def __init__(self, config):
        super(PannotMetaModel, self).__init__(config)

        if hasattr(config, "mm_seq_tower"):
            self.seq_tower = build_seq_tower(config, delay_load=True)
            self.mm_seq_projector = build_seq_projector(config)

        if hasattr(config, "mm_struc_tower"):
            self.struc_tower = build_struc_tower(config, delay_load=True)
            self.mm_struc_projector = build_struc_projector(config)

    def get_seq_tower(self):
        return getattr(self, 'seq_tower', None)

    def get_struc_tower(self):
        return getattr(self, 'struc_tower', None)

    def initialize_seq_modules(self, model_args, fsdp=None):
        self.config.mm_seq_tower = model_args.seq_tower
        self.config.mm_seq_select_layer = model_args.mm_seq_select_layer
        self.config.mm_seq_select_feature = model_args.mm_seq_select_feature
        self.config.use_mm_seq_proj = True

        if self.get_seq_tower() is None:
            seq_tower = build_seq_tower(model_args)
            self.seq_tower = [seq_tower] if fsdp else seq_tower
        else:
            seq_tower = self.seq_tower[0] if fsdp else self.seq_tower
            seq_tower.load_model()

        self.config.mm_seq_hidden_size = seq_tower.hidden_size
        self.config.mm_seq_projector_type = getattr(model_args, 'mm_seq_projector_type', 'linear')

        if getattr(self, 'mm_seq_projector', None) is None:
            self.mm_seq_projector = build_seq_projector(self.config)
        else:
            for p in self.mm_seq_projector.parameters():
                p.requires_grad = True

        if model_args.pretrain_mm_seq_mlp_adapter is not None:
            seq_projector_weights = torch.load(model_args.pretrain_mm_seq_mlp_adapter, map_location='cpu')
            self.mm_seq_projector.load_state_dict(
                {k.split('mm_seq_projector.')[1]: v for k, v in seq_projector_weights.items() if 'mm_seq_projector' in k}
            )

    def initialize_str_modules(self, model_args, fsdp=None):
        self.config.mm_struc_tower = model_args.struc_tower
        self.config.mm_str_select_layer = model_args.mm_str_select_layer
        self.config.mm_str_select_feature = model_args.mm_str_select_feature
        self.config.use_mm_str_proj = True

        if self.get_struc_tower() is None:
            struc_tower = build_struc_tower(model_args)
            self.struc_tower = [struc_tower] if fsdp else struc_tower
        else:
            struc_tower = self.struc_tower[0] if fsdp else self.struc_tower
            struc_tower.load_model()

        self.config.mm_str_hidden_size = struc_tower.hidden_size
        self.config.mm_struc_projector_type = getattr(model_args, 'mm_struc_projector_type', 'linear')

        if getattr(self, 'mm_struc_projector', None) is None:
            self.mm_struc_projector = build_struc_projector(self.config)
        else:
            for p in self.mm_struc_projector.parameters():
                p.requires_grad = True

        if model_args.pretrain_mm_str_mlp_adapter is not None:
            struc_projector_weights = torch.load(model_args.pretrain_mm_str_mlp_adapter, map_location='cpu')
            self.mm_struc_projector.load_state_dict(
                {k.split('mm_struc_projector.')[1]: v for k, v in struc_projector_weights.items() if 'mm_struc_projector' in k}
            )


In [13]:
class PannotMetaForCausalLM(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_seq_tower(self):
        return self.get_model().get_seq_tower()

    def get_struc_tower(self):
        return self.get_model().get_struc_tower()

    def encode_seqs(self, seqs):
        seq_features = self.get_seq_tower()(seqs)
        if seq_features.device != self.device:
            seq_features = seq_features.to(self.device)
        seq_features = self.get_model().mm_seq_projector(seq_features)
        return seq_features

    def encode_strs(self, strsp, chain=None):
        str_features = self.get_struc_tower()(strsp, chain=chain)
        if str_features.device != self.device:
            str_features = str_features.to(self.device)
        str_features = self.get_model().mm_struc_projector(str_features)
        return str_features

    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels,
        seqs=None, strs=None
    ):
        seq_tower = self.get_seq_tower()
        struc_tower = self.get_struc_tower()

        _labels = labels
        _position_ids = position_ids
        _attention_mask = attention_mask


        if seq_tower is None and struc_tower is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels

        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()

        if position_ids is None:
            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)

        if labels is None:
            labels = torch.full_like(input_ids, IGNORE_INDEX)

        # TODO: seq start / end and str start/end is not implemented here to support pretraining.
        if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_seq_start_end', False) and getattr(self.config, 'mm_use_str_start_end', False):
            raise NotImplementedError

        # Remove padding
        _input_ids = input_ids
        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]

        new_input_embeds = []
        new_labels = []
        cur_seq_idx = 0
        cur_str_idx = 0

        for batch_idx, cur_input_ids in enumerate(input_ids):
            cur_labels = labels[batch_idx]
            cur_input_embeds_no_special = []
            cur_labels_no_special = []

            token_indices = {
                SEQ_TOKEN_INDEX: torch.where(cur_input_ids == SEQ_TOKEN_INDEX)[0].tolist(),
                STR_TOKEN_INDEX: torch.where(cur_input_ids == STR_TOKEN_INDEX)[0].tolist()
            }

            # Break input around special tokens
            all_specials = sorted([(i, 'seq') for i in token_indices[SEQ_TOKEN_INDEX]] +
                                  [(i, 'str') for i in token_indices[STR_TOKEN_INDEX]])
            all_specials = [(-1, None)] + all_specials + [(cur_input_ids.shape[0], None)]

            cur_embed_segments = []
            cur_label_segments = []

            for i in range(len(all_specials) - 1):
                start = all_specials[i][0] + 1
                end = all_specials[i + 1][0]
                cur_embed_segments.append(self.get_model().embed_tokens(cur_input_ids[start:end]))
                cur_label_segments.append(cur_labels[start:end])

                # Add special modality features
                if all_specials[i + 1][1] == 'seq':
                    seq_feature = self.encode_seqs(seqs[cur_seq_idx]).squeeze(0)
                    cur_embed_segments.append(seq_feature)
                    cur_label_segments.append(torch.full((seq_feature.shape[0],), IGNORE_INDEX, dtype=cur_labels.dtype, device=cur_labels.device))
                    
                    #     # Embed the start and end tokens (1, D)
                    # start_embed = self.get_model().embed_tokens(
                    #     torch.tensor([SEQ_START_TOKEN_ID], device=seq_feature.device)
                    # )
                    # end_embed = self.get_model().embed_tokens(
                    #     torch.tensor([SEQ_END_TOKEN_ID], device=seq_feature.device)
                    # )

                    # # Concatenate to make (L + 2, D)
                    # seq_feature = torch.cat([start_embed, seq_feature, end_embed], dim=0)

                    # # Create dummy labels for these tokens (set to IGNORE_INDEX)
                    # seq_labels = torch.full((seq_feature.shape[0],), IGNORE_INDEX, dtype=cur_labels.dtype, device=cur_labels.device)

                    # cur_embed_segments.append(seq_feature)
                    # cur_label_segments.append(seq_labels)
                                    
                    cur_seq_idx += 1
                elif all_specials[i + 1][1] == 'str':
                    str_feature = self.encode_strs(strs[cur_str_idx]).squeeze(0)
                    cur_embed_segments.append(str_feature)
                    cur_label_segments.append(torch.full((str_feature.shape[0],), IGNORE_INDEX, dtype=cur_labels.dtype, device=cur_labels.device))
                    cur_str_idx += 1

            final_embed = torch.cat(cur_embed_segments, dim=0).to(self.device)
            final_labels = torch.cat(cur_label_segments, dim=0).to(self.device)

            new_input_embeds.append(final_embed)
            new_labels.append(final_labels)

        # Truncate and pad
        tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
        if tokenizer_model_max_length is not None:
            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
            new_labels = [x[:tokenizer_model_max_length] for x in new_labels]

        max_len = max(x.shape[0] for x in new_input_embeds)
        batch_size = len(new_input_embeds)

        new_input_embeds_padded = []
        new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
        attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool, device=new_labels[0].device)
        position_ids = torch.zeros((batch_size, max_len), dtype=torch.long, device=new_labels[0].device)

        for i, (emb, lab) in enumerate(zip(new_input_embeds, new_labels)):
            cur_len = emb.shape[0]
            padding_side = getattr(self.config, 'tokenizer_padding_side', 'right')

            if padding_side == "left":
                padded = torch.cat([torch.zeros((max_len - cur_len, emb.shape[1]), device=emb.device), emb], dim=0)
                new_input_embeds_padded.append(padded)
                new_labels_padded[i, -cur_len:] = lab
                attention_mask[i, -cur_len:] = True
                position_ids[i, -cur_len:] = torch.arange(cur_len, device=emb.device)
            else:
                padded = torch.cat([emb, torch.zeros((max_len - cur_len, emb.shape[1]), device=emb.device)], dim=0)
                new_input_embeds_padded.append(padded)
                new_labels_padded[i, :cur_len] = lab
                attention_mask[i, :cur_len] = True
                position_ids[i, :cur_len] = torch.arange(cur_len, device=emb.device)

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        # 如果原始标签为空，则将新的标签设置为空
        if _labels is None:
            new_labels = None
        else:
            new_labels = new_labels_padded

        # 如果原始注意力掩码为空，则将新的注意力掩码设置为空
        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

        # 如果原始位置id为空，则将新的位置id设置为空
        if _position_ids is None:
            position_ids = None

        return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels_padded
    
    def initialize_seq_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_seq_patch_token:
            tokenizer.add_tokens([DEFAULT_SEQ_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_seq_start_end:
            num_new_tokens = tokenizer.add_tokens(
                [DEFAULT_SEQ_START_TOKEN, DEFAULT_SEQ_END_TOKEN], special_tokens=True
            )
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
                output_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_avg
                output_embeddings[-num_new_tokens:] = output_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(
                        f"Unexpected embed_tokens_weight shape. "
                        f"Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. "
                        f"Number of new tokens: {num_new_tokens}."
                    )
        elif model_args.mm_use_seq_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

    def initialize_str_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_str_patch_token:
            tokenizer.add_tokens([DEFAULT_STR_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_str_start_end:
            num_new_tokens = tokenizer.add_tokens(
                [DEFAULT_STR_START_TOKEN, DEFAULT_STR_END_TOKEN], special_tokens=True
            )
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
                output_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_avg
                output_embeddings[-num_new_tokens:] = output_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(
                        f"Unexpected embed_tokens_weight shape. "
                        f"Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. "
                        f"Number of new tokens: {num_new_tokens}."
                    )
        elif model_args.mm_use_str_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False



## Pannot_llama.py

In [14]:

from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput

# from ..pannot_arch import PannotMetaModel, PannotMetaForCausalLM


# from ..pannot_arch import PannotMetaModel, PannotMetaForCausalLM

class PannotConfig(LlamaConfig):
    model_type = "pannot_llama"


class PannotLlamaModel(PannotMetaModel, LlamaModel):
    config_class = PannotConfig

    def __init__(self, config: LlamaConfig):
        super().__init__(config)


class PannotLlamaForCausalLM(LlamaForCausalLM, PannotMetaForCausalLM):
    config_class = PannotConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = PannotLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        seqs: Optional[torch.FloatTensor] = None,  # (B, L_seq, D)
        strs: Optional[torch.FloatTensor] = None,  # (B, L_str, D)
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        
        kwargs.pop('cache_position', None)

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                seqs,
                strs
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

    @torch.no_grad()
    def generate(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        seqs: Optional[torch.FloatTensor] = None,
        strs: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        kwargs.pop('cache_position', None)

        if seqs is not None or strs is not None:
            (
                input_ids,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                None,
                None,
                seqs,
                strs
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(input_ids)

        return super().generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            **kwargs
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        inputs_embeds=None,
        **kwargs
    ):
        seqs = kwargs.pop("seqs", None)
        strs = kwargs.pop("strs", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            **kwargs
        )
        if seqs is not None:
            inputs["seqs"] = seqs
        if strs is not None:
            inputs["strs"] = strs
        return inputs


AutoConfig.register("pannot_llama", PannotConfig)
AutoModelForCausalLM.register(PannotConfig, PannotLlamaForCausalLM)

In [15]:
# Send the embeddings to the PannotLlamaForCausalLM
# Load TinyLlama config and tokenizer
pretrained_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
config = PannotConfig.from_pretrained(pretrained_model_name)




You are using a model of type llama to instantiate a model of type pannot_llama. This is not supported for all configurations of models and can yield errors.


In [16]:
config

PannotConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "pannot_llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.1",
  "use_cache": true,
  "vocab_size": 32000
}

In [17]:
# Add missing multimodal fields manually
config.mm_seq_tower = "ESM"
config.mm_struc_tower = "ESMIF"
config.mm_seq_hidden_size = 1280

config.mm_struc_hidden_size = 512
config.mm_seq_projector_type = "linear"
config.mm_struc_projector_type = "linear"
config.mm_seq_select_layer = -1
config.mm_seq_select_feature = "cls"
config.mm_str_select_layer = -1
config.mm_str_select_feature = "mean"
config.use_mm_seq_proj = True
config.use_mm_str_proj = True

In [18]:
config

PannotConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "mm_seq_hidden_size": 1280,
  "mm_seq_projector_type": "linear",
  "mm_seq_select_feature": "cls",
  "mm_seq_select_layer": -1,
  "mm_seq_tower": "ESM",
  "mm_str_select_feature": "mean",
  "mm_str_select_layer": -1,
  "mm_struc_hidden_size": 512,
  "mm_struc_projector_type": "linear",
  "mm_struc_tower": "ESMIF",
  "model_type": "pannot_llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.1",
  "use_cache": true,
  "use_mm_s

In [19]:
# Initialize model
model = PannotLlamaForCausalLM(config)
# model.load_state_dict(torch.load("tinyllama_pannot_weights.pt", map_location="cpu"))  # Optional, if you have weights

# Move model and data to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [20]:
model.config.mm_seq_hidden_size = 1280

In [21]:
# Load test sequence
sequence = model.get_seq_tower().load_fasta_sequence("../asset/demo_seq_str/pdb_1ubq/rcsb_pdb_1UBQ.fasta")
# sequence = sequence.to(device)
# Run through encoder
with torch.no_grad():
    output = model.encode_seqs(sequence)  # Note: input is a list of sequences
    # seq_features = model.get_seq_tower()([sequence])
    # print(seq_features.device)
    # print(model.get_seq_tower().device)
    # # seq_features = model.get_model().mm_seq_projector(seq_features)

print("Output shape:", output.shape)  # (1, hidden_size) or (1, seq_len, hidden_size) if no_pooling=True



Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Output shape: torch.Size([1, 78, 2048])


In [None]:
# seq_features

In [23]:
output[0][0]

tensor([ 0.1026, -0.1568,  0.0708,  ...,  0.0030, -0.1666,  0.2931],
       device='cuda:0')

In [24]:

# Specify the path to your PDB or mmCIF file
structure_path = "../asset/demo_seq_str/pdb_1ubq/1ubq.cif"

# Optionally, specify the chain of interest
chain = "A"

# structure = model.get_struc_tower().load_structure(structure_path, chain)
# Get the protein embeddings

with torch.no_grad():    
    str_embeddings = model.encode_strs(structure_path, chain = chain)

print("Structure emb shape:", str_embeddings.shape)

Structure emb shape: torch.Size([1, 76, 2048])


In [25]:
sequence

'MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG'

## Modify the tokenizer to include custom tokens

In [26]:
tokenizer.decode(torch.tensor([1,2,910, 306,  4944, 29901,11762, 29958,   529,]))

'<s></s> This I provided:seq> <'

In [27]:
import re

def tokenizer_protein_token(prompt, tokenizer, seq_token_index=SEQ_TOKEN_INDEX, str_token_index=STR_TOKEN_INDEX, return_tensors=None):
    # Split the prompt on both <seq> and <str> while preserving the split tokens
    prompt_chunks = re.split(r'(<seq>|<str>)', prompt)

    # Tokenize the chunks and replace <seq> and <str> with their respective token indices
    tokenized_input = []
    for chunk in prompt_chunks:
        if chunk == '<seq>':
            tokenized_input.append(seq_token_index)
        elif chunk == '<str>':
            tokenized_input.append(str_token_index)
        else:
            # Tokenize the chunk normally
            tokenized_input.extend(tokenizer.encode(chunk, add_special_tokens=False))

    # If return_tensors is specified, return the result as a PyTorch tensor
    if return_tensors is not None:
        if return_tensors == 'pt':
            return torch.tensor(tokenized_input, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')

    return tokenized_input

In [28]:
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict,
    tokenizer,
    model,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

In [29]:
# # Define the special tokens you want to add
# special_tokens_dict = {
#     'additional_special_tokens': [
#         "<seq>", "<seq_patch>", "<seq_start>", "<seq_end>",
#         "<str>", "<str_patch>", "<str_start>", "<str_end>"
#     ]
# }

# # Assuming you already have the tokenizer and model objects defined
# smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model)


In [30]:
tokenizer_protein_token("This is a great protein I want to study: <seq> <str>", tokenizer)

[910,
 338,
 263,
 2107,
 26823,
 306,
 864,
 304,
 6559,
 29901,
 29871,
 -330,
 259,
 -360]

In [31]:
model_args = {
    'mm_use_seq_patch_token': False,
    'mm_use_seq_start_end': False,
    'mm_use_str_patch_token': False,
    'mm_use_str_start_end': False,
    'tune_mm_mlp_adapter': True,
    'pretrain_mm_mlp_adapter': 'path_to_pretrained_mlp_adapter.pth'
}


In [32]:
model.initialize_seq_tokenizer(model_args, tokenizer)

AttributeError: 'dict' object has no attribute 'mm_use_seq_patch_token'

In [33]:
tokenizer.decode(torch.tensor([-1, -2]))

OverflowError: out of range integral type conversion attempted

In [34]:


# Prepare input IDs
input_ids = tokenizer_protein_token("<s> This is the protein I provided: <seq> <str> Could you introduce its information?", tokenizer,return_tensors='pt').unsqueeze(0).cuda()

print(input_ids.shape)

# Generate output
with torch.no_grad():
    output = model.generate(
        input_ids=input_ids,
        seqs=[sequence],
        strs=[structure_path],
        max_new_tokens=100,
        do_sample=False
    )

# Decode and print
print(tokenizer.decode(output[0], skip_special_tokens=True))

torch.Size([1, 20])


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


invånaregetElementById invånareѝ invånareѝ/~ Theoryு)" ufficiale participe ufficiale participe)" ufficiale participe)" Bagுுுுுுுae ufficiale présent ufficialegetElementByIdѝ rid participe ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale participeionales coalptrтельства xmlns més ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale载⠀dependenciesakedAfterAfterAfterAfterAfterAfterAfterAfterAfterAfterAfterAfter ufficiale intersect senior ufficiale intersect senior ufficialeÞ ufficialeÞ musesubscribeAfter ufficiale ufficiale ufficiale ufficiale ufficiale ufficiale ufficialeAfterční
