# ESM breakdown 
## ESM functionalities
ESM( ESM_1 ESM_1b, ESM_2): general protein language model   
ESMFold: structure prediction based on language model   
EMS_1V: speicalize language model on protein variants   
EMS_1f: design protein sequence conditioning on protein backbone (similar to MPNN)   High level programming: combine language model and MCMC to design sequences with   high fitness   
We will start on ESM first.   

## ESM model 
1. based on Language BERT(Bidirectional Encoder Representations from Transformers)   model (Maksed language modelling based on left and right surrounding)     
2. general architecture:     
2.1 input representation token embedding + position embedding    
2.2 Transformer encoding layer   
output of a layer can be used as input for a new layer ( Number of layers)   
within transformer encoder:   
MultiHeadAttention   
Residual connencton + LayerNorm   
FeedForward NN   
Residual Connection + LayerNorm   
2.3 update embedding for each masked postions, use NN for prediction  

## Part I. Protein token representation

code: https://github.com/facebookresearch/esm/blob/main/esm/data.py#L14

In [1]:
from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter 

# we simplify ESM Alphabet class here
# ESM has many flags for differnt models, more flexibility in data handling : see data.ipynb
class Alphabet:
    def __init__(self):
        # Define tokens
        self.tokens = ["<pad>", "<cls>", "<eos>", "<unk>", "-", 
                       "A", "R", "N", "D", "C", "Q", "E", "G", 
                       "H", "I", "L", "K", "M", "F", "P", "S", 
                       "T", "W", "Y", "V"]
        
        # Map tokens to indices
        self.token_to_idx = {token: idx for idx, token in enumerate(self.tokens)}
        self.idx_to_token = {idx: token for idx, token in enumerate(self.tokens)}
        
        # Define special token indices
        self.padding_idx = self.token_to_idx["<pad>"]
        self.cls_idx = self.token_to_idx["<cls>"]
        self.eos_idx = self.token_to_idx["<eos>"]
        self.unk_idx = self.token_to_idx["<unk>"]

    def encode(self, sequence):
        """Convert a sequence of amino acids to indices."""
        return [self.token_to_idx.get(token, self.unk_idx) for token in sequence]

    def decode(self, indices):
        """Convert a list of indices back to a sequence of tokens."""
        return "".join([self.idx_to_token[idx] for idx in indices if idx != self.padding_idx])

    def vocab_size(self):
        """Return the total number of tokens in the vocabulary."""
        return len(self.tokens)





In [None]:
seq='MPPMLSGLLARLVKLLLGRHGSALHWRAAGAATVLLVIVLLAGSYLAVLAERGAPGAQLI'
tokenize=Alphabet()
token=torch.tensor(tokenize.encode(seq),dtype=torch.long)

print(token[:5])
tokenize.vocab_size()
embed_tokens=nn.Embedding(tokenize.vocab_size(),32,padding_idx=tokenize.padding_idx) # padding embedding will be zero
x=embed_tokens(token)
print(x.shape)


tensor([17, 19, 19, 17, 15])
torch.Size([60, 32])


## Part II, postional embedding
There are three ways to do as far as I am aware
1. sinusoidalPostionalEmbedding from orignal Transformer (https://github.com/facebookresearch/esm/blob/main/esm/modules.py#L260)
2. Rotary embedding (mulptiplicative instead of addjetive)
3. learnable embedding

In [3]:
# Simplified SinusoidaPositionalEmbedding

import torch
import torch.nn as nn
import math

class SinusoidalPositionEmbedding(nn.Module):
    def __init__(self, embedding_dim: int, max_len: int = 5000):
        """
        Sinusoidal positional embedding.
        
        Args:
            embedding_dim (int): The size of each embedding vector.
            max_len (int): The maximum sequence length.
        """
        super(SinusoidalPositionEmbedding, self).__init__()
        
        # Precompute the positional encodings for efficiency
        self.embedding_dim = embedding_dim
        self.max_len = max_len
        
        # Create a matrix of shape (max_len, embedding_dim)
        position = torch.arange(0, max_len).unsqueeze(1)  # Shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim)) # base freq
        
        pe = torch.zeros(max_len, embedding_dim)
        pe[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices 
        pe[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices
        
        self.register_buffer("pe", pe)  # Register as a buffer so it won't be updated during training

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Add sinusoidal positional embeddings to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).
        
        Returns:
            torch.Tensor: Input tensor with positional embeddings added.
        """
        seq_len = x.size(1)
        if seq_len > self.max_len:
            raise ValueError(f"Sequence length ({seq_len}) exceeds maximum length ({self.max_len}).")
        return x + self.pe[:seq_len].unsqueeze(0)  # Add positional embeddings (broadcasted over the batch)


In [4]:
# Learnable 
# padding handling
import torch.nn.functional as F
class LearnedPositionalEmbedding(nn.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    Padding ids are ignored by either offsetting based on padding_idx
    or by setting padding_idx to None and ensuring that the appropriate
    position ids are passed to the forward function.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
        if padding_idx is not None:
            num_embeddings_ = num_embeddings + padding_idx + 1 # convention to start with padding_idx 
        else:
            num_embeddings_ = num_embeddings
        super().__init__(num_embeddings_, embedding_dim, padding_idx)
        self.max_positions = num_embeddings

    def forward(self, input: torch.Tensor):
        """Input is expected to be of size [bsz x seqlen]."""
        if input.size(1) > self.max_positions:
            raise ValueError(
                f"Sequence length {input.size(1)} above maximum "
                f" sequence length of {self.max_positions}"
            )
        mask = input.ne(self.padding_idx).int() # mask for padding (if padding ->0)
        positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
        return F.embedding(
            positions,
            self.weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
    # all the self.* is inherted from nn.Embedding

In [5]:
# main idea is to generat a trainable lookup matrix as a lookup table, assign embedding based on index lookup

# Example parameters
num_embeddings = 10  # Maximum sequence length
embedding_dim = 4    # Embedding size
padding_idx = 0      # Index for padding token
batch_size = 2       # Number of samples in a batch
seq_length = 6       # Length of each sequence

# Create embedding layer
embedding = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)

# Input tensor: shape [batch_size, seq_length]
input_tensor = torch.tensor([[0, 1, 2, 3, 4, 5], 
                             [0, 1, 1, 2, 3, 3]])  # Sample input with padding and positions

# Forward pass
output = embedding(input_tensor)

# Shapes
print("self.weight shape:", embedding.weight.shape)  # [num_embeddings_ + padding_idx + 1, embedding_dim]
print("Input shape:", input_tensor.shape)            # [batch, seq_length]
print("Output shape:", output.shape)  # [batch, seq_length, embedding_dim]
print(output[0,0,:])   # lookup table for padding is [0]*embed_dim   


self.weight shape: torch.Size([11, 4])
Input shape: torch.Size([2, 6])
Output shape: torch.Size([2, 6, 4])
tensor([0., 0., 0., 0.], grad_fn=<SliceBackward>)


## Rotory Embedding
reference: https://arxiv.org/abs/2104.09864  



In [6]:
## Rotory Embedding
#  Code from https://github.com/facebookresearch/esm/blob/main/esm/rotary_embedding.py


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(x, cos, sin):
    cos = cos[:, : x.shape[-2], :]   # x.shape[-2] real seq len 
    sin = sin[:, : x.shape[-2], :]

    return (x * cos) + (rotate_half(x) * sin)

class RotaryEmbedding(torch.nn.Module):
    """
    The rotary position embeddings from RoFormer_ (Su et. al).
    A crucial insight from the method is that the query and keys are
    transformed by rotation matrices which depend on the relative positions.
    Other implementations are available in the Rotary Transformer repo_ and in
    GPT-NeoX_, GPT-NeoX was an inspiration
    .. _RoFormer: https://arxiv.org/abs/2104.09864
    .. _repo: https://github.com/ZhuiyiTechnology/roformer
    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
    .. warning: Please note that this embedding is not registered on purpose, as it is transformative
        (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
    """

    def __init__(self, dim: int, *_, **__):
        super().__init__()
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))  # inv_freq dim = 1/2*dim
        self.register_buffer("inv_freq", inv_freq) # so that inv_freq is part of the mode but non-trainable

        self._seq_len_cached = None  # cached seq_len so that the same len seqnece do not need be recalculated
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_tables(self, x, seq_dimension=1):
        seq_len = x.shape[seq_dimension]

        # Reset the tables if the sequence length has changed,
        # or if we're on a new device (possibly due to tracing for instance)
        if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
            self._seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) # 
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # i index, j inv_freq 2D table
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # duplicate freqs 

            self._cos_cached = emb.cos()[None, :, :] # [1, len, emb_dim]
            self._sin_cached = emb.sin()[None, :, :] # 

        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

        return (
            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
        )

In [16]:
# example of rotary embedding
x = torch.tensor([
    [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],  # Position 1
     [0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6],  # Position 2
     [1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],  # Position 3
     [2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2]]  # Position 4
])
#x = x.unsqueeze(0) # [B, L, emb_dim]

rotary_test=RotaryEmbedding(dim=8)
print(rotary_test.inv_freq)
sin, cos =rotary_test._update_cos_sin_tables(x)
print(sin.shape)

updated_x=apply_rotary_pos_emb(x, sin, cos)
print(updated_x)

tensor([1.0000, 0.1000, 0.0100, 0.0010])
torch.Size([1, 4, 8])
tensor([[[ 0.1000,  0.2000,  0.3000,  0.4000,  0.5000,  0.6000,  0.7000,
           0.8000],
         [-0.6076,  0.8552,  1.0849,  1.1984,  1.4597,  1.4928,  1.5109,
           1.6012],
         [-2.6170,  1.3270,  1.8536,  1.9952,  0.6719,  2.5138,  2.3375,
           2.4040],
         [-2.8842,  1.5973,  2.6058,  2.7904, -2.5182,  3.6344,  3.1796,
           3.2084]]])


# Part III. Contact prediction
paper: https://www.biorxiv.org/content/10.1101/622803v4.full.pdf


In [3]:
# using attention matrix to predict residue contact, used in supervised trainig 
# remove append and pre_append through masking 
# symmetrilize the matrix and substract background (apc function)
# concatenate lays* head, and regression all values into one value
# then use activation function to convert it probablity.
from torch import nn
from typing import Optional
def symmetrize(x):
    "Make layer symmetric in final two dimensions, used for contact prediction."
    return x + x.transpose(-1, -2)


def apc(x):
    "Perform average product correct, used for contact prediction."
    a1 = x.sum(-1, keepdims=True)
    a2 = x.sum(-2, keepdims=True)
    a12 = x.sum((-1, -2), keepdims=True)

    avg = a1 * a2
    avg.div_(a12)  # in-place to reduce memory
    normalized = x - avg
    return normalized

class ContactPredictionHead(nn.Module):
    """Performs symmetrization, apc, and computes a logistic regression on the output features"""

    def __init__(
        self,
        in_features: int,
        prepend_bos: bool,
        append_eos: bool,
        bias=True,
        eos_idx: Optional[int] = None,
    ):
        super().__init__()
        self.in_features = in_features
        self.prepend_bos = prepend_bos
        self.append_eos = append_eos
        if append_eos and eos_idx is None:
            raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
        self.eos_idx = eos_idx
        self.regression = nn.Linear(in_features, 1, bias)
        self.activation = nn.Sigmoid()

    def forward(self, tokens, attentions):
        # remove eos token attentions
        if self.append_eos:
            eos_mask = tokens.ne(self.eos_idx).to(attentions)
            eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
            attentions = attentions * eos_mask[:, None, None, :, :]
            attentions = attentions[..., :-1, :-1]
        # remove cls token attentions
        if self.prepend_bos:
            attentions = attentions[..., 1:, 1:]
        batch_size, layers, heads, seqlen, _ = attentions.size()
        attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)

        # features: B x C x T x T
        attentions = attentions.to(
            self.regression.weight.device
        )  # attentions always float32, may need to convert to float16
        attentions = apc(symmetrize(attentions))  # 
        attentions = attentions.permute(0, 2, 3, 1)
        return self.activation(self.regression(attentions).squeeze(3))

In [None]:
# convert updated representaion to logits
class RobertaLMHead(nn.Module):
    """Head for masked language modeling."""

    def __init__(self, embed_dim, output_dim, weight):
        super().__init__()
        self.dense = nn.Linear(embed_dim, embed_dim)
        self.layer_norm = ESM1bLayerNorm(embed_dim)
        self.weight = weight
        self.bias = nn.Parameter(torch.zeros(output_dim))

    def forward(self, features):
        x = self.dense(features)
        x = gelu(x)
        x = self.layer_norm(x)
        # project back to size of vocabulary with bias
        x = F.linear(x, self.weight) + self.bias
        return x

## Overvall ESM2 architecuture (unbelivably simple)

ESM1 ESM1b are all similar, differnece in details, such as positional embedding, scaling, LayerNormalization method    
Overall bulding structure:  
generate token_embedding with position embedding   
input into loops of transformer layers, save intermeditae and attentio matrix if needed  
laynorm(updated representation), use it for logits prediction for masked postions  
use (attention matrix), use it for prediction for contact   


In [None]:

class ESM2(nn.Module):
    def __init__(
        self,
        num_layers: int = 33,
        embed_dim: int = 1280,
        attention_heads: int = 20,
        alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
        token_dropout: bool = True,
    ):
        super().__init__()
        self.num_layers = num_layers   # layer
        self.embed_dim = embed_dim    # embedding_dim
        self.attention_heads = attention_heads  # numb of heads
        if not isinstance(alphabet, esm.data.Alphabet):
            alphabet = esm.data.Alphabet.from_architecture(alphabet)  # alphbet class for tokens
        self.alphabet = alphabet
        # vacabulary size (including append and pre-append)
        self.alphabet_size = len(alphabet)
        # idx for various non-content toekn including padding
        self.padding_idx = alphabet.padding_idx
        self.mask_idx = alphabet.mask_idx
        self.cls_idx = alphabet.cls_idx
        self.eos_idx = alphabet.eos_idx
        self.prepend_bos = alphabet.prepend_bos
        self.append_eos = alphabet.append_eos
        self.token_dropout = token_dropout

        # a useful way to call submodule to intialize model (can be integrated with flags to load differnt models, see ESM1)
        self._init_submodules()

    def _init_submodules(self):
        self.embed_scale = 1
        self.embed_tokens = nn.Embedding(
            self.alphabet_size,
            self.embed_dim,
            padding_idx=self.padding_idx,
        )   # generate Embedding x, a lookup table : [ alphabet_size, emb_dim] 

        # Transformer layer: using rotary positional embedding
        self.layers = nn.ModuleList(
            [
                TransformerLayer(
                    self.embed_dim,
                    4 * self.embed_dim,
                    self.attention_heads,
                    add_bias_kv=False,
                    use_esm1b_layer_norm=True,
                    use_rotary_embeddings=True,
                )
                for _ in range(self.num_layers)
            ]
        )

        # model for predicting contact
        self.contact_head = ContactPredictionHead(
            self.num_layers * self.attention_heads,
            self.prepend_bos,
            self.append_eos,
            eos_idx=self.eos_idx,
        )
        self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)


        self.lm_head = RobertaLMHead(
            embed_dim=self.embed_dim,
            output_dim=self.alphabet_size,
            weight=self.embed_tokens.weight,  # initial weights in look up table, convert dim to alphabet size
        )

    def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False):
        if return_contacts:
            need_head_weights = True

        assert tokens.ndim == 2
        padding_mask = tokens.eq(self.padding_idx)  # B, T

        x = self.embed_scale * self.embed_tokens(tokens) # scaling 

        if self.token_dropout:
            x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) # make mask position to zero
            # x: B x T x C
            mask_ratio_train = 0.15 * 0.8  # general maksing ratio
            src_lengths = (~padding_mask).sum(-1)   # all non-padding sites
            mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths # calculate masking ratio [B, ]
            x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] # scaling based on aimed and true masking ratio

        if padding_mask is not None:
            x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) # make padding sito to zero

        repr_layers = set(repr_layers)  # transformer layers to save intermediate resutls (representaion, attention)
        hidden_representations = {}
        if 0 in repr_layers:
            hidden_representations[0] = x

        if need_head_weights:
            attn_weights = []

        # (B, T, E) => (T, B, E)
        x = x.transpose(0, 1) # needed to make data compatiable with pytroch nn.Transformer

        if not padding_mask.any():
            padding_mask = None


        # loop through layers of transformers, and save intermediate results if needed
        for layer_idx, layer in enumerate(self.layers):
            x, attn = layer(
                x,
                self_attn_padding_mask=padding_mask,
                need_head_weights=need_head_weights,
            )
            if (layer_idx + 1) in repr_layers:
                hidden_representations[layer_idx + 1] = x.transpose(0, 1)
            if need_head_weights:
                # (H, B, T, T) => (B, H, T, T)
                attn_weights.append(attn.transpose(1, 0))

        x = self.emb_layer_norm_after(x)
        x = x.transpose(0, 1)  # (T, B, E) => (B, T, E)

        # last hidden representation should have layer norm applied
        if (layer_idx + 1) in repr_layers:
            hidden_representations[layer_idx + 1] = x
        x = self.lm_head(x)  # convert represnetaion into logits

        result = {"logits": x, "representations": hidden_representations}
        if need_head_weights:
            # attentions: B x L x H x T x T
            attentions = torch.stack(attn_weights, 1)
            if padding_mask is not None:
                attention_mask = 1 - padding_mask.type_as(attentions)
                attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) # generate 2D padding mask
                attentions = attentions * attention_mask[:, None, None, :, :] # apply mask
            result["attentions"] = attentions
            if return_contacts:
                contacts = self.contact_head(tokens, attentions) # predict contact map
                result["contacts"] = contacts

        return result

    def predict_contacts(self, tokens):
        return self(tokens, return_contacts=True)["contacts"]


## MSA_transformer
ESM contains this Msa_transformer model where MSA is used
The overall model architecture is the same, expect in the transformer layer, axial transformer is used for efficiency.  
Axial transformer contains a row and column attention (same ae that in Alphafold2)

