# ESM3 on SageMaker JumpStart

The demo will showcase ESM3's ability to perform several protein design tasks.

![1](images/all.png)

## 1. Setup

In [None]:
from IPython.display import clear_output

%pip install -U -r requirements.txt
%pip install -U esm --no-deps

clear_output()

In [None]:
ENDPOINT_NAME = "<PASTE YOUR ENDPOINT NAME HERE>"
MODEL_NAME = "<PASTE YOUR MODEL NAME HERE>"

In [None]:
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.sdk.sagemaker import ESM3SageMakerClient
from src.esmhelpers import format_seq, quick_pdb_plot, quick_aligment_plot

model = ESM3SageMakerClient(endpoint_name=ENDPOINT_NAME, model=MODEL_NAME)

---
## 2. Sequence + Structure Generation

![Sequence and Structure Generation](images/seq_str_out.png)

ESM3 is a generative model, so the most basic task it can accomplish is to create the sequence and structure of a new protein. All ESM3 inference requests must include sequence information, so in this case we will pass a string of "_" symbols. This is the "mask" token that indicates where we want ESM3 to fill in the blanks.

We start by generating a new protein sequence.

In [None]:
%%time

n_masked = 64

masked_sequence = "_" * n_masked

prompt = ESMProtein(sequence=masked_sequence)
sequence_generation_config = GenerationConfig(
    track="sequence", # We want ESM3 to generate tokens for the sequence track
    num_steps=prompt.sequence.count("_") // 4, # We'll use num(mask tokens) // 4 steps to decode the sequence
    temperature=0.7, # We'll use a temperature of 0.7 to increase the randomness of the decoding process
)

# Call the ESM3 inference endpoint
generated_protein = model.generate(
    prompt,
    sequence_generation_config,
)

# View the generated sequence
print(f"Sequence length: {len(generated_protein.sequence)}")
print(format_seq(generated_protein.sequence))


Next, we predict the structure of the generated sequence and display the results.

In [None]:
%%time

import py3Dmol

prompt = generated_protein

structure_generation_config = GenerationConfig(
    track="structure", # We want ESM3 to generate tokens for the structure track
    num_steps=len(generated_protein.sequence) // 8,
    temperature=0.0, 
)

generated_protein = model.generate(
    prompt,
    structure_generation_config,
)
print(f"Structure coordinates dimensions: {tuple(generated_protein.coordinates.shape)}")

quick_pdb_plot(generated_protein.to_protein_chain().infer_oxygen().to_pdb_string(), color="spectrum")


Let's repeat the sequence + structure generation a few more times. In this case we'll generate all of the tokens in a single step. This makes the inference much faster, but will reduced accuracy.

In [None]:
# Generate sequence
for i in range(3):
    print(f"Iteration {i+1}")
    sequence_prompt = ESMProtein(sequence="_" * n_masked)
    sequence_generation_config = GenerationConfig(
        track="sequence",
        num_steps=1,
        temperature=0.7,
    )
    generated_protein = model.generate(
        sequence_prompt,
        sequence_generation_config,
    )
    print(format_seq(generated_protein.sequence))

    # Generate structure
    structure_prompt = generated_protein
    structure_generation_config = GenerationConfig(
        track="structure",
        num_steps=1,
        temperature=0.0,
    )

    generated_protein = model.generate(
        generated_protein,
        structure_generation_config,
    )

    quick_pdb_plot(
        generated_protein.to_protein_chain().infer_oxygen().to_pdb_string(),
        width=400,
        height=300,
        color="spectrum",
    )

---
## 3. Sequence to Function Prediction

![Sequence In - Function Out](images/seq-func.png)

Another common task is function prediction. Given an unknown amino acid sequence, can we predict the function of its domains? Let's try an example.

For this example, we'll look at pyruvate kinase (PDB ID: [1PKN](https://www.rcsb.org/structure/1PKN)), a key enzyme involved in the breakdown of sugar into energy. It is composed of two different domains, or functional units, the “Barrel Domain” (colored in green below) and the “C-Terminal Domain” (colored in orange).

In [None]:
from esm.utils.structure.protein_chain import ProteinChain
import py3Dmol

pdb_id = "1PKN"
chain_id = "A"

# Download the mmCIF file for 1PKN from PDB
pyruvate_kinase_chain = ProteinChain.from_rcsb(pdb_id, chain_id)

# Display the sequence
print(format_seq(pyruvate_kinase_chain.sequence))

# Display the structure
view = py3Dmol.view(width=400, height=300)
view.addModel(pyruvate_kinase_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle({"resi": list(range(40, 373))}, {"cartoon": {"color": "#38EF7D"}})
view.addStyle({"resi": list(range(408, 526))}, {"cartoon": {"color": "#FF9900"}})
view.rotate(150, "x")
view.rotate(45, "y")
view.rotate(45, "z")
view.zoomTo()
view.show()

Let's submit the pyruvate kinase sequence to ESM3 and request functional annotations by setting the `track` parameter to `function'.

In [None]:
prompt = ESMProtein.from_protein_chain(pyruvate_kinase_chain)

function_prediction_config = GenerationConfig(
    track="function",
    num_steps=len(prompt.sequence)
    // 8,
)

generated_protein = model.generate(
    prompt,
    function_prediction_config,
)

for annotation in generated_protein.function_annotations:
    print(annotation.to_tuple())

In [None]:
from src.esmhelpers import parse_annotations_by_label, format_annotations

parsed_annotations = parse_annotations_by_label(generated_protein.function_annotations)

print(
    " ".ljust(25),
    format_seq(
        generated_protein.sequence,
        width=len(generated_protein.sequence) + 1,
        line_numbers=False,
    ),
)

for label, flags in format_annotations(
    parsed_annotations,
    len(generated_protein.sequence),
    [
        "Pyruvate kinase (IPR001697)",
        "barrel",
        "Pyruvate kinase, C-terminal (IPR015795)",
        "Pyruvate kinase, active site (IPR018209)",
        "acetyltransferase",
    ],
).items():
    print(
        label[:24].ljust(25),
        format_seq(
            flags,
            width=len(generated_protein.sequence) + 1,
            line_numbers=False,
        ),
    )

ESM3 was able to correctly identify the barrel and C-terminal domains, as well as some additional sequence annotations.

---
## 4. Sequence to Structure Prediction

![Sequence In - Structure Out](images/seq-str.png)



Another common task for bioFMs is to translate between sequence and struture (protein folding). Let's try to predict the structure of human beta 3 alchohol dehydrogenase, the enzyme responsible for breaking down alcohol in the liver.

In [None]:
from esm.utils.structure.protein_chain import ProteinChain
import py3Dmol

pdb_id = "1HTB"
chain_id = "A"

# Download the mmCIF file for 1JB0 from PDB
adh_ref_chain = ProteinChain.from_rcsb(pdb_id, chain_id)

# Display the sequence
print(format_seq(adh_ref_chain.sequence))

# Display the structure
quick_pdb_plot(adh_ref_chain.to_pdb_string(), color="#007FAA", width=400, height=300)

Now we use ESM3 to predict the structure, conditioned on the sequence

In [None]:
prompt = ESMProtein.from_protein_chain(adh_ref_chain)

structure_generation_config = GenerationConfig(
    track="structure",
    num_steps=len(prompt.sequence) // 8,
    temperature=0.0,  # Lower temperature means more deterministic predictions.
)

generated_protein = model.generate(
    prompt,
    structure_generation_config,
)

generated_chain = generated_protein.to_protein_chain()
generated_chain = generated_chain.align(adh_ref_chain)

quick_pdb_plot(
    generated_protein.to_pdb_string(), color="#00f174", width=400, height=300
)

Finally we align the generated and reference structures and view the results.

In [None]:
# Calculate the cRMSD
crmsd = generated_chain.rmsd(adh_ref_chain)
print(
    "cRMSD of the motif in the generated structure vs the original structure: ", crmsd
)

view = py3Dmol.view(width=800, height=600)
view.addModel(adh_ref_chain.to_pdb_string(), "pdb")
view.addModel(generated_chain.to_pdb_string(), "pdb")
view.setStyle({"model": 0}, {"cartoon": {"color": "#007FAA"}})
view.setStyle({"model": 1}, {"cartoon": {"color": "#00f174"}})
view.zoomTo()
view.show()

The structure prediction is quite good, with a cRMSD of less than 1. The reference structure was generated using X-ray diffraction at a resolution of 2.4 angstroms, so this prediction matches the experimental accuracy.

---
## 5. Structure to Sequence Prediction

We can also translate the other direction, from structure to sequence.

![Structure In - Sequence Out](images/str-seq.png)

In [None]:
masked_sequence = "_" * len(adh_ref_chain.sequence)

prompt = ESMProtein(
    sequence=masked_sequence,
    coordinates=generated_protein.coordinates,
)
sequence_generation_config = GenerationConfig(
    track="sequence",
    num_steps=prompt.sequence.count("_") // 4,
    temperature=0.0,
)
generated_protein = model.generate(
    prompt,
    sequence_generation_config,
)
print(format_seq(generated_protein.sequence))

In [None]:
quick_aligment_plot(adh_ref_chain.sequence, generated_protein.sequence)

Given only the predicted 3D structure of ADH, ESM3 was able to recover more than 85% of the actual sequence.