# Enzyme Engineering with ESM3 using Amazon SageMaker Realtime inference endpoints

The demo will showcase ESM3's ability to modify enzyme sequences and structures.

Rather than generating new sequences from scratch, it can be much more interesting to modify an existing protein sequence. You might do this to potentially increase the binding to a ligand, for example. Or to design a new protein that incorporates a known active site.

![Protein Modification](images/part_seq-seq.png)

## 1. Setup

Note: you'll need to run the first notebook `1-deploy-esm3-inference-endpoint.ipynb` before running this one to get the `ENDPOINT_NAME` used below. Optionally if you deploy the model via the SageMaker console you can find the endpoint in the inference section of the console.

In [None]:
ENDPOINT_NAME = ""
MODEL_NAME = "esm3-sm-open-v1"

In [None]:
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.sdk.sagemaker import ESM3SageMakerClient
from src.esmhelpers import format_seq

model = ESM3SageMakerClient(endpoint_name=ENDPOINT_NAME, model=MODEL_NAME)

---
## 2. Download enzyme structure

Ornithine transcarbamylase (OTC) deficiency is a rare genetic disorder that affects the liver's ability to process ammonia, a waste product produced during the breakdown of proteins. It is the most common urea cycle disorder. Treatment involves a low-protein diet, ammonia-lowering medications, and sometimes liver transplantation for severe cases. Early diagnosis and management are crucial to prevent brain damage and other complications.

One treatment approach for certain genetic diseases like OTCD is enzyme replacement therapy, where patients receive an intravenous infusion of the missing or deficient enzyme on a regular basis. This can be effective, but expensive. Instead, scientists have proposed using modified versions of these enzymes that require lower or less-frequent dosing. 

Let's see how ESM3 can improve protein engineering projects like this. First, we download the OTC reference structure from PDB and visualize the active sites necessary for its function.

In [None]:
from esm.utils.structure.protein_chain import ProteinChain
import py3Dmol

pdb_id = "1OTH"
chain_id = "A"

# Download the mmCIF file for 1PKN from PDB
otc_reference_chain = ProteinChain.from_rcsb(pdb_id, chain_id)
otc_reference_chain.residue_index = (
    otc_reference_chain.residue_index - otc_reference_chain.residue_index[0] + 1
)
otc_reference_protein = ESMProtein.from_protein_chain(otc_reference_chain)

# Display the sequence
print(format_seq(otc_reference_chain.sequence))

active_site_residues = [
    56,
    57,
    58,
    59,
    60,
    61,
    108,
    130,
    135,
    138,
    165,
    166,
    167,
    230,
    231,
    234,
    235,
    270,
    271,
    272,
    297,
]

# Display the structure
view = py3Dmol.view(width=800, height=600)
view.addModel(otc_reference_chain.infer_oxygen().to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "#007FAA"}})
view.addStyle({"resi": active_site_residues}, {"cartoon": {"color": "#eb982c"}})

view.zoomTo()
view.show()

---
## 3. Prepare masked prompt

Next, we encode the reference sequence and structure into tokens. This will make it easier to select specific portions of the protein for redesign, especially for the structure.

In [None]:
otc_reference_tokens = model.encode(otc_reference_protein)
print(f"Encoded sequence:\n{otc_reference_tokens.sequence}")
print(f"Encoded structure:\n{otc_reference_tokens.structure}")

Next, we create a prompt that masks all of the protein except for the binding pocket highlighted above. First, we can construct a sequence prompt of all masks and then fill in the active site residues.

In [None]:
import torch
from esm.utils.constants import esm3 as esm3_constants

prompt_token_length = len(otc_reference_tokens.sequence)
print(f"Sequence token count: {prompt_token_length}")
masked_sequence_tokens = torch.full(
    [prompt_token_length], esm3_constants.SEQUENCE_MASK_TOKEN
)
masked_sequence_tokens[0] = esm3_constants.SEQUENCE_BOS_TOKEN
masked_sequence_tokens[-1] = esm3_constants.SEQUENCE_EOS_TOKEN

for idx in active_site_residues:
    masked_sequence_tokens[idx - 1] = otc_reference_tokens.sequence[idx - 1]

masked_sequence_token_count = (
    (masked_sequence_tokens == esm3_constants.SEQUENCE_MASK_TOKEN).sum().item()
)

print(f"Masked sequence token count: {masked_sequence_token_count}")

Next, we do something similar for the structure. Rather than dealing with 3D coordinates, we instead work with the encoded structure tokens. We construct an empty structure track like |<bos> <mask> ... <mask> <eos>|... and then fill in structure tokens for the active site.


In [None]:
import torch

masked_structure_tokens = torch.full(
    [prompt_token_length], esm3_constants.STRUCTURE_MASK_TOKEN
)

masked_structure_tokens[0] = esm3_constants.STRUCTURE_BOS_TOKEN
masked_structure_tokens[-1] = esm3_constants.STRUCTURE_EOS_TOKEN

otc_reference_tokens = model.encode(otc_reference_protein)
for idx in active_site_residues:
    masked_structure_tokens[idx - 1] = otc_reference_tokens.structure[idx - 1]

masked_structure_token_count = (
    (masked_structure_tokens == esm3_constants.STRUCTURE_MASK_TOKEN).sum().item()
)

print(f"Masked structure token count: {masked_structure_token_count}")

assert masked_sequence_token_count == masked_structure_token_count
masked_token_count = masked_sequence_token_count

In [None]:
from esm.sdk.api import ESMProteinTensor

encoded_prompt = ESMProteinTensor(
    sequence=masked_sequence_tokens, 
    structure=masked_structure_tokens
)

In [None]:
print("Reference sequence:")
print(
    format_seq(
        otc_reference_chain.sequence, width=prompt_token_length + 1, line_numbers=False
    )
)
print("Masked sequence:")
print(
    format_seq(
        model.decode(encoded_prompt).sequence,
        width=prompt_token_length + 1,
        line_numbers=False,
    )
)
print("Masked structure:")
print(
    format_seq(
        "".join(["✔" if st < 4096 else "_" for st in encoded_prompt.structure][1:-1]),
        width=prompt_token_length + 1,
        line_numbers=False,
    )
)

---
## 4. Generate structure

In [None]:
structure_generation_config = GenerationConfig(
    track="structure", num_steps=masked_token_count // 8, temperature=1.0
)

generated_protein_1 = model.generate(encoded_prompt, structure_generation_config)

decoded_protein_chain = model.decode(generated_protein_1).to_protein_chain()

view = py3Dmol.view(width=600, height=400)
view.addModel(decoded_protein_chain.infer_oxygen().to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "#007FAA"}})
view.addStyle({"resi": active_site_residues}, {"cartoon": {"color": "#eb982c"}})
view.zoomTo()
view.show()

Verfiy that:
  1. The new structure has a very similar active site as the reference
  2. The new struture has a very DISSIMLAR backbone structure

In [None]:
constrained_site_rmsd = otc_reference_chain[active_site_residues].rmsd(
    decoded_protein_chain[active_site_residues]
)
backbone_rmsd = otc_reference_chain.rmsd(decoded_protein_chain)

c_pass = "✅" if constrained_site_rmsd < 1.5 else "❌"
b_pass = "✅" if backbone_rmsd > 1.5 else "❌"

print(f"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}")
print(f"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}")

---
## 5. Generate sequence

Next, we use the generated structure as conditioning to generate a new sequence

In [None]:
sequence_generation_config = GenerationConfig(
    track="sequence",
    num_steps=masked_token_count // 4,
    temperature=1.0,
)
generated_protein_2 = model.generate(generated_protein_1, sequence_generation_config)

print("Reference sequence:")
print(
    format_seq(
        otc_reference_chain.sequence, width=prompt_token_length + 1, line_numbers=False
    )
)
print("Masked sequence:")
print(
    format_seq(
        model.decode(encoded_prompt).sequence,
        width=prompt_token_length + 1,
        line_numbers=False,
    )
)
print("Generated sequence:")
print(
    format_seq(
        model.decode(generated_protein_2).sequence, width=prompt_token_length + 1, line_numbers=False
    )
)

Finally, refold the generated sequence without any other conditioning.

In [None]:
prompt = ESMProteinTensor(sequence=generated_protein_2.sequence, structure=None)

structure_generation_config = GenerationConfig(
    track="structure", num_steps=masked_token_count // 8, temperature=0.0
)

generated_protein_3 = model.generate(prompt, structure_generation_config)
final_protein = model.decode(generated_protein_3)
print(format_seq(final_protein.sequence))

---
## 6. Validation

Compare the generated sequence to the reference.

In [None]:
import biotite.sequence as seq
import biotite.sequence.align as align
import biotite.sequence.graphics as graphics
import matplotlib.pyplot as pl


seq1 = seq.ProteinSequence(otc_reference_protein.sequence)
seq2 = seq.ProteinSequence(final_protein.sequence)

alignments = align.align_optimal(
    seq1,
    seq2,
    align.SubstitutionMatrix.std_protein_matrix(),
    gap_penalty=(-10, -1),
)

alignment = alignments[0]

identity = align.get_sequence_identity(alignment)
print(f"Sequence identity: {100*identity:.2f}%")

print("\nSequence alignment:")
fig = pl.figure(figsize=(8.0, 4.0))
ax = fig.add_subplot(111)
graphics.plot_alignment_similarity_based(
    ax, alignment, symbols_per_line=45, spacing=2,
    show_numbers=True,
)
fig.tight_layout()
pl.show()

Compare the generated structure to the reference.

In [None]:
generated_chain = final_protein.to_protein_chain()
generated_chain = generated_chain.align(otc_reference_chain)

constrained_site_rmsd = otc_reference_chain[active_site_residues].rmsd(
    generated_chain[active_site_residues]
)
backbone_rmsd = otc_reference_chain.rmsd(generated_chain)

c_pass = "✅" if constrained_site_rmsd < 1.5 else "❌"
b_pass = "🤷‍♂️"

print(f"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}")
print(f"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}")


In [None]:
view = py3Dmol.view(width=600, height=600)
view.addModel(otc_reference_chain.infer_oxygen().to_pdb_string(), "pdb")
view.addModel(generated_chain.infer_oxygen().to_pdb_string(), "pdb")
view.setStyle({"model":0},{"cartoon": {"color": "#007FAA"}})
view.setStyle({"model":1},{"cartoon": {"color": "lightgreen"}})
view.addStyle({"resi": active_site_residues}, {"cartoon": {"color": "#eb982c"}})
view.zoomTo()
view.show()

We have successfully generated a new protein with a similar active site as the referemce but different backbone structure and sequence. Repeating this process many times will give us a good library of candidates for lab testing.

## Congratulations

You've gone through all the example notebooks. Please feel free to experiment further or go to the next notebook where you can learn how to delete the inference endpoints.