<a href="https://colab.research.google.com/github/Mattywonger/ESM3_tutorial/blob/main/ESM3_tokenization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Tokenization refers to the process of converting high-dimensional data into numerical representations.

For instance, in GPT-3 model, sentences are broken down into individual words. Individual words are further broken down into sub-words, also known as tokens. Then there is a dictionary that associates a numerical value with a token. For example, the word "My" might correspond to the number 1 and the word "name" corresponds to the number 16.

Then the phrase my name might be represented with the matrix (1,16)

In the ESM3 model, there are 6 inputs that needs to be tokenized:


1.   Sequence
2.   Structure
3.   Function Annotation
4.   Secondary Structure
5.   Surface Area Solvent Accessible (SASA)
6.   Residue Annotation



# Sequence Tokenization

In [1]:
pip install --upgrade pip

Collecting pip
  Downloading pip-24.2-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-24.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-24.2


In [2]:
!pip install sentencepiece



In [3]:
!pip install tokenizers
!pip install transformers



In [4]:
from typing import Protocol, runtime_checkable


@runtime_checkable
class EsmTokenizerBase(Protocol):
    def encode(self, *args, **kwargs):
        ...

    def decode(self, *args, **kwargs):
        ...

    @property
    def mask_token(self) -> str:
        ...

    @property
    def mask_token_id(self) -> int:
        ...

    @property
    def bos_token(self) -> str:
        ...

    @property
    def bos_token_id(self) -> int:
        ...

    @property
    def eos_token(self) -> str:
        ...

    @property
    def eos_token_id(self) -> int:
        ...

    @property
    def pad_token(self) -> str:
        ...

    @property
    def pad_token_id(self) -> int:
        ...


In [5]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast


In [6]:
SEQUENCE_VOCAB = [
    "<cls>", "<pad>", "<eos>", "<unk>",
    "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
    "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
    "O", ".", "-", "|",
    "<mask>",
]


In [7]:
class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
    """
    Constructs an ESM tokenizer.
    """

    model_input_names = ["sequence_tokens", "attention_mask"]

    def __init__(
        self,
        unk_token="<unk>",
        cls_token="<cls>",
        pad_token="<pad>",
        mask_token="<mask>",
        eos_token="<eos>",
        chainbreak_token="|",
        **kwargs,
    ):
        all_tokens = SEQUENCE_VOCAB
        token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}

        # a character-level tokenizer is the same as BPE with no token merges
        bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
        tokenizer = Tokenizer(bpe)
        special_tokens = [cls_token, pad_token, mask_token, eos_token, chainbreak_token]
        additional_special_tokens = [chainbreak_token]

        tokenizer.add_special_tokens(
            special_tokens,
        )

        # This is where we configure the automatic addition of special tokens when we call
        # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
        # sequences are merged if you want.
        tokenizer.post_processor = TemplateProcessing(  # type: ignore
            single="<cls> $A <eos>",
            special_tokens=[
                ("<cls>", tokenizer.token_to_id("<cls>")),
                ("<eos>", tokenizer.token_to_id("<eos>")),
            ],
        )
        super().__init__(
            tokenizer_object=tokenizer,
            unk_token=unk_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            eos_token=eos_token,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )

    # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
    @property
    def bos_token(self):
        return self.cls_token

    @property
    def bos_token_id(self):
        return self.cls_token_id


In [64]:
sequence_tokenizer = EsmSequenceTokenizer()
sequence1 = "AGCTGACCTGAAGTCCGATCGTAACTGGCATAGCGTATGCCGTACGTAGGCTACGATCGATAGCTGACCGT"
sequence2 = "ATGCTAGCTGACCGTACGTTAGCTAGCTGATCGTAGCTAGTCGATCGTAGCTGATCGTAGCTAGCTAGCTA"
print(len(sequence1))
print(len(sequence2))
sequence_tokens1 = sequence_tokenizer.encode(
        sequence1, add_special_tokens=True
    )
sequence_tokens2 = sequence_tokenizer.encode(
        sequence2, add_special_tokens=True
    )
sequence_2d = torch.stack((torch.tensor(sequence_tokens1), torch.tensor(sequence_tokens2)), dim=0)
print(sequence_2d.shape)

71
71
torch.Size([2, 73])


#Feature Encoding


In [81]:
import torch.nn as nn
import torch

In [82]:
sequence_embed = nn.Embedding(64, 1536)

In [83]:
sequence_after_embed = sequence_embed(sequence_2d)

In [84]:
print(sequence_after_embed.shape)

torch.Size([2, 73, 1536])


In [72]:
print(sequence_after_embed[0].shape)

torch.Size([73, 1536])


The Inputs to the transformer architecture must be

> "The input tensor of shape (batch_size, sequence_length, d_model)."


>  sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length).





# Transformer architecture

In [73]:
import functools
!pip install einops
import torch.nn.functional as F
import einops



### Multi-head attention class

In [74]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        bias: bool = False,
        qk_layernorm: bool = True,
    ):
        super().__init__()

        self.d_model = d_model
        self.n_heads = n_heads

        self.d_head = self.d_model // self.n_heads
        self.layernorm_qkv = nn.Sequential(
            nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias)
        )
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)

        if qk_layernorm:
            self.q_ln = nn.LayerNorm(d_model, bias=bias)
            self.k_ln = nn.LayerNorm(d_model, bias=bias)
        else:
            self.q_ln = nn.Identity()
            self.k_ln = nn.Identity()
    def forward(self,x,seq_id):
        qkv_BLD3 = self.layernorm_qkv(x)
        query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
        query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD)
        #query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) ##WILL IMPLEMENT

        n_heads = self.n_heads
        reshaper = functools.partial(
            einops.rearrange, pattern="b s (h d) -> b h s d", h=n_heads
        )

        query_BHLD, key_BHLD, value_BHLD = map(
            reshaper, (query_BLD, key_BLD, value_BLD)
        )

        # Where True, enable participation in attention.
        mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
        mask_BHLL = mask_BLL.unsqueeze(1)

        context_BHLD = F.scaled_dot_product_attention(
            query_BHLD, key_BHLD, value_BHLD, mask_BHLL
        )
        context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
        return self.out_proj(context_BLD)

##SwiGLU activation function

In [75]:
class SwiGLU(nn.Module):
    """
    SwiGLU activation function as an nn.Module, allowing it to be used within nn.Sequential.
    This module splits the input tensor along the last dimension and applies the SiLU (Swish)
    activation function to the first half, then multiplies it by the second half.
    """

    def __init__(self):
        super(SwiGLU, self).__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return F.silu(x1) * x2

The chunk function splits the input into two halves along the last dimension.

Then SiLU activation is applied to X1 and multiplied element-wise with

In [76]:
def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
    # set hidden dimesion to nearest multiple of 256 after expansion ratio
    return int(((expansion_ratio * d_model) + 255) // 256 * 256)

In [77]:
def swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool):
    return nn.Sequential(
        nn.LayerNorm(d_model),
        nn.Linear(
            d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=bias
        ),
        SwiGLU(),
        nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=bias),
    )

##Transformer Block

In [78]:
class UnifiedTransformerBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        v_heads: int | None = None,
        bias: bool = False,
        expansion_ratio: float = 4.0,
        residue_scaling_factor: float = 1,
        mask_and_zero_frameless: bool = False,
        qk_layernorm: bool = True,
    ):
        super().__init__()
        self.attn = MultiHeadAttention(
                d_model, n_heads, bias, qk_layernorm=qk_layernorm
            )
        self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, bias)
        self.scaling_factor = residue_scaling_factor


In [79]:
import math
class TransformerStack(nn.Module):
    """
    A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
    which can either be geometric attention or standard multi-head attention.
"""
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        v_heads: int | None,
        n_layers: int,
        n_layers_geom: int = 1,
        scale_residue: bool = True,
        mask_and_zero_frameless: bool = False,
        bias: bool = False,
        qk_layernorm: bool = True,
        ffn_type: str = "swiglu",  # swiglu | gelu
        expansion_ratio: float = 8 / 3,
    ):
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                UnifiedTransformerBlock(
                    d_model,
                    n_heads,
                    v_heads=v_heads,
                    residue_scaling_factor=(
                        math.sqrt(n_layers / 36) if scale_residue else 1.0
                    ),
                    expansion_ratio=expansion_ratio,
                    mask_and_zero_frameless=mask_and_zero_frameless,
                    bias=bias,
                    qk_layernorm=qk_layernorm,
                )
                for i in range(n_layers)
            ]
        )
        self.norm = nn.LayerNorm(d_model, bias=False)

In [80]:
transformer = TransformerStack(
            1536,
            24,
            64,
            48,
        )

In [85]:
count = 0
for block in transformer.blocks:
    count += 1
    r1 = block.attn(sequence_after_embed,None)
    r3 =  block.ffn(r1)/block.scaling_factor
    x = r1+r3
    print(x.shape)


print(count)



AttributeError: 'NoneType' object has no attribute 'unsqueeze'

In [48]:
out = nn.Linear(1536,1536)


In [None]:
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
query_BLD, key_BLD = self.q_ln(query_BLD), self.k_ln(key_BLD)
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)



query_BHLD, key_BHLD, value_BHLD = map(
            reshaper, (query_BLD, key_BLD, value_BLD)
        )

        # Where True, enable participation in attention.
        mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
        mask_BHLL = mask_BLL.unsqueeze(1)

        context_BHLD = F.scaled_dot_product_attention(
            query_BHLD, key_BHLD, value_BHLD, mask_BHLL
        )
        context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
        return self.out_proj(context_BLD)