In [158]:

import torch

import transformers
import tokenizers
from torch.utils.data import Dataset
from enum import auto, Enum

import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
from typing import List, Union,Tuple, Optional,Any


In [3]:
from packaging import version
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')


In [4]:
IS_TOKENIZER_GREATER_THAN_0_14

True

### Input Definition

In [240]:
@dataclass
class ModelArguments:
    # Core model
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    version: Optional[str] = field(default="v0")

    # Global control
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)

    # ===== SEQUENCE TOWER =====
    use_seq_tower: bool = field(default=True)
    mm_seq_tower: Optional[str] = field(default="ESM")  # One of: "ProtST", "ESM"
    mm_seq_select_layer: Optional[int] = field(default=-1)
    mm_seq_projector_type: Optional[str] = field(default="linear")
    mm_use_seq_start_end: bool = field(default=False)
    mm_use_seq_patch_token: bool = field(default=False)

    # ===== STRUCTURE TOWER =====
    use_str_tower: bool = field(default=True)
    mm_struc_tower: Optional[str] = field(default="ESM3")  # One of: "ESMIF", "ESM3"
    mm_str_projector_type: Optional[str] = field(default="linear")
    mm_use_str_start_end: bool = field(default=False)
    mm_use_str_patch_token: bool = field(default=False)

    # ===== Fusion control (optional) =====
    mm_fusion_type: Optional[str] = field(default="concat")  # e.g., "concat", "sum", "crossattn"

In [205]:


class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()  
    TWO = auto()  
    MPT = auto()  
    PLAIN = auto()  
    LLAMA_2 = auto()  

@dataclass
class ProteinInput:
    """Class to represent a protein input with sequence, structure, and annotations."""
    sequence: str  # Amino acid sequence
    structure: Optional[str] = None  # e.g., path to PDB file or structural data
    annotations: Optional[str] = None  # Functional annotations or descriptions

@dataclass
class Conversation:
    system: str
    roles: List[str]
    messages: List[List[Union[str, ProteinInput]]]
    offset: int = 0
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: Optional[str] = None
    version: str = "Pannot"

    def get_prompt(self) -> str:
        messages = self.messages
        if self.sep_style == SeparatorStyle.SINGLE:
            ret = self.system + self.sep
            for role, message in messages:
                ret += role + ": " + self._format_message(message) + self.sep
        elif self.sep_style == SeparatorStyle.TWO:
            seps = [self.sep, self.sep2 or self.sep]
            ret = self.system + seps[0]
            for i, (role, message) in enumerate(messages):
                ret += role + ": " + self._format_message(message) + seps[i % 2]
        elif self.sep_style == SeparatorStyle.PLAIN:
            ret = self.system + "\n"
            for role, message in messages:
                ret += role + ": " + self._format_message(message) + "\n"
        elif self.sep_style == SeparatorStyle.LLAMA_2:
            wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if msg else ""
            wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
            ret = ""
            for i, (role, message) in enumerate(messages):
                formatted = self._format_message(message)
                if i == 0:
                    assert role == self.roles[0], "First message must be from user"
                    formatted = wrap_sys(self.system) + formatted
                if i % 2 == 0:
                    ret += self.sep + wrap_inst(formatted)
                else:
                    ret += " " + formatted + " " + (self.sep2 or self.sep)
            ret = ret.lstrip(self.sep)
        elif self.sep_style == SeparatorStyle.MPT:
            ret = self.system + self.sep
            for role, message in messages:
                ret += role + self._format_message(message) + self.sep
        else:
            raise ValueError(f"Invalid separator style: {self.sep_style}")
        return ret

    def _format_message(self, message: Union[str, ProteinInput]) -> str:
        if isinstance(message, ProteinInput):
            lines = [f"<seq> {message.sequence} </seq>"]
            if message.structure:
                lines.append(f"<str> {message.structure} </str>")
            if message.annotations:
                lines.append(f"<anno> {message.annotations} </anno>")
            return "\n".join(lines)
        return message

    def append_message(self, role: str, message: Union[str, ProteinInput]):
        self.messages.append([role, message])

    def copy(self) -> 'Conversation':
        return Conversation(
            system=self.system,
            roles=self.roles,
            messages=[[r, m] for r, m in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            version=self.version,
        )

default_conversation = Conversation(
    system="A chat between a curious user and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed, and polite answers to the user's questions.",
    roles=["USER", "ASSISTANT"],
    version="v1",
    messages=(),
    offset=0,
    sep_style=SeparatorStyle.TWO,
    sep=" ",
    sep2="</s>",
)

In [206]:
default_conversation.get_prompt()

"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "

In [207]:

IGNORE_INDEX = -100

PROT_TOKEN_INDEX = -300
DEFAULT_PROT_TOKEN = "<prot>"
DEFAULT_PROT_PATCH_TOKEN = "<prot_patch>"
DEFAULT_PROT_START_TOKEN = "<prot_start>"
DEFAULT_PROT_END_TOKEN = "<prot_end>"
PROT_PLACEHOLDER = "<prot-placeholder>"

SEQ_TOKEN_INDEX = -330
DEFAULT_SEQ_TOKEN = "<seq>"
DEFAULT_SEQ_PATCH_TOKEN = "<seq_patch>"
DEFAULT_SEQ_START_TOKEN = "<seq_start>"
DEFAULT_SEQ_END_TOKEN = "<seq_end>"

STR_TOKEN_INDEX = -360
DEFAULT_STR_TOKEN = "<str>"
DEFAULT_STR_PATCH_TOKEN = "<str_patch>"
DEFAULT_STR_START_TOKEN = "<str_start>"
DEFAULT_STR_END_TOKEN = "<str_end>"

In [None]:

@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = False
    seq_folder: Optional[str] = field(default=None)
    struc_folder: Optional[str] = field(default=None)
    # image_aspect_ratio: str = 'square'


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    freeze_mm_mlp_adapter: bool = field(default=False)
    mpt_attn_impl: Optional[str] = field(default="triton")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
            "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=16,
        metadata={"help": "How many bits to use."}
    )
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"
    mm_projector_lr: Optional[float] = None
    group_by_modality_length: bool = field(default=False)


In [209]:

local_rank = None


def rank0_print(*args):
    if local_rank == 0:
        print(*args)




### Preprocess Function

In [210]:

def preprocess_multimodal(
    sources: Sequence[str],
    data_args: DataArguments
) -> Dict:
    if not data_args.is_multimodal:
        return sources

    use_seq_start_end = getattr(data_args, "use_seq_start_end", False)
    use_str_start_end = getattr(data_args, "use_str_start_end", False)

    for source in sources:
        for sentence in source:
            value = sentence['value']

            # Handle <seq> tokens
            if DEFAULT_SEQ_TOKEN in value:
                value = value.replace(DEFAULT_SEQ_TOKEN, '').strip()
                value = DEFAULT_SEQ_TOKEN + '\n' + value
                if use_seq_start_end:
                    value = value.replace(
                        DEFAULT_SEQ_TOKEN,
                        DEFAULT_SEQ_START_TOKEN + DEFAULT_SEQ_TOKEN + DEFAULT_SEQ_END_TOKEN
                    )

            # Handle <str> tokens
            if DEFAULT_STR_TOKEN in value:
                value = value.replace(DEFAULT_STR_TOKEN, '').strip()
                value = DEFAULT_STR_TOKEN + '\n' + value
                if use_str_start_end:
                    value = value.replace(
                        DEFAULT_STR_TOKEN,
                        DEFAULT_STR_START_TOKEN + DEFAULT_STR_TOKEN + DEFAULT_STR_END_TOKEN
                    )

            sentence['value'] = value.strip()

    return sources


In [211]:
import re
def tokenizer_protein_token(prompt, tokenizer, seq_token_index=SEQ_TOKEN_INDEX, str_token_index=STR_TOKEN_INDEX, return_tensors=None):
    # Split the prompt on both <seq> and <str> while preserving the split tokens
    prompt_chunks = re.split(r'(<seq>|<str>)', prompt)

    # Tokenize the chunks and replace <seq> and <str> with their respective token indices
    tokenized_input = []
    for chunk in prompt_chunks:
        if chunk == '<seq>':
            tokenized_input.append(seq_token_index)
        elif chunk == '<str>':
            tokenized_input.append(str_token_index)
        else:
            # Tokenize the chunk normally
            tokenized_input.extend(tokenizer.encode(chunk, add_special_tokens=False))

    # If return_tensors is specified, return the result as a PyTorch tensor
    if return_tensors is not None:
        if return_tensors == 'pt':
            return torch.tensor(tokenized_input, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')

    return tokenized_input


In [212]:


# def preprocess_llama_2(
#     sources,
#     tokenizer: transformers.PreTrainedTokenizer,
#     has_image: bool = False
# ) -> Dict:
#     conv = default_conversation.copy()
#     roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

#     # Apply prompt templates
#     conversations = []
#     for i, source in enumerate(sources):
#         if roles[source[0]["from"]] != conv.roles[0]:
#             # Skip the first one if it is not from human
#             source = source[1:]

#         conv.messages = []
#         for j, sentence in enumerate(source):
#             role = roles[sentence["from"]]
#             assert role == conv.roles[j % 2], f"{i}"
#             conv.append_message(role, sentence["value"])
#         conversations.append(conv.get_prompt())

#     # Tokenize conversations

#     if has_image:
#         input_ids = torch.stack([tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
#     else:
#         input_ids = tokenizer(
#             conversations,
#             return_tensors="pt",
#             padding="longest",
#             max_length=tokenizer.model_max_length,
#             truncation=True,
#         ).input_ids

#     targets = input_ids.clone()

#     assert conv.sep_style == SeparatorStyle.LLAMA_2

#     # Mask targets
#     sep = "[/INST] "
#     for conversation, target in zip(conversations, targets):
#         total_len = int(target.ne(tokenizer.pad_token_id).sum())

#         rounds = conversation.split(conv.sep2)
#         cur_len = 1
#         target[:cur_len] = IGNORE_INDEX
#         for i, rou in enumerate(rounds):
#             if rou == "":
#                 break

#             parts = rou.split(sep)
#             if len(parts) != 2:
#                 break
#             parts[0] += sep

#             if has_image:
#                 round_len = len(tokenizer_protein_token(rou, tokenizer))
#                 instruction_len = len(tokenizer_protein_token(parts[0], tokenizer)) - 2
#             else:
#                 round_len = len(tokenizer(rou).input_ids)
#                 instruction_len = len(tokenizer(parts[0]).input_ids) - 2

#             target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

#             cur_len += round_len
#         target[cur_len:] = IGNORE_INDEX

#         if cur_len < tokenizer.model_max_length:
#             if cur_len != total_len:
#                 target[:] = IGNORE_INDEX
#                 print(
#                     f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
#                     f" (ignored)"
#                 )

#     return dict(
#         input_ids=input_ids,
#         labels=targets,
#     )


# def preprocess_v1(
#     sources,
#     tokenizer: transformers.PreTrainedTokenizer,
#     has_image: bool = False
# ) -> Dict:
#     conv = default_conversation.copy()
#     roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

#     # Apply prompt templates
#     conversations = []
#     for i, source in enumerate(sources):
#         if roles[source[0]["from"]] != conv.roles[0]:
#             # Skip the first one if it is not from human
#             source = source[1:]

#         conv.messages = []
#         for j, sentence in enumerate(source):
#             role = roles[sentence["from"]]
#             assert role == conv.roles[j % 2], f"{i}"
#             conv.append_message(role, sentence["value"])
#         conversations.append(conv.get_prompt())

#     # Tokenize conversations

#     if has_image:
#         input_ids = torch.stack([tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
#     else:
#         input_ids = tokenizer(
#             conversations,
#             return_tensors="pt",
#             padding="longest",
#             max_length=tokenizer.model_max_length,
#             truncation=True,
#         ).input_ids

#     targets = input_ids.clone()

#     assert conv.sep_style == SeparatorStyle.TWO

#     # Mask targets
#     sep = conv.sep + conv.roles[1] + ": "
#     for conversation, target in zip(conversations, targets):
#         total_len = int(target.ne(tokenizer.pad_token_id).sum())

#         rounds = conversation.split(conv.sep2)
#         cur_len = 1
#         target[:cur_len] = IGNORE_INDEX
#         for i, rou in enumerate(rounds):
#             if rou == "":
#                 break

#             parts = rou.split(sep)
#             if len(parts) != 2:
#                 break
#             parts[0] += sep

#             if has_image:
#                 round_len = len(tokenizer_protein_token(rou, tokenizer))
#                 instruction_len = len(tokenizer_protein_token(parts[0], tokenizer)) - 2
#             else:
#                 round_len = len(tokenizer(rou).input_ids)
#                 instruction_len = len(tokenizer(parts[0]).input_ids) - 2

#             if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
#                 round_len -= 1
#                 instruction_len -= 1

#             target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

#             cur_len += round_len
#         target[cur_len:] = IGNORE_INDEX

#         if cur_len < tokenizer.model_max_length:
#             if cur_len != total_len:
#                 target[:] = IGNORE_INDEX
#                 print(
#                     f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
#                     f" (ignored)"
#                 )

#     return dict(
#         input_ids=input_ids,
#         labels=targets,
#     )


# def preprocess_mpt(
#     sources,
#     tokenizer: transformers.PreTrainedTokenizer,
#     has_image: bool = False
# ) -> Dict:
#     conv = default_conversation.copy()
#     roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

#     # Apply prompt templates
#     conversations = []
#     for i, source in enumerate(sources):
#         if roles[source[0]["from"]] != conv.roles[0]:
#             # Skip the first one if it is not from human
#             source = source[1:]

#         conv.messages = []
#         for j, sentence in enumerate(source):
#             role = roles[sentence["from"]]
#             assert role == conv.roles[j % 2], f"{i}"
#             conv.append_message(role, sentence["value"])
#         conversations.append(conv.get_prompt())

#     # Tokenize conversations

#     if has_image:
#         input_ids = torch.stack([tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
#     else:
#         input_ids = tokenizer(
#             conversations,
#             return_tensors="pt",
#             padding="longest",
#             max_length=tokenizer.model_max_length,
#             truncation=True,
#         ).input_ids

#     targets = input_ids.clone()
#     assert conv.sep_style == SeparatorStyle.MPT

#     # Mask targets
#     sep = conv.sep + conv.roles[1]
#     for conversation, target in zip(conversations, targets):
#         total_len = int(target.ne(tokenizer.pad_token_id).sum())

#         rounds = conversation.split(conv.sep)
#         re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
#         for conv_idx in range(3, len(rounds), 2):
#             re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2]))    # user + gpt
#         cur_len = 0
#         target[:cur_len] = IGNORE_INDEX
#         for i, rou in enumerate(re_rounds):
#             if rou == "":
#                 break

#             parts = rou.split(sep)
#             if len(parts) != 2:
#                 break
#             parts[0] += sep

#             if has_image:
#                 round_len = len(tokenizer_protein_token(rou, tokenizer))
#                 instruction_len = len(tokenizer_protein_token(parts[0], tokenizer)) - 1
#             else:
#                 round_len = len(tokenizer(rou).input_ids)
#                 instruction_len = len(tokenizer(parts[0]).input_ids) - 1

#             if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
#                 round_len += 1
#                 instruction_len += 1

#             target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

#             cur_len += round_len
#         target[cur_len:] = IGNORE_INDEX

#         if cur_len < tokenizer.model_max_length:
#             if cur_len != total_len:
#                 target[:] = IGNORE_INDEX
#                 print(
#                     f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
#                     f" (ignored)"
#                 )

#     return dict(
#         input_ids=input_ids,
#         labels=targets,
#     )


# def preprocess_plain(
#     sources: Sequence[str],
#     tokenizer: transformers.PreTrainedTokenizer,
# ) -> Dict:
#     # add end signal and concatenate together
#     conversations = []
#     for source in sources:
#         assert len(source) == 2
#         assert DEFAULT_IMAGE_TOKEN in source[0]['value']
#         source[0]['value'] = DEFAULT_IMAGE_TOKEN
#         conversation = source[0]['value'] + source[1]['value'] + default_conversation.sep
#         conversations.append(conversation)
#     # tokenize conversations
#     input_ids = [tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
#     targets = copy.deepcopy(input_ids)
#     for target, source in zip(targets, sources):
#         tokenized_len = len(tokenizer_protein_token(source[0]['value'], tokenizer))
#         target[:tokenized_len] = IGNORE_INDEX

#     return dict(input_ids=input_ids, labels=targets)




# def preprocess_plain(
#     sources: Sequence[str],
#     tokenizer: transformers.PreTrainedTokenizer,
# ) -> Dict:
#     # add end signal and concatenate together
#     conversations = []
#     for source in sources:
#         assert len(source) == 2
#         assert DEFAULT_IMAGE_TOKEN in source[0]['value']
#         source[0]['value'] = DEFAULT_IMAGE_TOKEN
#         conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
#         conversations.append(conversation)
#     # tokenize conversations
#     input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
#     targets = copy.deepcopy(input_ids)
#     for target, source in zip(targets, sources):
#         tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
#         target[:tokenized_len] = IGNORE_INDEX

#     return dict(input_ids=input_ids, labels=targets)

In [213]:
def preprocess_llama_2_protein(sources, tokenizer, has_protein=True) -> Dict:
    conv = default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
    conversations = []

    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            source = source[1:]
        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    if has_protein:
        input_ids = torch.stack(
            [tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0
        )
    else:
        input_ids = tokenizer(conversations, return_tensors="pt", padding="longest",
                              max_length=tokenizer.model_max_length, truncation=True).input_ids

    targets = input_ids.clone()
    assert conv.sep_style == SeparatorStyle.LLAMA_2
    sep = "[/INST] "

    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())
        rounds = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_INDEX

        for i, rou in enumerate(rounds):
            if not rou:
                break
            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_protein:
                round_len = len(tokenizer_protein_token(rou, tokenizer))
                instr_len = len(tokenizer_protein_token(parts[0], tokenizer)) - 2
            else:
                round_len = len(tokenizer(rou).input_ids)
                instr_len = len(tokenizer(parts[0]).input_ids) - 2

            target[cur_len:cur_len + instr_len] = IGNORE_INDEX
            cur_len += round_len

        target[cur_len:] = IGNORE_INDEX

        # if cur_len != total_len:
        #     target[:] = IGNORE_INDEX
        #     print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored)")


        if abs(cur_len - total_len) > 2:
            print(f"[W] Mismatch ignored: cur_len={cur_len}, total_len={total_len}")


    return dict(input_ids=input_ids, labels=targets)

def preprocess_v1_protein(sources, tokenizer, has_protein=True) -> Dict:
    conv = default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
    conversations = []

    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            source = source[1:]
        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    if has_protein:
        input_ids = torch.stack(
            [tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0
        )
    else:
        input_ids = tokenizer(conversations, return_tensors="pt", padding="longest",
                              max_length=tokenizer.model_max_length, truncation=True).input_ids

    targets = input_ids.clone()
    assert conv.sep_style == SeparatorStyle.TWO
    sep = conv.sep + conv.roles[1] + ": "

    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())
        rounds = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_INDEX

        for i, rou in enumerate(rounds):
            if not rou:
                break
            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_protein:
                round_len = len(tokenizer_protein_token(rou, tokenizer))
                instr_len = len(tokenizer_protein_token(parts[0], tokenizer)) - 2
            else:
                round_len = len(tokenizer(rou).input_ids)
                instr_len = len(tokenizer(parts[0]).input_ids) - 2

            target[cur_len:cur_len + instr_len] = IGNORE_INDEX
            cur_len += round_len

        target[cur_len:] = IGNORE_INDEX

        # if cur_len != total_len:
        #     target[:] = IGNORE_INDEX
        #     print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored)")


        if abs(cur_len - total_len) > 2:
            print(f"[W] Mismatch ignored: cur_len={cur_len}, total_len={total_len}")

    return dict(input_ids=input_ids, labels=targets)


def preprocess_mpt_protein(sources, tokenizer, has_protein=True) -> Dict:
    conv = default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
    conversations = []

    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            source = source[1:]
        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    if has_protein:
        input_ids = torch.stack(
            [tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0
        )
    else:
        input_ids = tokenizer(conversations, return_tensors="pt", padding="longest",
                              max_length=tokenizer.model_max_length, truncation=True).input_ids

    targets = input_ids.clone()
    assert conv.sep_style == SeparatorStyle.MPT
    sep = conv.sep + conv.roles[1]

    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())
        rounds = conversation.split(conv.sep)
        re_rounds = [conv.sep.join(rounds[:3])]  # system + user + gpt
        for k in range(3, len(rounds), 2):
            re_rounds.append(conv.sep.join(rounds[k:k+2]))

        cur_len = 0
        for i, rou in enumerate(re_rounds):
            if not rou:
                break
            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_protein:
                round_len = len(tokenizer_protein_token(rou, tokenizer))
                instr_len = len(tokenizer_protein_token(parts[0], tokenizer)) - 2
            else:
                round_len = len(tokenizer(rou).input_ids)
                instr_len = len(tokenizer(parts[0]).input_ids) - 2

            target[cur_len:cur_len + instr_len] = IGNORE_INDEX
            cur_len += round_len

        target[cur_len:] = IGNORE_INDEX

        # if cur_len != total_len:
        #     target[:] = IGNORE_INDEX
        #     print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored)")


        if abs(cur_len - total_len) > 2:
            print(f"[W] Mismatch ignored: cur_len={cur_len}, total_len={total_len}")

    return dict(input_ids=input_ids, labels=targets)


def preprocess_plain_protein(
    sources: Sequence[Dict[str, str]],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    conversations = []
    for source in sources:
        assert len(source) == 2
        assert DEFAULT_SEQ_TOKEN in source[0]['value'] or DEFAULT_STR_TOKEN in source[0]['value'], \
            "Expected <seq> or <str> in the input."

        # Construct conversation string
        conversation = source[0]['value'] + source[1]['value'] + default_conversation.sep
        conversations.append(conversation)

    # Tokenize each prompt with protein tokenizer
    input_ids = [tokenizer_protein_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]

    # Clone to labels
    targets = copy.deepcopy(input_ids)

    # Mask the prompt (first part) in the target
    for target, source in zip(targets, sources):
        prompt_len = len(tokenizer_protein_token(source[0]['value'], tokenizer))
        target[:prompt_len] = IGNORE_INDEX

    return dict(input_ids=input_ids, labels=targets)



In [214]:

def _tokenize_fn(strings: Sequence[str],
                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ) for text in strings
    ]
    input_ids = labels = [
        tokenized.input_ids[0] for tokenized in tokenized_list
    ]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


# The following code would more adpated for the protein-aware tokenizer.

# def _tokenize_fn(strings: Sequence[str],
#                  tokenizer: transformers.PreTrainedTokenizer) -> Dict:
#     """Tokenize a list of strings using protein-aware tokenizer."""
#     input_ids = []
#     input_ids_lens = []

#     for text in strings:
#         ids = tokenizer_protein_token(text, tokenizer, return_tensors='pt')
#         input_ids.append(ids)
#         input_ids_lens.append(ids.ne(tokenizer.pad_token_id).sum().item())

#     return dict(
#         input_ids=input_ids,
#         labels=copy.deepcopy(input_ids),
#         input_ids_lens=input_ids_lens,
#         labels_lens=input_ids_lens,
#     )



# def _mask_targets(target, tokenized_lens, speakers):
#     cur_idx = tokenized_lens[0]  # system prompt
#     target[:cur_idx] = IGNORE_INDEX
#     tokenized_lens = tokenized_lens[1:]

#     for tokenized_len, speaker in zip(tokenized_lens, speakers):
#         if speaker == "human":
#             target[cur_idx:cur_idx + tokenized_len] = IGNORE_INDEX
#         cur_idx += tokenized_len

#     target[cur_idx:] = IGNORE_INDEX  # mask remainder if any


# def _add_speaker_and_signal(header, source, get_conversation=True):
#     """Add speaker tokens and signals to each sentence in the dialog."""
#     BEGIN_SIGNAL = "### "
#     END_SIGNAL = "\n"
#     conversation = header

#     for sentence in source:
#         from_str = sentence["from"].lower()
#         if from_str == "human":
#             from_str = default_conversation.roles[0]
#         elif from_str == "gpt":
#             from_str = default_conversation.roles[1]
#         else:
#             from_str = "unknown"

#         # Normalize multimodal tokens if needed
#         sentence["value"] = sentence["value"].replace("<SEQ>", "<seq>").replace("<STR>", "<str>")
#         sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL

#         if get_conversation:
#             conversation += sentence["value"]

#     conversation += BEGIN_SIGNAL
#     return conversation


def _mask_targets(target, tokenized_lens, speakers):
    # cur_idx = 0
    cur_idx = tokenized_lens[0]
    tokenized_lens = tokenized_lens[1:]
    target[:cur_idx] = IGNORE_INDEX
    for tokenized_len, speaker in zip(tokenized_lens, speakers):
        if speaker == "human":
            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
        cur_idx += tokenized_len



def _add_speaker_and_signal(header, source, get_conversation=True):
    """Add speaker and start/end signal on each round."""
    BEGIN_SIGNAL = "### "
    END_SIGNAL = "\n"
    conversation = header
    for sentence in source:
        from_str = sentence["from"]
        if from_str.lower() == "human":
            from_str = default_conversation.roles[0]
        elif from_str.lower() == "gpt":
            from_str = default_conversation.roles[1]
        else:
            from_str = 'unknown'
        sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
                             sentence["value"] + END_SIGNAL)
        if get_conversation:
            conversation += sentence["value"]
    conversation += BEGIN_SIGNAL
    return conversation


In [215]:
def preprocess(
    sources: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    has_protein: bool = True  # replaces has_image
) -> Dict:
    """
    Given a list of sources, each is a conversation list. This transform:
    1. Add signal '### ' at the beginning each sentence, with end signal '\n';
    2. Concatenate conversations together;
    3. Tokenize the concatenated conversation;
    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
    """
    # Dispatch to conversation-style specific preprocessors
    if default_conversation.sep_style == SeparatorStyle.PLAIN:
        return preprocess_plain_protein(sources, tokenizer)

    if default_conversation.sep_style == SeparatorStyle.LLAMA_2:
        return preprocess_llama_2_protein(sources, tokenizer, has_protein=has_protein)

    if default_conversation.version.startswith("v1"):
        return preprocess_v1_protein(sources, tokenizer, has_protein=has_protein)

    if default_conversation.version == "mpt":
        return preprocess_mpt_protein(sources, tokenizer, has_protein=has_protein)

    # Fallback: generic instruction formatting + tokenization
    conversations = []
    for source in sources:
        header = f"{default_conversation.system}\n\n"
        conversation = _add_speaker_and_signal(header, source)
        conversations.append(conversation)

    # Tokenize full conversations
    if has_protein:
        input_ids = [
            tokenizer_protein_token(prompt, tokenizer, return_tensors='pt')
            for prompt in conversations
        ]
    else:
        conversations_tokenized = _tokenize_fn(conversations, tokenizer)
        input_ids = conversations_tokenized["input_ids"]

    targets = copy.deepcopy(input_ids)

    for target, source in zip(targets, sources):
        if has_protein:
            # Compute token lengths for header and each utterance using protein-aware tokenizer
            prompts = [f"{default_conversation.system}\n\n"] + [s["value"] for s in source]
            tokenized_lens = [
                len(tokenizer_protein_token(p, tokenizer))
                for p in prompts
            ]
        else:
            tokenized_lens = _tokenize_fn(
                [f"{default_conversation.system}\n\n"] + [s["value"] for s in source],
                tokenizer
            )["input_ids_lens"]

        speakers = [s["from"] for s in source]
        _mask_targets(target, tokenized_lens, speakers)

    return dict(input_ids=input_ids, labels=targets)

### Dataset object

In [None]:
class LazySupervisedProteinDataset(Dataset):
    """Protein multimodal dataset for instruction tuning."""


    def __init__(self,
                 data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_args,
                 seq_tower=None,
                 struc_tower=None):
        super().__init__()
        with open(data_path, 'r') as f:
            self.list_data_dict = [json.loads(line) for line in f if line.strip()]
        self.tokenizer = tokenizer
        self.data_args = data_args
        self.seq_tower = seq_tower  # Should have .tokenize()
        self.struc_tower = struc_tower  # Should have .structure_processor()

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        return [
            sum(len(conv['value'].split()) for conv in sample['conversations']) +
            (128 if 'seq' in sample or 'str' in sample else 0)
            for sample in self.list_data_dict
        ]

    @property
    def modality_lengths(self):
        lengths = []
        for sample in self.list_data_dict:
            base_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
            if 'seq' in sample or 'str' in sample:
                lengths.append(base_len)
            else:
                lengths.append(-base_len)
        return lengths

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        sample = self.list_data_dict[idx]
        conversations = copy.deepcopy(sample["conversations"])
        # Step 1: Insert <seq>/<str> tokens
        sources = preprocess_multimodal([conversations], self.data_args)

        # Step 2: Tokenize conversation for decoder
        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_protein=('sequence' in sample or 'structure_path' in sample)
        )
        data_dict = {
            "input_ids": data_dict["input_ids"][0],
            "labels": data_dict["labels"][0]
        }

        # Step 3: Protein sequence processing
        if "sequence" in sample and self.seq_tower is not None:
            seq_tokenized = self.seq_tower.tokenize([sample["sequence"]],
                                                    return_tensors='pt', padding=True, truncation=True)
            data_dict["seq_input_ids"] = seq_tokenized["input_ids"][0]
            data_dict["seq_attention_mask"] = seq_tokenized["attention_mask"][0]

        # Step 4: Structure preprocessing (L, 3, 3)
        if "structure_path" in sample and self.struc_tower is not None:
            try:
                coords = self.struc_tower.structure_processor(
                    sample["structure_path"],
                    chain=sample.get("structure_chain", "A")
                )
                data_dict["struc_coords"] = coords  # tensor, will be moved in collator
            except Exception as e:
                print(f"[WARN] Structure loading failed for idx {idx}: {e}")
                data_dict["struc_coords"] = None
        return data_dict


In [244]:
class OPISupervisedDataset(Dataset):
    """
    OPI-loading dataset for Pannot multimodal instruction tuning with sequence + structure.
    Supports OPI-style data with keys: 'instruction', 'input' (sequence), 'output', and optionally 'structure'.
    """

    def __init__(self, data_path: str,
                 tokenizer,
                 data_args):
        super().__init__()
        with open(data_path, 'r') as f:
            self.list_data_dict = [json.loads(line) for line in f if line.strip()]

        self.tokenizer = tokenizer
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        """
        Token lengths for each sample (approximate).
        Includes fixed offset for <seq> and <str> placeholders.
        """
        length_list = []
        for sample in self.list_data_dict:
            inst_len = len(sample['instruction'].split())
            out_len = len(sample['output'].split())
            modality_tokens = 0
            if 'input' in sample:
                modality_tokens += 4  # rough length for <seq>
            if 'structure' in sample:
                modality_tokens += 4  # rough length for <str>
            length_list.append(inst_len + out_len + modality_tokens)
        return length_list

    @property
    def modality_lengths(self):
        """
        Used to group samples by modality in length-based batching.
        All positive → indicates multimodal.
        """
        return self.lengths

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        sample = self.list_data_dict[idx]
        instruction = sample['instruction']
        output = sample['output']
        sequence = sample.get('input', None)
        structure = sample.get('structure', None)

        # Build multimodal-aware prompt with <seq> and/or <str>
        prompt_parts = [instruction]
        if sequence is not None:
            prompt_parts.append(DEFAULT_SEQ_TOKEN)
        if structure is not None:
            prompt_parts.append(DEFAULT_STR_TOKEN)

        prompt = '\n'.join(prompt_parts)

        conversation = [
            {"from": "human", "value": prompt},
            {"from": "gpt", "value": output}
        ]

        # Inject multimodal wrappers
        sources = preprocess_multimodal([conversation], self.data_args)

        # Tokenize + label masking
        processed = preprocess(
            sources, self.tokenizer, has_protein=True
        )

        data_dict = {
            "input_ids": processed["input_ids"][0],
            "labels": processed["labels"][0]
        }

        if sequence is not None:
            data_dict["sequence"] = sequence
        if structure is not None:
            data_dict["structure"] = structure

        return data_dict

In [245]:

@dataclass
class DataCollatorForSupervisedProteinDataset(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids = [instance["input_ids"] for instance in instances]
        labels = [instance["labels"] for instance in instances]

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )

        batch = {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
        }

        # Handle optional protein sequence features
        if "seq_input_ids" in instances[0]:
            seq_input_ids = [inst["seq_input_ids"] for inst in instances]
            seq_attention_mask = [inst["seq_attention_mask"] for inst in instances]
            batch["seq_input_ids"] = torch.nn.utils.rnn.pad_sequence(
                seq_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
            )
            batch["seq_attention_mask"] = torch.nn.utils.rnn.pad_sequence(
                seq_attention_mask, batch_first=True, padding_value=0
            )

        # Handle structure features (pad to max L)
        if "struc_coords" in instances[0]:
            coords_list = []
            max_len = max((inst["struc_coords"].shape[0] if inst["struc_coords"] is not None else 0)
                          for inst in instances)

            for inst in instances:
                coord = inst["struc_coords"]
                if coord is None:
                    padded = torch.full((max_len, 3, 3), float("nan"))
                else:
                    pad_len = max_len - coord.shape[0]
                    padded = torch.nn.functional.pad(coord, (0, 0, 0, 0, 0, pad_len), value=float("nan"))
                coords_list.append(padded)

            batch["struc_coords"] = torch.stack(coords_list)

        return batch
# @dataclass
# class DataCollatorForSupervisedProteinDataset:
#     tokenizer: transformers.PreTrainedTokenizer

#     def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
#         input_ids = [instance["input_ids"] for instance in instances]
#         labels = [instance["labels"] for instance in instances]

#         input_ids = torch.nn.utils.rnn.pad_sequence(
#             input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
#         )
#         labels = torch.nn.utils.rnn.pad_sequence(
#             labels, batch_first=True, padding_value=IGNORE_INDEX
#         )

#         batch = {
#             "input_ids": input_ids,
#             "labels": labels,
#             "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
#         }

#         # Optional: keep raw sequence/structure for encoder use
#         if "sequence" in instances[0]:
#             batch["sequences"] = [inst.get("sequence", None) for inst in instances]

#         if "structure" in instances[0]:
#             batch["structures"] = [inst.get("structure", None) for inst in instances]

#         return batch


In [246]:
@dataclass
class DataCollatorForOPI:
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
        # Token IDs & Labels
        input_ids = [inst["input_ids"] for inst in instances]
        labels = [inst["labels"] for inst in instances]

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )

        # Truncate to max model length
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]

        batch = {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
        }

        # Optional: keep raw sequence/structure for encoder use
        if "sequence" in instances[0]:
            batch["sequences"] = [inst.get("sequence", None) for inst in instances]

        if "structure" in instances[0]:
            batch["structures"] = [inst.get("structure", None) for inst in instances]

        return batch

In [247]:
def make_supervised_data_module(tokenizer, data_args) -> Dict:
    train_dataset = LazySupervisedProteinDataset(
        tokenizer=tokenizer,
        data_path=data_args.data_path,
        data_args=data_args,
    )
    data_collator = DataCollatorForSupervisedProteinDataset(tokenizer=tokenizer)
    return dict(
        train_dataset=train_dataset,
        eval_dataset=None,
        data_collator=data_collator
    )

In [248]:
def make_opi_supervised_data_module(tokenizer, data_args) -> Dict:
    train_dataset = OPISupervisedDataset(
        tokenizer=tokenizer,
        data_path=data_args.data_path,
        data_args=data_args,
    )
    data_collator = DataCollatorForOPI(tokenizer=tokenizer)
    return dict(
        train_dataset=train_dataset,
        eval_dataset=None,
        data_collator=data_collator
    )

In [249]:
from transformers import AutoModel, AutoTokenizer, PretrainedConfig
import torch.nn as nn
class ESMSeqTower(nn.Module):
    def __init__(
        self,
        model_name: str = 'facebook/esm2_t6_8M_UR50D',
        args=None,
        delay_load: bool = False,
        no_pooling: bool = False,
    ):
        super().__init__()
        self.is_loaded = False
        self.model_name = model_name
        self.args = args

        self.select_layer = getattr(args, 'protein_select_layer', -1)
        self.pooling = getattr(args, 'protein_pooling', 'cls')  # 'cls' or 'mean'
        self.no_pooling = no_pooling

        if not delay_load or getattr(args, 'unfreeze_mm_seq_tower', False):
            self.load_model()

    def load_model(self, device_map=None):
        if self.is_loaded:
            print(f'{self.model_name} is already loaded. Skipping load.')
            return

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
        self.encoder = AutoModel.from_pretrained(
            self.model_name,
            trust_remote_code=True,
            output_hidden_states=True,
            device_map=device_map
        )
        self.encoder.requires_grad_(False)
        self.is_loaded = True

    @torch.no_grad()
    def forward(self, input_ids, attention_mask):
        if not self.is_loaded:
            self.load_model()

        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)

        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.hidden_states[self.select_layer]

        if self.no_pooling:
            return hidden_states  # (B, L, D)

        if self.pooling == 'cls':
            return hidden_states[:, 0, :]  # (B, D)
        elif self.pooling == 'mean':
            mask = attention_mask.unsqueeze(-1).expand_as(hidden_states)
            sum_emb = torch.sum(hidden_states * mask, dim=1)
            counts = mask.sum(dim=1).clamp(min=1e-9)
            return sum_emb / counts
        else:
            raise ValueError(f"Unsupported pooling type: {self.pooling}")

    def tokenize(self, sequences, return_tensors='pt', padding=True, truncation=True, max_length=1024):
        if not self.is_loaded:
            self.load_model()
        return self.tokenizer(
            sequences,
            return_tensors=return_tensors,
            padding=padding,
            truncation=truncation,
            max_length=max_length
        )

    @property
    def dummy_feature(self):
        if self.no_pooling:
            return torch.zeros(1, 1, self.hidden_size, device=self.device, dtype=self.dtype)
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return self.encoder.dtype if self.is_loaded else torch.get_default_dtype()

    @property
    def device(self):
        return next(self.encoder.parameters()).device if self.is_loaded else torch.device('cpu')

    @property
    def config(self):
        return self.encoder.config if self.is_loaded else PretrainedConfig.from_pretrained(self.model_name)

    @property
    def hidden_size(self):
        return self.config.hidden_size

### Load OPI data

In [250]:
# # === Step 1: Create dummy OPI-like test data ===
# dummy_data = [
#     {
#         "instruction": "What is the function of this protein?",
#         "input": "MSEQNNTEMTFQIQRIYTKDISFEAPNAPHVFQKDWMA",
#         "structure": "dummy_structure_info",
#         "output": "It acts as a kinase inhibitor."
#     }
# ]

# with NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
#     json.dump(dummy_data, f)
#     dummy_json_path = f.name
opi_demo_path = "/home/yining_yang/Documents/lm/Pannot/data/OPI_full_1.61M_train_first_10000.json"


In [251]:
# Send the embeddings to the PannotLlamaForCausalLM
# Load TinyLlama config and tokenizer

# === Step 2: Define dummy tokenizer and args ===
pretrained_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
# config = PannotConfig.from_pretrained(pretrained_model_name)

tokenizer.add_tokens([DEFAULT_SEQ_TOKEN, DEFAULT_STR_TOKEN], special_tokens=True)
tokenizer.pad_token = tokenizer.eos_token



In [252]:
tokenizer_protein_token("This is a great protein I want to study: <seq> <str>", tokenizer)

[910,
 338,
 263,
 2107,
 26823,
 306,
 864,
 304,
 6559,
 29901,
 29871,
 -330,
 259,
 -360]

In [253]:
tokenizer

LlamaTokenizerFast(name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<seq>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32001: AddedToken("<str>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [254]:
seq_tower = ESMSeqTower(no_pooling=True)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [255]:
from types import SimpleNamespace

In [256]:
data_args = SimpleNamespace(
    data_path=opi_demo_path,
    is_multimodal=True,
    use_seq_start_end=False,
    use_str_start_end=False
)

# === Step 3: Call make_supervised_data_module ===
data_module = make_opi_supervised_data_module(tokenizer, data_args)
dataset = data_module["train_dataset"]
collator = data_module["data_collator"]

In [257]:
# print("Prompt:", sample)
# print("Tokenized:", tokenizer(sample).input_ids)
# print("Expected total len:", sample)
# print("Calculated cur_len:", sample)

In [258]:
# === Step 4: Test a batch ===
sample = dataset[0]
print("\n--- Sample Output ---")
print("input_ids:", sample["input_ids"])
print("labels:", sample["labels"])
print("sequence:", sample["sequence"])
# print("structure:", sample["structure"])

# Collate one batch
batch = collator([sample])
print("\n--- Collated Batch ---")
print("input_ids shape:", batch["input_ids"].shape)
print("labels shape:", batch["labels"].shape)
print("attention_mask shape:", batch["attention_mask"].shape)
print("seq_input_ids:", batch.get("seq_input_ids", "[seq_input_ids not returned by collator]"))
print("struc_coords:", batch.get("struc_coords", "[struc_coords not returned by collator]"))



--- Sample Output ---
input_ids: tensor([  319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116, 21082,
        20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,   322,
         1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155, 29889,
         3148,  1001, 29901, 29871,  -330, 29871,    13, 23084,   919,   278,
        13303, 29361,   310,   278,  1494, 26823, 15602,   319,  1799,  9047,
        13566, 29901,   319,  1195, 29877,   562,  2904, 29899, 29873, 29934,
         3521, 14710,   300,   559, 29936, 27884, 29899, 19672, 29936,  8045,
         3332,  3333, 29885, 29936, 21894,   559, 29936,   405,  1682,   280,
          327,   680, 29899, 19672, 29936, 14409,   262,   289,  2363,   948,
        26533,     2])
labels: tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,

In [259]:
batch

{'input_ids': tensor([[  319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116, 21082,
          20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,   322,
           1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155, 29889,
           3148,  1001, 29901, 29871,  -330, 29871,    13, 23084,   919,   278,
          13303, 29361,   310,   278,  1494, 26823, 15602,   319,  1799,  9047,
          13566, 29901,   319,  1195, 29877,   562,  2904, 29899, 29873, 29934,
           3521, 14710,   300,   559, 29936, 27884, 29899, 19672, 29936,  8045,
           3332,  3333, 29885, 29936, 21894,   559, 29936,   405,  1682,   280,
            327,   680, 29899, 19672, 29936, 14409,   262,   289,  2363,   948,
          26533,     2]]),
 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -10

In [260]:
tokenizer.encode("Predict the functional keywords of the following protein sequences")

[1, 21099, 919, 278, 13303, 29361, 310, 278, 1494, 26823, 15602]

In [261]:
tokenizer.decode([  319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116, 21082,
        20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,   322,
         1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155, 29889,
         3148,  1001, 29901, 29871])

"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "

In [262]:
tokenizer.decode([29871,    13, 23084,   919,   278,
        13303, 29361,   310,   278,  1494, 26823, 15602,   319,  1799,  9047,
        13566, 29901,   319,  1195, 29877,   562,  2904, 29899, 29873, 29934,
         3521, 14710,   300,   559, 29936, 27884, 29899, 19672, 29936,  8045,
         3332,  3333, 29885, 29936, 21894,   559, 29936,   405,  1682,   280,
          327,   680, 29899, 19672, 29936, 14409,   262,   289,  2363,   948,
        26533,2])

'\nPredict the functional keywords of the following protein sequences ASSISTANT: Aminoacyl-tRNA synthetase; ATP-binding; Cytoplasm; Ligase; Nucleotide-binding; Protein biosynthesis</s>'

In [263]:
tokenizer.decode([319,  1195, 29877,   562,  2904, 29899, 29873, 29934,
         3521, 14710,   300,   559, 29936, 27884, 29899, 19672, 29936,  8045,
         3332,  3333, 29885, 29936, 21894,   559, 29936,   405,  1682,   280,
          327,   680, 29899, 19672, 29936, 14409,   262,   289,  2363,   948,
        26533,     2])

'Aminoacyl-tRNA synthetase; ATP-binding; Cytoplasm; Ligase; Nucleotide-binding; Protein biosynthesis</s>'

In [None]:
# {
#   "conversations": [
#     {"from": "human", "value": "Here is a protein: <seq> and its structure: <str>. What is its function?"},
#     {"from": "gpt", "value": "This protein is likely involved in..." }
#   ],
#   "seq_feat": [ ... ],  // 1D or 2D float array from ESM2
#   "str_feat": [ ... ]   // 1D or 2D float array from ESM-IF1
# }