# Tutorial 6: Protein Structure Embedding with ESM3

In this notebook we will see how to embed a batch of structures using ESM 3, as well as explore its different layers

# Imports

In [1]:
import torch
from esm.models.esm3 import ESM3
from cookbook.snippets.structure_embed import get_layer_embedding
from esm.utils.constants import esm3 as C

# Loading the model

In [2]:
print("CUDA available:", torch.cuda.is_available())
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading ESM3 model...")
model = ESM3.from_pretrained("esm3_sm_open_v1").to(device)

CUDA available: False
Loading ESM3 model...


Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

  state_dict = torch.load(


# Getting embeddings from specific layers

We will demo two use cases:
1. Getting embeddings for complete protein structure
2. Getting embeddings for only backbone structure

## 1. Getting embeddings for complete protein structure

## 2. Getting embeddings for only backbone structure

In [6]:
pdb_id = "1mjc"
chain_id = "A"
layer = 45
structure_tokens = get_layer_embedding(pdb_id, chain_id, model, layer, device)

Loading protein chain for PDB ID: 1mjc, Chain ID: A
Number of residues with all backbone atoms present: 69 out of 69
<class 'numpy.ndarray'>
Filtered backbone coordinate shape: torch.Size([1, 69, 3, 3])


Checking backbone structure token correctness

In [7]:
print(structure_tokens)
print(f"number of unique structure tokens: {structure_tokens.unique().numel()}")

tensor([[ 721, 2431, 1244, 3952, 2165, 3892, 3703, 3967, 2435, 2784, 3203, 3058,
          745, 4006, 1412, 3414, 2463,  638, 3304, 1439,  257, 3376, 1397, 1400,
         1083, 2871, 3337,  815, 3912, 2971, 1466,  602, 2348, 2605, 1725, 2881,
         1921,   96, 1853,  497,  446, 3089, 3232, 3350,  698, 1829,  546,  481,
          567, 1451, 1247,  973, 2004, 1877, 1973,  448,  747, 2983, 4073, 2157,
         2824, 3968, 2015,   26,  612,  592, 2977,  615, 3234]])
number of unique structure tokens: 69


In [8]:
from esm.utils.decoding import decode_structure
from esm.sdk.api import ESMProtein

def add_bos_eos(tokens: torch.Tensor):
    if tokens.ndim == 2:
        tokens = tokens.squeeze(0)
    bos = torch.tensor([C.VQVAE_SPECIAL_TOKENS["BOS"]], device=tokens.device)
    eos = torch.tensor([C.VQVAE_SPECIAL_TOKENS["EOS"]], device=tokens.device)
    return torch.cat([bos, tokens, eos], dim=0)

@torch.no_grad()
def decode_from_structure_tokens(structure_tokens, model, device):
    structure_tokens_with_special = add_bos_eos(structure_tokens)
    
    atom37_coords, plddt, ptm = decode_structure(
        structure_tokens_with_special,
        structure_decoder=model.get_structure_decoder(),
        structure_tokenizer=model.tokenizers.structure,
        sequence=None
    )
    return atom37_coords, plddt, ptm

recon_coords, plddt, ptm = decode_from_structure_tokens(structure_tokens, model, device)

protein = ESMProtein(coordinates=recon_coords.to(device).numpy())

print("PLDDT:", plddt.mean().item())
print("pTM:", ptm.item())


PLDDT: 0.8731199502944946
pTM: 0.7162373661994934
