# Explore the embeddings of the discrete structure VAE produced for ESM3
The VAE embeds local structure, runs global attention over the neighborhoods of the structure, then discretizes the outputs into tokens.

## 1. provided structural embeddings

__This is how esm does it__

Internal to the ESM model api, eg. `model.generate` we call `model.encode` if the inputs are not already tensors. Here, we call the util function `tokenize_structure`:



```
def tokenize_structure(
    coordinates: torch.Tensor,
    structure_encoder: StructureTokenEncoder,
    structure_tokenizer: StructureTokenizer,
    reference_sequence: str = "",
    add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    device = next(structure_encoder.parameters()).device
    chain = ProteinChain.from_atom37(
        coordinates, sequence=reference_sequence if reference_sequence else None
    )

    # Setup padding
    if reference_sequence and len(reference_sequence) != coordinates.size(0):
        raise ValueError(
            f"Reference sequence length ({len(reference_sequence)}) does not match the number of residues in the coordinates ({coordinates.size(0)})"
        )

    left_pad = 0
    right_pad = 0

    if add_special_tokens:
        left_pad += 1  # Add space for BOS token
        right_pad += 1  # Add space for EOS token

    coordinates, plddt, residue_index = chain.to_structure_encoder_inputs()
    coordinates = coordinates.to(device)  # (1, L, 37, 3)
    plddt = plddt.to(device)  # (1, L)
    residue_index = residue_index.to(device)  # (1, L)
    _, structure_tokens = structure_encoder.encode(
        coordinates, residue_index=residue_index
    )
    coordinates = torch.squeeze(coordinates, dim=0)  # (L, 37, 3)  # type: ignore
    plddt = torch.squeeze(plddt, dim=0)  # (L,)  # type: ignore
    structure_tokens = torch.squeeze(structure_tokens, dim=0)  # (L,)  # type: ignore

    # Add space for BOS and EOS tokens
    if add_special_tokens:
        coordinates = F.pad(
            coordinates,
            (0, 0, 0, 0, left_pad, right_pad),
            value=torch.inf,
        )
        plddt = F.pad(plddt, (left_pad, right_pad), value=0)
        structure_tokens = F.pad(
            structure_tokens,
            (left_pad, right_pad),
            value=structure_tokenizer.pad_token_id,
        )
        structure_tokens[0] = structure_tokenizer.bos_token_id
        structure_tokens[-1] = structure_tokenizer.eos_token_id
    return coordinates, plddt, structure_tokens
```

#### The `StructureTokenEncoder` is a torch module.

In [1]:
from esm.pretrained import load_local_model
from huggingface_hub import login


In [None]:
login(token="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")

In [2]:
# download the encoder
encoder = load_local_model('esm3_structure_encoder_v0', device='cpu')

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

#### Tokenize as expected

In [10]:
from esm.utils.encoding import tokenize_structure
from esm.utils.structure.protein_chain import ProteinChain
import torch
from esm.tokenization import get_model_tokenizers

In [11]:
tokenizers = get_model_tokenizers('esm3_sm_open_v1')

In [7]:
example = ProteinChain.from_rcsb('1ITU', "A")

In [8]:
example

ProteinChain(id='1ITU', sequence='DFFRDEAERIMRDSPVIDGHNDLPWQLLDMFNNRLQDERANLTTLAGTHTNIPKLRAGFVGGQFWSVYTPCDTQNKDAVRRTLEQMDVVHRMCRMYPETFLYVTSSAGIRQAFREGKVASLIGVEGGHSIDSSLGVLRALYQLGMRYLTLTHSCNTPWADNWLVDTGDSEPQSQGLSPFGQRVVKELNRLGVLIDLAHVSVATMKATLQLSRAPVIFSHSSAYSVCASRRNVPDDVLRLVKQTDSLVMVNFYNNYISCTNKANLSQVADHLDHIKEVAGARAVGFGGDFDGVPRVPEGLEDVSKYPDLIAELLRRNWTEAEVKGALADNLLRVFEAVEQASNLTQAPEEEPIPLDQLGGSCRTHYGYSS', chain_id='A', entity_id=1, residue_index=array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 10

In [22]:
outs = tokenize_structure(
    coordinates=torch.tensor(example.atom37_positions),
    structure_encoder=encoder,
    structure_tokenizer=tokenizers.structure,
    reference_sequence=example.sequence,
)

In [23]:
# first output is atom37 positions, did not change from input
outs[0].shape

torch.Size([371, 37, 3])

In [24]:
# second output is attention mask for structure embedding - all ones since we did not mask anything. We can do so by naning anythin along axis 0
outs[1].shape

torch.Size([371])

In [25]:
# descrete output tokens
outs[2].shape

torch.Size([371])

By looking at the above function for tokenization, the discretization happens within the encoder.

#### Dig into the decoder module to get continuous embeddings

In [20]:
encoder

StructureTokenEncoder(
  (transformer): GeometricEncoderStack(
    (blocks): ModuleList(
      (0-1): 2 x UnifiedTransformerBlock(
        (geom_attn): GeometricReasoningOriginalImpl(
          (s_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (proj): Linear(in_features=1024, out_features=1920, bias=True)
          (out_proj): Linear(in_features=384, out_features=1024, bias=True)
        )
        (ffn): Sequential(
          (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1024, out_features=8192, bias=True)
          (2): SwiGLU()
          (3): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (norm): Identity()
  )
  (pre_vq_proj): Linear(in_features=1024, out_features=128, bias=True)
  (codebook): EMACodebook()
  (relative_positional_embedding): RelativePositionEmbedding(
    (embedding): Embedding(66, 1024)
  )
)

Here is the encoder class:

In [21]:
class StructureTokenEncoder(nn.Module):
    def __init__(self, d_model, n_heads, v_heads, n_layers, d_out, n_codes):
        super().__init__()
        # We only support fully-geometric structure token encoders for now...
        # setting n_layers_geom to something that's not n_layers won't work because
        # sequence ID isn't supported fully in this repo for plain-old transformers
        self.transformer = GeometricEncoderStack(d_model, n_heads, v_heads, n_layers)
        self.pre_vq_proj = nn.Linear(d_model, d_out)
        self.codebook = EMACodebook(n_codes, d_out)
        self.relative_positional_embedding = RelativePositionEmbedding(
            32, d_model, init_std=0.02
        )
        self.knn = 16

    def encode_local_structure(
        self,
        coords: torch.Tensor,
        affine: Affine3D,
        attention_mask: torch.Tensor,
        sequence_id: torch.Tensor | None,
        affine_mask: torch.Tensor,
        residue_index: torch.Tensor | None = None,
    ):
        """This function allows for a multi-layered encoder to encode tokens with a local receptive fields. The implementation is as follows:

        1. Starting with (B, L) frames, we find the KNN in structure space. This now gives us (B, L, K) where the last dimension is the local
        neighborhood of all (B, L) residues.
        2. We reshape these frames to (B*L, K) so now we have a large batch of a bunch of local neighborhoods.
        3. Pass the (B*L, K) local neighborhoods through a stack of geometric reasoning blocks, effectively getting all to all communication between
        all frames in the local neighborhood.
        4. This gives (B*L, K, d_model) embeddings, from which we need to get a single embedding per local neighborhood. We do this by simply
        taking the embedding corresponding to the query node. This gives us (B*L, d_model) embeddings.
        5. Reshape back to (B, L, d_model) embeddings
        """
        assert coords.size(-1) == 3 and coords.size(-2) == 3, "need N, CA, C"
        with torch.no_grad():
            knn_edges, _ = self.find_knn_edges(
                coords,
                ~attention_mask,
                coord_mask=affine_mask,
                sequence_id=sequence_id,
                knn=self.knn,
            )
            B, L, E = knn_edges.shape

            affine_tensor = affine.tensor  # for easier manipulation
            T_D = affine_tensor.size(-1)
            knn_affine_tensor = node_gather(affine_tensor, knn_edges)
            knn_affine_tensor = knn_affine_tensor.view(-1, E, T_D).contiguous()
            affine = Affine3D.from_tensor(knn_affine_tensor)
            knn_sequence_id = (
                node_gather(sequence_id.unsqueeze(-1), knn_edges).view(-1, E)
                if sequence_id is not None
                else torch.zeros(L, E, dtype=torch.int64, device=coords.device)
            )
            knn_affine_mask = node_gather(affine_mask.unsqueeze(-1), knn_edges).view(
                -1, E
            )
            knn_chain_id = torch.zeros(L, E, dtype=torch.int64, device=coords.device)

            if residue_index is None:
                res_idxs = knn_edges.view(-1, E)
            else:
                res_idxs = node_gather(residue_index.unsqueeze(-1), knn_edges).view(
                    -1, E
                )

        z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs)

        z, _ = self.transformer.forward(
            x=z,
            sequence_id=knn_sequence_id,
            affine=affine,
            affine_mask=knn_affine_mask,
            chain_id=knn_chain_id,
        )

        # Unflatten the output and take the query node embedding, which will always be the first one because
        # a node has distance 0 with itself and the KNN are sorted.
        z = z.view(B, L, E, -1)
        z = z[:, :, 0, :]

        return z

    @staticmethod
    def find_knn_edges(
        coords,
        padding_mask,
        coord_mask,
        sequence_id: torch.Tensor | None = None,
        knn: int | None = None,
    ) -> tuple:
        assert knn is not None, "Must specify a non-null knn to find_knn_edges"
        # Coords are N, CA, C
        coords = coords.clone()
        coords[~coord_mask] = 0

        if sequence_id is None:
            sequence_id = torch.zeros(
                (coords.shape[0], coords.shape[1]), device=coords.device
            ).long()

        with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore
            ca = coords[..., 1, :]
            edges, edge_mask = knn_graph(
                ca,
                coord_mask,
                padding_mask,
                sequence_id,
                no_knn=knn,
            )

        return edges, edge_mask

    def encode(
        self,
        coords: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        sequence_id: torch.Tensor | None = None,
        residue_index: torch.Tensor | None = None,
    ):
        coords = coords[..., :3, :]
        affine, affine_mask = build_affine3d_from_coordinates(coords=coords)

        if attention_mask is None:
            attention_mask = torch.ones_like(affine_mask, dtype=torch.bool)
        attention_mask = attention_mask.bool()

        if sequence_id is None:
            sequence_id = torch.zeros_like(affine_mask, dtype=torch.int64)

        z = self.encode_local_structure(
            coords=coords,
            affine=affine,
            attention_mask=attention_mask,
            sequence_id=sequence_id,
            affine_mask=affine_mask,
            residue_index=residue_index,
        )

        z = z.masked_fill(~affine_mask.unsqueeze(2), 0)
        z = self.pre_vq_proj(z)

        z_q, min_encoding_indices, _ = self.codebook(z)

        return z_q, min_encoding_indices

NameError: name 'nn' is not defined

is Zq the latent space?

In [27]:
left_pad = 0
right_pad = 0

coordinates, plddt, residue_index = example.to_structure_encoder_inputs()


In [31]:
z, structure_tokens = encoder.encode(
    coordinates, residue_index=residue_index
)

In [36]:
z.mean(axis=1)

tensor([[-0.5313, -2.3629, -0.6525,  0.3052, -0.4117,  2.9609,  0.6918,  2.1311,
          1.5440,  0.6354, -1.1304,  1.1496, -0.1755,  1.9234, -2.5304, -2.9934,
         -0.9476, -1.0383, -0.1722, -1.8722,  1.7672, -0.3774,  1.9237,  1.4243,
         -2.0182,  2.0026, -1.1837,  0.2101, -0.8612,  0.2164,  0.8020, -0.1875,
          0.3853, -0.6748, -0.3127,  2.9619, -0.6505,  0.3225, -0.9738, -0.8902,
         -0.4902,  0.5850, -0.5019, -0.5598, -0.2611,  0.5285, -0.5970, -0.3727,
          1.9450, -1.2231, -0.1247,  2.5219, -1.6484,  1.6937,  0.6202,  1.1796,
          1.6026,  1.2698, -1.7287, -1.0446,  0.3377, -0.4107,  0.0990,  1.4283,
         -0.5145,  0.6338, -0.2815, -0.0670,  0.8958,  0.2626, -1.8318, -1.7010,
         -1.2507,  2.0013, -2.4350, -1.3212, -0.6036,  0.7297, -0.8470,  0.1455,
          1.3283,  0.9365, -0.3025,  0.5832,  0.7132,  1.0818, -0.2943,  1.9198,
          1.9231,  0.4123,  0.7660, -0.4020,  2.0509, -1.4754,  0.6602,  1.6624,
         -0.7840,  0.6760,  