In [1]:
import numpy as np
import torch
import py3Dmol
from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain

In [2]:
from esm.utils.misc import huggingfacehub_login

huggingfacehub_login()

model = ESM3.from_pretrained("esm3_sm_open_v1", device=torch.device("cuda"))

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

  state_dict = torch.load(


PDBid = 1ITU 人肾（肾）二肽酶

In [None]:
pdb_id = "1ITU"  # PDB ID corresponding to Renal Dipeptidase
chain_id = "A"  # Chain ID corresponding to Renal Dipeptidase in the PDB structure
renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)

In [5]:
print(renal_dipep_chain.sequence)
print("atom37_positions shape: ", renal_dipep_chain.atom37_positions.shape)
print(renal_dipep_chain.atom37_positions[:3])

DFFRDEAERIMRDSPVIDGHNDLPWQLLDMFNNRLQDERANLTTLAGTHTNIPKLRAGFVGGQFWSVYTPCDTQNKDAVRRTLEQMDVVHRMCRMYPETFLYVTSSAGIRQAFREGKVASLIGVEGGHSIDSSLGVLRALYQLGMRYLTLTHSCNTPWADNWLVDTGDSEPQSQGLSPFGQRVVKELNRLGVLIDLAHVSVATMKATLQLSRAPVIFSHSSAYSVCASRRNVPDDVLRLVKQTDSLVMVNFYNNYISCTNKANLSQVADHLDHIKEVAGARAVGFGGDFDGVPRVPEGLEDVSKYPDLIAELLRRNWTEAEVKGALADNLLRVFEAVEQASNLTQAPEEEPIPLDQLGGSCRTHYGYSS
atom37_positions shape:  (369, 37, 3)
[[[-40.525  -9.87   -2.643]
  [-39.79   -9.325  -3.825]
  [-38.765 -10.354  -4.294]
  [-39.096  -8.012  -3.45 ]
  [-37.878 -10.748  -3.53 ]
  [-38.41   -7.359  -4.629]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-39.105  -7.036  -5.617]
  [-37.177  -7.161  -4.562]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [ 

In [6]:
# First we can create a `py3Dmol` view object
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = renal_dipep_chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

In [7]:
motif_inds = np.arange(123, 146)
# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues
motif_sequence = renal_dipep_chain[motif_inds].sequence
motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions
print("Motif sequence: ", motif_sequence)
print("Motif atom37_positions shape: ", motif_atom37_positions.shape)

Motif sequence:  VEGGHSIDSSLGVLRALYQLGMR
Motif atom37_positions shape:  (23, 37, 3)


In [8]:
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (
    motif_inds + 1
).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

In [9]:
prompt_length = 200
# First, we can construct a sequence prompt of all masks
sequence_prompt = ["_"] * prompt_length
# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)
sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)
sequence_prompt = "".join(sequence_prompt)
print("Sequence prompt: ", sequence_prompt)
print("Length of sequence prompt: ", len(sequence_prompt))

# Next, we can construct a structure prompt of all nan coordinates
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72
structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(
    motif_atom37_positions
)
print("Structure prompt shape: ", structure_prompt.shape)
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),
)

# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)

Sequence prompt:  ________________________________________________________________________VEGGHSIDSSLGVLRALYQLGMR_________________________________________________________________________________________________________
Length of sequence prompt:  200
Structure prompt shape:  torch.Size([200, 37, 3])
Indices with structure conditioning:  [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94]


In [10]:
# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use
sequence_generation_config = GenerationConfig(
    track="sequence",  # We want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_")
    // 2,  # We'll use num(mask tokens) // 2 steps to decode the sequence
    temperature=0.5,  # We'll use a temperature of 0.5 to control the randomness of the decoding process
)

# Now, we can use the `generate` method of the model to decode the sequence
sequence_generation = model.generate(protein_prompt, sequence_generation_config)
print("Sequence Prompt:\n\t", protein_prompt.sequence)
print("Generated sequence:\n\t", sequence_generation.sequence)

  state_dict = torch.load(
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore
100%|██████████| 88/88 [00:04<00:00, 19.73it/s]
  state_dict = torch.load(
  state_dict = torch.load(


Sequence Prompt:
	 ________________________________________________________________________VEGGHSIDSSLGVLRALYQLGMR_________________________________________________________________________________________________________
Generated sequence:
	 LDKLRAGGVGAQFWSVYVPAEYQGGDAVRRTLEQIDLVKRLVAAYPDDFELAYTAADIPRIVASGKIASMIGVEGGHSIDSSLGVLRALYQLGMRYMTLTWNADNDWADSATGAGPVHNGLSDFGREVVREMNRLGIMVDLSHVSDATFWDALAVSTAPVIASHSSARALADHPRNMTDEQLAALAANGGVIMINFYAGY


In [11]:
structure_prediction_config = GenerationConfig(
    track="structure",  # We want ESM3 to generate tokens for the structure track
    num_steps=len(sequence_generation) // 8,
    temperature=0.7,
)
structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
structure_prediction = model.generate(
    structure_prediction_prompt, structure_prediction_config
)

100%|██████████| 25/25 [00:01<00:00, 13.09it/s]


In [12]:
# Convert the generated structure to a back into a ProteinChain object
structure_prediction_chain = structure_prediction.to_protein_chain()
# Align the generated structure to the original structure using the motif residues
motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))
structure_prediction_chain.align(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
crmsd = structure_prediction_chain.rmsd(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
print(
    "cRMSD of the motif in the generated structure vs the original structure: ", crmsd
)

view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(pdb_str, "pdb", viewer=(0, 0))
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
view.addStyle(
    {"resi": (motif_inds_in_generation + 1).tolist()},
    {"cartoon": {"color": "cyan"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()

cRMSD of the motif in the generated structure vs the original structure:  0.1737958035499259


In [13]:
helix_shortening_chain = ProteinChain.from_rcsb("7XBQ", "A")
view = py3Dmol.view(width=500, height=500)
view.addModel(helix_shortening_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
helix_region = np.arange(38, 111)  # zero-indexed
view.addStyle(
    {"resi": (helix_region + 1).tolist()}, {"cartoon": {"color": "lightblue"}}
)
view.zoomTo()
view.show()
helix_shortening_ss8 = "CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC"
print(
    "Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \n\t",
    helix_shortening_ss8,
)

Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) 
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC


In [14]:
shortened_region_length = 45

# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked
sequence_prompt = (
    helix_shortening_chain.sequence[: helix_region[0]]
    + "_" * shortened_region_length
    + helix_shortening_chain.sequence[helix_region[-1] + 1 :]
)
print("Sequence prompt:\n\t", sequence_prompt)

# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region
ss8_prompt = (
    helix_shortening_ss8[: helix_region[0]]
    + (
        ((shortened_region_length - 3) // 2) * "H"
        + "C" * 3
        + ((shortened_region_length - 3) // 2) * "H"
    )
    + helix_shortening_ss8[helix_region[-1] + 1 :]
)
print("SS8 prompt:\n\t", ss8_prompt)
print(
    "Proposed SS8 for shortened helix-coil-helix region:\n\t",
    " " * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],
)

print("")
print("Original sequence:\n\t", helix_shortening_chain.sequence)
print("Original SS8:\n\t", helix_shortening_ss8)
print(
    "Original SS8 for helix-coil-helix region:\n\t",
    " " * helix_region[0]
    + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],
)


# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)

Sequence prompt:
	 MAREENVYMAKLAEQAERYEEMVQFMEKVSTSLGSEEL_____________________________________________SASNGDSKVFYLKMKGDYHRYLAEFKTGAERKEAAESTLSAYKAAQDIANTELAPTHPIRLGLALNFSVFYYEILNSPDRACNLAKQAFDEAIAELDTLGEESYKDSTLIMQLLRDNLTLWT
SS8 prompt:
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCHHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHHGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC
Proposed SS8 for shortened helix-coil-helix region:
	                                       HHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHH

Original sequence:
	 MAREENVYMAKLAEQAERYEEMVQFMEKVSTSLGSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEESRGNEEHVKCIKEYRSKIESELSNICDGILKLLDSNLIPSASNGDSKVFYLKMKGDYHRYLAEFKTGAERKEAAESTLSAYKAAQDIANTELAPTHPIRLGLALNFSVFYYEILNSPDRACNLAKQAFDEAIAELDTLGEESYKDSTLIMQLLRDNLTLWT
Original SS8:
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHH

In [15]:
print("Generating protein sequence...")
sequence_generation = model.generate(
    protein_prompt,
    GenerationConfig(
        track="sequence",
        num_steps=protein_prompt.sequence.count("_") // 2,
        temperature=0.5,
    ),
)
print("Folding protein...")
structure_prediction = model.generate(
    ESMProtein(sequence=sequence_generation.sequence),
    GenerationConfig(
        track="structure", num_steps=len(protein_prompt) // 4, temperature=0
    ),
)

Generating protein sequence...


100%|██████████| 22/22 [00:01<00:00, 17.90it/s]


Folding protein...


100%|██████████| 51/51 [00:02<00:00, 19.72it/s]


In [16]:
predicted_chain = structure_prediction.to_protein_chain()
predicted_chain = predicted_chain.align(
    helix_shortening_chain,
    mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),
    target_inds=np.arange(
        len(helix_shortening_chain) - 120, len(helix_shortening_chain)
    ),
)
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(helix_shortening_chain.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(predicted_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (helix_region + 1).tolist()},
    {"cartoon": {"color": "lightblue"}},
    viewer=(0, 0),
)
view.addStyle(
    {"resi": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},
    {"cartoon": {"color": "pink"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()

In [17]:
lipase_chain = ProteinChain.from_rcsb("1LBS", "A")
span_start = 105
span_end = 116
view = py3Dmol.view(width=500, height=500)
view.addModel(lipase_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (np.arange(span_start, span_end) + 1).tolist()},
    {"cartoon": {"color": "red"}},
)
view.zoomTo()
view.show()
lipase_ss8 = "CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC"

In [18]:
structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)
structure_prompt[span_start:span_end] = torch.tensor(
    lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32
)

sasa_prompt = [None] * len(lipase_chain)
sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)

print("SASA prompt (just for buried region): ", sasa_prompt[span_start:span_end])

protein_prompt = ESMProtein(
    sequence="_" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt
)

SASA prompt (just for buried region):  [40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0]


In [19]:
generated_proteins = []
N_SAMPLES = 16
for i in range(N_SAMPLES):
    print("Generating protein sequence...")
    sequence_generation = model.generate(
        protein_prompt,
        GenerationConfig(
            track="sequence", num_steps=len(protein_prompt) // 8, temperature=0.7
        ),
    )
    print("Folding protein...")
    structure_prediction = model.generate(
        ESMProtein(sequence=sequence_generation.sequence),
        GenerationConfig(track="structure", num_steps=len(protein_prompt) // 32),
    )
    generated_proteins.append(structure_prediction)

# Sort generations by ptm
generated_proteins = sorted(
    generated_proteins, key=lambda x: x.ptm.item(), reverse=True
)

  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 18.41it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.71it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.76it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.74it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.72it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.48it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.22it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.90it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.17it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.52it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 16.57it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.78it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 16.90it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.22it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 16.43it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.24it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.21it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.15it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.53it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.62it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.11it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 15.78it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.37it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.56it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.35it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.52it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.26it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.15it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.10it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 15.53it/s]


Generating protein sequence...


100%|██████████| 39/39 [00:02<00:00, 17.28it/s]


Folding protein...


100%|██████████| 9/9 [00:00<00:00, 16.90it/s]


In [20]:
N_SAMPLES_TO_SHOW = 4
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))
view.addModel(lipase_chain.to_pdb_string(), "pdb", viewer=(0, 0))
for i in range(N_SAMPLES_TO_SHOW):
    print(
        "PTM of generated protein {}: {:.2f}".format(
            i + 1, generated_proteins[i].ptm.item()
        )
    )
    view.addModel(
        generated_proteins[i].to_protein_chain().to_pdb_string(),
        "pdb",
        viewer=(0, i + 1),
    )
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (np.arange(span_start, span_end) + 1).tolist()},
    {"cartoon": {"color": "red"}},
)
view.zoomTo()
view.show()

PTM of generated protein 1: 0.62
PTM of generated protein 2: 0.53
PTM of generated protein 3: 0.40
PTM of generated protein 4: 0.39
