# ESMFold breakdown
It took ouput from ESM (sequence reperesentation, attention matrix) as input
into folding trunk then structure module to predict structure.  

utility modules (misc.py)
read seqence module (encode_sequence): input sequence, add chain_linker( optional, default 'G'*25), add residue_index_offset (add a jump in index if in different chain)
return:encoded (residue token), residx, linker_mask, chain_index

read seqence in batch 

output_to_pdb: input: output from model (in dict format) output: pdb files

Attention: to calcualte attention in Transformer

Dropout: 

SequencToPair:

PairToSequence:





## This is how differnt ESM model get loaded into ESMFold model

In [2]:
def _load_model(model_name):
    if model_name.endswith(".pt"):  # local, treat as filepath
        model_path = Path(model_name)
        model_data = torch.load(str(model_path), map_location="cpu")
    else:  # load from hub
        url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
        model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")

    cfg = model_data["cfg"]["model"]
    model_state = model_data["model"]  # contain weights and bais
    model = ESMFold(esmfold_config=cfg)

    expected_keys = set(model.state_dict().keys())
    found_keys = set(model_state.keys())

    missing_essential_keys = []  # make sure keys are compatiable before load parameters
    for missing_key in expected_keys - found_keys:
        if not missing_key.startswith("esm."):
            missing_essential_keys.append(missing_key)

    if missing_essential_keys:
        raise RuntimeError(f"Keys '{', '.join(missing_essential_keys)}' are missing.")

    model.load_state_dict(model_state, strict=False) # 

    return model

## Triangluar self attention block
Inputs:
sequence_state: Shape (B, L, sequence_state_dim) (batch, sequence length, feature dimension).  
pairwise_state: Shape (B, L, L, pairwise_state_dim) (batch, pairwise interactions, feature dimension).  
mask: Optional (B, L) mask indicating valid sequence positions.  
![Folding Truk](folding_trunk.png)

### operation ( very close to Alphafold2 Evoformer module):
### (use to update seq and pair representation by passing message between there two)
### (when update pair representation, it contains triangular multiplication and triangular attention)
triangular multiplication: a way of message passing, if i, k interact, j, k interact, i, j should interact. (This message passing is encoded through element-wise multiplication)
triangular attention: simply a way to calcualte inter residue attention  in column or row wise
Sequence State Updates:
Applies pair-to-sequence projection (self.pair_to_sequence) to derive biases for attention.
Applies self-attention (self.seq_attention) with the computed biases.
Adds residual connections and applies an MLP (self.mlp_seq) to the sequence state.
Pairwise State Updates:

Updates pairwise state with sequence-to-pair projections (self.sequence_to_pair).
Applies triangular multiplications (outgoing and incoming).
Applies triangular attention (starting and ending nodes).
Adds residual connections and applies an MLP (self.mlp_pair) to the pairwise state.
Returns the updated sequence and pairwise states.




## We illusrate some technique details with code below
1. how sequence info pass to pair: SequenceToPair
2. how pair info pass to sequence : PairToSequence

In [3]:
# Code fore sequence_to_pair
import torch
import torch.nn as nn

class SequenceToPair(nn.Module):
    def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
        super().__init__()

        self.layernorm = nn.LayerNorm(sequence_state_dim)
        self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
        self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)

        torch.nn.init.zeros_(self.proj.bias)
        torch.nn.init.zeros_(self.o_proj.bias)

    def forward(self, sequence_state):
        """
        Inputs:
          sequence_state: B x L x sequence_state_dim

        Output:
          pairwise_state: B x L x L x pairwise_state_dim

        Intermediate state:
          B x L x L x 2*inner_dim
        """

        assert len(sequence_state.shape) == 3

        s = self.layernorm(sequence_state)
        s = self.proj(s)
        q, k = s.chunk(2, dim=-1)

        prod = q[:, None, :, :] * k[:, :, None, :]
        diff = q[:, None, :, :] - k[:, :, None, :]

        x = torch.cat([prod, diff], dim=-1)
        x = self.o_proj(x)

        return x
    


In [4]:
# This is code for pair to sequence
# use pair info as bais for transformer attention matrix
class PairToSequence(nn.Module):
    def __init__(self, pairwise_state_dim, num_heads):
        super().__init__()

        self.layernorm = nn.LayerNorm(pairwise_state_dim)
        self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)

    def forward(self, pairwise_state):
        """
        Inputs:
          pairwise_state: B x L x L x pairwise_state_dim

        Output:
          pairwise_bias: B x L x L x num_heads
        """
        assert len(pairwise_state.shape) == 4
        z = self.layernorm(pairwise_state)
        pairwise_bias = self.linear(z)
        return pairwise_bias


In [5]:

# costomized Dropout so that it consistantly applied to same dimension 
# for exmaple, along 0, so that each batch has the same dropout to improve learning
import typing as T
class Dropout(nn.Module):
    """
    Implementation of dropout with the ability to share the dropout mask
    along a particular dimension.
    """

    def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]):
        super(Dropout, self).__init__()

        self.r = r
        if type(batch_dim) == int:
            batch_dim = [batch_dim]
        self.batch_dim = batch_dim
        self.dropout = nn.Dropout(self.r)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = list(x.shape)
        if self.batch_dim is not None:
            for bd in self.batch_dim:
                shape[bd] = 1
        return x * self.dropout(x.new_ones(shape))

In [6]:
# example of dropout function

import torch
import torch.nn as nn

# Define input tensor (batch size = 4, sequence length = 3, feature dim = 2)
x = torch.tensor(
    [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
     [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
     [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]],
     [[19.0, 20.0], [21.0, 22.0], [23.0, 24.0]]]
)

# Create the dropout layer with r=0.5 and shared mask along batch dimension
dropout_layer = Dropout(r=0.5, batch_dim=0)

# Set random seed for reproducibility
torch.manual_seed(42)

# Apply dropout
output = dropout_layer(x)

print("Input Tensor:")
print(x)
print("\nOutput Tensor:")
print(output)

Input Tensor:
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]])

Output Tensor:
tensor([[[ 2.,  4.],
         [ 6.,  8.],
         [ 0., 12.]],

        [[14., 16.],
         [18., 20.],
         [ 0., 24.]],

        [[26., 28.],
         [30., 32.],
         [ 0., 36.]],

        [[38., 40.],
         [42., 44.],
         [ 0., 48.]]])


In [7]:
# Special Attetnion units
#1. add bias from pair info 2.with gating option 3. with masking option
class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_width, gated=False):
        super().__init__()
        assert embed_dim == num_heads * head_width

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_width = head_width

        self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.gated = gated
        if gated:
            self.g_proj = nn.Linear(embed_dim, embed_dim)
            torch.nn.init.zeros_(self.g_proj.weight)
            torch.nn.init.ones_(self.g_proj.bias)

        self.rescale_factor = self.head_width**-0.5

        torch.nn.init.zeros_(self.o_proj.bias)

    def forward(self, x, mask=None, bias=None, indices=None):
        """
        Basic self attention with optional mask and external pairwise bias.
        To handle sequences of different lengths, use mask.

        Inputs:
          x: batch of input sequneces (.. x L x C)
          mask: batch of boolean masks where 1=valid, 0=padding position (.. x L_k). optional.
          bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads). optional.

        Outputs:
          sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
        """

        t = rearrange(self.proj(x), "... l (h c) -> ... h l c", h=self.num_heads)
        q, k, v = t.chunk(3, dim=-1)

        q = self.rescale_factor * q
        a = torch.einsum("...qc,...kc->...qk", q, k)

        # Add external attention bias.
        if bias is not None:
            a = a + rearrange(bias, "... lq lk h -> ... h lq lk")

        # Do not attend to padding tokens.
        if mask is not None:
            mask = repeat(
                mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2]
            )
            a = a.masked_fill(mask == False, -np.inf)

        a = F.softmax(a, dim=-1)

        y = torch.einsum("...hqk,...hkc->...qhc", a, v)
        y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads)

        if self.gated:
            y = self.g_proj(x).sigmoid() * y
        y = self.o_proj(y)

        return y, rearrange(a, "... lq lk h -> ... h lq lk")



In [1]:
# distogram (used for calcualte training loss)
# calculate CB position based on CA N C 
# calcuculate interresdiuce CB distance, then bin the distance (output L, L, 1)
# for example  bin 2 with [1, 2, 3, 4], the output is 2

def distogram(coords, min_bin, max_bin, num_bins):
    # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
    boundaries = torch.linspace(
        min_bin,
        max_bin,
        num_bins - 1,
        device=coords.device,
    )
    boundaries = boundaries**2
    N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
    # Infer CB coordinates.
    b = CA - N
    c = C - CA
    a = b.cross(c, dim=-1)
    CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
    dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
    bins = torch.sum(dists > boundaries, dim=-1)  # [..., L, L]
    return bins

In [None]:
# We need to inspect the structure module: how seq, pair represenation used for building 3D structure
# this structure module comes from openfold : https://github.com/aqlaboratory/openfold/blob/main/openfold/model/structure_module.py
#detailed explanation of structure module comes from Alphafold2 paper(Algorithm 20): https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf


In [None]:
#Assemble all commponents
class FoldingTrunk(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.cfg = FoldingTrunkConfig(**kwargs)
        assert self.cfg.max_recycles > 0

        c_s = self.cfg.sequence_state_dim
        c_z = self.cfg.pairwise_state_dim

        assert c_s % self.cfg.sequence_head_width == 0
        assert c_z % self.cfg.pairwise_head_width == 0
        block = TriangularSelfAttentionBlock

        self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z)

        self.blocks = nn.ModuleList(
            [
                block(
                    sequence_state_dim=c_s,
                    pairwise_state_dim=c_z,
                    sequence_head_width=self.cfg.sequence_head_width,
                    pairwise_head_width=self.cfg.pairwise_head_width,
                    dropout=self.cfg.dropout,
                )
                for i in range(self.cfg.num_blocks)
            ]
        )

        self.recycle_bins = 15
        self.recycle_s_norm = nn.LayerNorm(c_s)
        self.recycle_z_norm = nn.LayerNorm(c_z)
        self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
        self.recycle_disto.weight[0].detach().zero_()

        self.structure_module = StructureModule(**self.cfg.structure_module)  # type: ignore
        self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s)
        self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z)

        self.chunk_size = self.cfg.chunk_size

    def set_chunk_size(self, chunk_size):
        # This parameter means the axial attention will be computed
        # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
        # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
        # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
        self.chunk_size = chunk_size

    def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None):
        """
        Inputs:
          seq_feats:     B x L x C            tensor of sequence features
          pair_feats:    B x L x L x C        tensor of pair features
          residx:        B x L                long tensor giving the position in the sequence
          mask:          B x L                boolean tensor indicating valid residues

        Output:
          predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
        """

        device = seq_feats.device
        s_s_0 = seq_feats
        s_z_0 = pair_feats

        if no_recycles is None:
            no_recycles = self.cfg.max_recycles
        else:
            assert no_recycles >= 0, "Number of recycles must not be negative."
            no_recycles += 1  # First 'recycle' is just the standard forward pass through the model.

        def trunk_iter(s, z, residx, mask):
            z = z + self.pairwise_positional_embedding(residx, mask=mask)

            for block in self.blocks:
                s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
            return s, z

        s_s = s_s_0
        s_z = s_z_0
        recycle_s = torch.zeros_like(s_s)
        recycle_z = torch.zeros_like(s_z)
        recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)

        assert no_recycles > 0
        for recycle_idx in range(no_recycles):
            with ExitStack() if recycle_idx == no_recycles - 1 else torch.no_grad():
                # === Recycling ===
                recycle_s = self.recycle_s_norm(recycle_s.detach())
                recycle_z = self.recycle_z_norm(recycle_z.detach())
                recycle_z += self.recycle_disto(recycle_bins.detach())

                s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)

                # === Structure module ===
                structure = self.structure_module(
                    {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
                    true_aa,
                    mask.float(),
                )

                recycle_s = s_s
                recycle_z = s_z
                # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
                recycle_bins = FoldingTrunk.distogram(
                    structure["positions"][-1][:, :, :3],
                    3.375,
                    21.375,
                    self.recycle_bins,
                )

        assert isinstance(structure, dict)  # type: ignore
        structure["s_s"] = s_s
        structure["s_z"] = s_z

        return structure