# ESM breakdown 
## ESM functionalities
ESM( ESM_1 ESM_1b, ESM_2): general protein language model   
ESMFold: structure prediction based on language model   
EMS_1V: speicalize language model on protein variants   
EMS_1f: design protein sequence conditioning on protein backbone (similar to MPNN)   High level programming: combine language model and MCMC to design sequences with   high fitness   
We will start on ESM first.   

## ESM model 
1. based on Language BERT(Bidirectional Encoder Representations from Transformers)   model (Maksed language modelling based on left and right surrounding)     
2. general architecture:     
2.1 input representation token embedding + position embedding    
2.2 Transformer encoding layer   
output of a layer can be used as input for a new layer ( Number of layers)   
within transformer encoder:   
MultiHeadAttention   
Residual connencton + LayerNorm   
FeedForward NN   
Residual Connection + LayerNorm   
2.3 update embedding for each masked postions, use NN for prediction  

## Part I. Protein token representation

code: https://github.com/facebookresearch/esm/blob/main/esm/data.py#L14

In [1]:
from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter 

# we simplify ESM Alphabet class here
# ESM has many flags for differnt models
class Alphabet:
    def __init__(self):
        # Define tokens
        self.tokens = ["<pad>", "<cls>", "<eos>", "<unk>", "-", 
                       "A", "R", "N", "D", "C", "Q", "E", "G", 
                       "H", "I", "L", "K", "M", "F", "P", "S", 
                       "T", "W", "Y", "V"]
        
        # Map tokens to indices
        self.token_to_idx = {token: idx for idx, token in enumerate(self.tokens)}
        self.idx_to_token = {idx: token for idx, token in enumerate(self.tokens)}
        
        # Define special token indices
        self.padding_idx = self.token_to_idx["<pad>"]
        self.cls_idx = self.token_to_idx["<cls>"]
        self.eos_idx = self.token_to_idx["<eos>"]
        self.unk_idx = self.token_to_idx["<unk>"]

    def encode(self, sequence):
        """Convert a sequence of amino acids to indices."""
        return [self.token_to_idx.get(token, self.unk_idx) for token in sequence]

    def decode(self, indices):
        """Convert a list of indices back to a sequence of tokens."""
        return "".join([self.idx_to_token[idx] for idx in indices if idx != self.padding_idx])

    def vocab_size(self):
        """Return the total number of tokens in the vocabulary."""
        return len(self.tokens)





In [2]:
seq='MPPMLSGLLARLVKLLLGRHGSALHWRAAGAATVLLVIVLLAGSYLAVLAERGAPGAQLI'
tokenize=Alphabet()
token=torch.tensor(tokenize.encode(seq),dtype=torch.long)

print(token[:5])
tokenize.vocab_size()
embed_tokens=nn.Embedding(tokenize.vocab_size(),32)
x=embed_tokens(token)
print(x.shape)


tensor([17, 19, 19, 17, 15])
torch.Size([60, 32])


## Part II, postional embedding
There are three ways to do as far as I am aware
1. sinusoidalPostionalEmbedding from orignal Transformer (https://github.com/facebookresearch/esm/blob/main/esm/modules.py#L260)
2. Rotary embedding (mulptiplicative instead of addjetive)
3. learnable embedding

In [3]:
# Simplified SinusoidaPositionalEmbedding

import torch
import torch.nn as nn
import math

class SinusoidalPositionEmbedding(nn.Module):
    def __init__(self, embedding_dim: int, max_len: int = 5000):
        """
        Sinusoidal positional embedding.
        
        Args:
            embedding_dim (int): The size of each embedding vector.
            max_len (int): The maximum sequence length.
        """
        super(SinusoidalPositionEmbedding, self).__init__()
        
        # Precompute the positional encodings for efficiency
        self.embedding_dim = embedding_dim
        self.max_len = max_len
        
        # Create a matrix of shape (max_len, embedding_dim)
        position = torch.arange(0, max_len).unsqueeze(1)  # Shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim)) # base freq
        
        pe = torch.zeros(max_len, embedding_dim)
        pe[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices 
        pe[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices
        
        self.register_buffer("pe", pe)  # Register as a buffer so it won't be updated during training

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Add sinusoidal positional embeddings to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).
        
        Returns:
            torch.Tensor: Input tensor with positional embeddings added.
        """
        seq_len = x.size(1)
        if seq_len > self.max_len:
            raise ValueError(f"Sequence length ({seq_len}) exceeds maximum length ({self.max_len}).")
        return x + self.pe[:seq_len].unsqueeze(0)  # Add positional embeddings (broadcasted over the batch)


In [4]:
# Learnable 
# padding handling
import torch.nn.functional as F
class LearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    Padding ids are ignored by either offsetting based on padding_idx
    or by setting padding_idx to None and ensuring that the appropriate
    position ids are passed to the forward function.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
        if padding_idx is not None:
            num_embeddings_ = num_embeddings + padding_idx + 1 # convention to start with padding_idx 
        else:
            num_embeddings_ = num_embeddings
        super().__init__(num_embeddings_, embedding_dim, padding_idx)
        self.max_positions = num_embeddings

    def forward(self, input: torch.Tensor):
        """Input is expected to be of size [bsz x seqlen]."""
        if input.size(1) > self.max_positions:
            raise ValueError(
                f"Sequence length {input.size(1)} above maximum "
                f" sequence length of {self.max_positions}"
            )
        mask = input.ne(self.padding_idx).int() # mask for padding (if padding ->0)
        positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
        return F.embedding(
            positions,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
    # all the self.* is inherted from nn.Embedding

In [5]:
# main idea is to generat a trainable lookup matrix as a lookup table, assign embedding based on index lookup

# Example parameters
num_embeddings = 10  # Maximum sequence length
embedding_dim = 4    # Embedding size
padding_idx = 0      # Index for padding token
batch_size = 2       # Number of samples in a batch
seq_length = 6       # Length of each sequence

# Create embedding layer
embedding = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)

# Input tensor: shape [batch_size, seq_length]
input_tensor = torch.tensor([[0, 1, 2, 3, 4, 5], 
                             [0, 1, 1, 2, 3, 3]])  # Sample input with padding and positions

# Forward pass
output = embedding(input_tensor)

# Shapes
print("self.weight shape:", embedding.weight.shape)  # [num_embeddings_ + padding_idx + 1, embedding_dim]
print("Input shape:", input_tensor.shape)            # [batch, seq_length]
print("Output shape:", output.shape)  # [batch, seq_length, embedding_dim]
print(output[0,0,:])   # lookup table for padding is [0]*embed_dim   


self.weight shape: torch.Size([11, 4])
Input shape: torch.Size([2, 6])
Output shape: torch.Size([2, 6, 4])
tensor([0., 0., 0., 0.], grad_fn=<SliceBackward>)


## Rotory Embedding
reference: https://arxiv.org/abs/2104.09864  



In [6]:
## Rotory Embedding
#  Code from https://github.com/facebookresearch/esm/blob/main/esm/rotary_embedding.py


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:, : x.shape[-2], :]   # x.shape[-2] real seq len 
    sin = sin[:, : x.shape[-2], :]

    return (x * cos) + (rotate_half(x) * sin)

class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.
    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration
    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
    .. warning: Please note that this embedding is not registered on purpose, as it is transformative
        (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
    """

    def __init__(self, dim: int, *_, **__):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))  # inv_freq dim = 1/2*dim
        self.register_buffer("inv_freq", inv_freq) # so that inv_freq is part of the mode but non-trainable

        self._seq_len_cached = None  # cached seq_len so that the same len seqnece do not need be recalculated
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_tables(self, x, seq_dimension=1):
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) # 
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # i index, j inv_freq 2D table
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # duplicate freqs 

            self._cos_cached = emb.cos()[None, :, :] # [1, len, emb_dim]
            self._sin_cached = emb.sin()[None, :, :] # 

        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )

In [16]:
# example of rotary embedding
x = torch.tensor([
    [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],  # Position 1
     [0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6],  # Position 2
     [1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],  # Position 3
     [2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2]]  # Position 4
])
#x = x.unsqueeze(0) # [B, L, emb_dim]

rotary_test=RotaryEmbedding(dim=8)
print(rotary_test.inv_freq)
sin, cos =rotary_test._update_cos_sin_tables(x)
print(sin.shape)

updated_x=apply_rotary_pos_emb(x, sin, cos)
print(updated_x)

tensor([1.0000, 0.1000, 0.0100, 0.0010])
torch.Size([1, 4, 8])
tensor([[[ 0.1000,  0.2000,  0.3000,  0.4000,  0.5000,  0.6000,  0.7000,
           0.8000],
         [-0.6076,  0.8552,  1.0849,  1.1984,  1.4597,  1.4928,  1.5109,
           1.6012],
         [-2.6170,  1.3270,  1.8536,  1.9952,  0.6719,  2.5138,  2.3375,
           2.4040],
         [-2.8842,  1.5973,  2.6058,  2.7904, -2.5182,  3.6344,  3.1796,
           3.2084]]])
