In [1]:
import biotite.structure as bs
import py3Dmol

from esm.sdk.api import ESMProtein, GenerationConfig
from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction

class PTMScoringFunction(GuidedDecodingScoringFunction):
    def __call__(self, protein: ESMProtein) -> float:
        # Minimal example of a scoring function that scores proteins based on their pTM score
        # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score
        assert protein.ptm is not None, "Protein must have pTM scores to be scored"
        return float(protein.ptm)


In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
from esm.models.esm3 import ESM3
model = ESM3.from_pretrained()  


Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

In [4]:
ptm_guided_decoding = ESM3GuidedDecoding(
    client=model, scoring_function=PTMScoringFunction()
)

In [5]:
# Start from a fully masked protein
PROTEIN_LENGTH = 256
starting_protein = ESMProtein(sequence="_" * PROTEIN_LENGTH)

# Call guided_generate
generated_protein = ptm_guided_decoding.guided_generate(
    protein=starting_protein,
    num_decoding_steps=len(starting_protein) // 8,
    num_samples_per_step=10,
)

Current score: 0.93: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [36:39<00:00, 68.73s/it]


In [6]:
generated_protein_no_guided: ESMProtein = model.generate(
    input=starting_protein,
    config=GenerationConfig(track="sequence", num_steps=len(starting_protein) // 8),
)  # type: ignore

generated_protein_no_guided: ESMProtein = model.generate(
    input=generated_protein_no_guided,
    config=GenerationConfig(track="structure", num_steps=1),
)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:27<00:00,  2.74s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.61s/it]


In [7]:
print(f"pTM Without guidance: {generated_protein_no_guided.ptm[0].item():.3f}")
print(f"pTM With guidance: {generated_protein.ptm[0].item():.3f}")


pTM Without guidance: 0.251
pTM With guidance: 0.934


In [17]:
print("Sequence (no guidance):", generated_protein_no_guided.sequence)
print("Sequence (with guidance):", generated_protein.sequence)
import py3Dmol
from IPython.display import display

def show_protein(pdb_string, style='cartoon'):
    view = py3Dmol.view(width=400, height=400)
    view.addModel(pdb_string, 'pdb')
    view.setStyle({style: {'color': 'spectrum'}})
    view.zoomTo()
    display(view.show()) 


Sequence (no guidance): MLEALLLLLLLLLLGLLLLLLVLLLLLALLLDLLLLLLLLLLLGLSLGAAMAAAAAAAAVAAAAAAAAAAAAAAAYLLLLLLALNLLLLGAFLLLLLLLLLLLLLAQLGAEAAAAFASAAAAAAAAAAAAAKAAAAALAAAATVAAALAALLLLLLSLLPAAAAAAAAAAGAALAAAAAEATATAAAAAAAAAAALVAAAPAAAAAIAAAAAKAAAAAAAAAAAAASAALLLLLLLALLLLNAALLLVLLLLVLLL
Sequence (with guidance): MNLMGKIAVVTGASGGIGLKTAACLADLGVKIYLVARNDEKLETARRLIAERGGTAFALQADLLDDASLLQGKARLLALEDSVDFLVLNAAYWVEGNAESAPRNGVHALFAYNVFGAIAMIQVMLPLLRAAGGCIIAVGSALHADPGPFNDTDRYPAGKKPPRYIEFAAAKQALAGMVKQLAAELAPEVRAMAYDPGLQVDTAANHANWPQHAHRWTPPEMAARAICYLVSPEAEDWHGKVIACRDVLKLRGEIKW


AttributeError: 'ESMProtein' object has no attribute 'pdb'

In [20]:
print("Sequence (no guidance):", generated_protein_no_guided.sequence)
print("Sequence (with guidance):", generated_protein.sequence)

import py3Dmol
from IPython.display import display

def show_protein(pdb_string, style='cartoon'):
    view = py3Dmol.view(width=400, height=400)
    view.addModel(pdb_string, 'pdb')
    view.setStyle({style: {'color': 'spectrum'}})
    view.zoomTo()
    return view.show()  

# visualize both
show_protein(generated_protein_no_guided.to_pdb_string())
show_protein(generated_protein.to_pdb_string())


Sequence (no guidance): MLEALLLLLLLLLLGLLLLLLVLLLLLALLLDLLLLLLLLLLLGLSLGAAMAAAAAAAAVAAAAAAAAAAAAAAAYLLLLLLALNLLLLGAFLLLLLLLLLLLLLAQLGAEAAAAFASAAAAAAAAAAAAAKAAAAALAAAATVAAALAALLLLLLSLLPAAAAAAAAAAGAALAAAAAEATATAAAAAAAAAAALVAAAPAAAAAIAAAAAKAAAAAAAAAAAAASAALLLLLLLALLLLNAALLLVLLLLVLLL
Sequence (with guidance): MNLMGKIAVVTGASGGIGLKTAACLADLGVKIYLVARNDEKLETARRLIAERGGTAFALQADLLDDASLLQGKARLLALEDSVDFLVLNAAYWVEGNAESAPRNGVHALFAYNVFGAIAMIQVMLPLLRAAGGCIIAVGSALHADPGPFNDTDRYPAGKKPPRYIEFAAAKQALAGMVKQLAAELAPEVRAMAYDPGLQVDTAANHANWPQHAHRWTPPEMAARAICYLVSPEAEDWHGKVIACRDVLKLRGEIKW


In [29]:
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from Bio.SeqUtils import molecular_weight

class HydrophobicityScoringFunction(GuidedDecodingScoringFunction):
    
    def __init__(self, target_gravy: float = 0.0, penalty_weight: float = 1.0):
        self.target_gravy = target_gravy
        self.penalty_weight = penalty_weight
    
    def __call__(self, protein: ESMProtein) -> float:
        seq = str(protein.sequence).replace("_", "")
        
        if len(seq) < 10:
            return -100.0
        
        try:
            analyzer = ProteinAnalysis(seq)
            gravy = analyzer.gravy()
            return -abs(gravy - self.target_gravy) * self.penalty_weight
        except Exception:
            return -100.0


In [41]:
scoring_fn = HydrophobicityScoringFunction(target_gravy=0.5)
guided_decoding = ESM3GuidedDecoding(
     client=model,
     scoring_function=scoring_fn
 )
starting_protein = ESMProtein(sequence="_" * 256)
generated_protein2 = guided_decoding.guided_generate(
     protein=starting_protein,
     num_decoding_steps=32,
     num_samples_per_step=10,
 )


Current score: -0.01: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [38:09<00:00, 71.55s/it]


In [47]:
generated_protein_no_guided2: ESMProtein = model.generate(
    input=starting_protein,
    config=GenerationConfig(track="sequence", num_steps=len(starting_protein) // 8),
)  

generated_protein_no_guided2: ESMProtein = model.generate(
    input=generated_protein_no_guided2,
    config=GenerationConfig(track="structure", num_steps=1),
)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:26<00:00,  2.70s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.02s/it]


In [49]:
from Bio.SeqUtils.ProtParam import ProteinAnalysis

# Compute GRAVY for the unguided protein
seq_no_guided2 = str(generated_protein_no_guided2.sequence).replace("_", "")
gravy_no_guided = ProteinAnalysis(seq_no_guided2).gravy()

# Compute GRAVY for the guided protein
seq_guided2 = str(generated_protein2.sequence).replace("_", "")
gravy_guided2 = ProteinAnalysis(seq_guided2).gravy()
print(f"GRAVY Without guidance: {gravy_no_guided2:.3f}")
print(f"GRAVY With guidance:    {gravy_guided2:.3f}")


GRAVY Without guidance: 2.470
GRAVY With guidance:    0.491


In [62]:
class InstabilityIndexScoringFunction(GuidedDecodingScoringFunction):
    """
    Scores based on instability index (lower is more stable).
    
    The instability index estimates protein stability in a test tube.
    - Values < 40 indicate the protein is predicted to be stable
    - Values > 40 indicate the protein may be unstable
    
    Returns negative of instability index to favor stable proteins.
    """
    
    def __call__(self, protein: ESMProtein) -> float:
        seq = str(protein.sequence).replace("_", "")
        
        if len(seq) < 10:
            return -100.0
        
        try:
            analyzer = ProteinAnalysis(seq)
            instability = analyzer.instability_index()
            return -instability
        except Exception:
            return -100.0


In [63]:
scoring_fn = InstabilityIndexScoringFunction()
guided_decoding = ESM3GuidedDecoding(
     client=model,
     scoring_function=scoring_fn
 )
starting_protein = ESMProtein(sequence="_" * 256)
generated_protein3 = guided_decoding.guided_generate(
     protein=starting_protein,
     num_decoding_steps=32,
     num_samples_per_step=10,
 )

Current score: -1.06: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [43:30<00:00, 81.57s/it]


In [64]:
generated_protein_no_guided3: ESMProtein = model.generate(
    input=starting_protein,
    config=GenerationConfig(track="sequence", num_steps=len(starting_protein) // 8),
)  

generated_protein_no_guided3: ESMProtein = model.generate(
    input=generated_protein_no_guided,
    config=GenerationConfig(track="structure", num_steps=1),
)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:29<00:00,  2.78s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.59s/it]


In [65]:
from Bio.SeqUtils.ProtParam import ProteinAnalysis

# Compute instability for the unguided protein
seq_no_guided3 = str(generated_protein_no_guided3.sequence).replace("_", "")
instability_no_guided = ProteinAnalysis(seq_no_guided3).instability_index()

# Compute instability for the guided protein
seq_guided3 = str(generated_protein3.sequence).replace("_", "")
instability_guided = ProteinAnalysis(seq_guided3).instability_index()
print(f"Instability Without guidance: {instability_no_guided:.3f}")
print(f"Instability With guidance:    {instability_guided:.3f}")


Instability Without guidance: 12.126
Instability With guidance:    1.065


In [67]:
class AromaticityScoringFunction(GuidedDecodingScoringFunction):
    """
    Scores based on aromaticity (fraction of Phe, Trp, Tyr).
    
    Aromatic residues are important for:
    - Protein-protein interactions
    - Binding sites
    - Structural stability
    
    Typical aromaticity: 5-15% of residues
    
    Args:
        target_aromaticity: Desired fraction of aromatic residues (0-1)
        penalty_weight: Weight for deviation from target
    """
    
    def __init__(self, target_aromaticity: float = 0.1, penalty_weight: float = 10.0):
        self.target_aromaticity = target_aromaticity
        self.penalty_weight = penalty_weight
    
    def __call__(self, protein: ESMProtein) -> float:
        seq = str(protein.sequence).replace("_", "")
        
        if len(seq) < 10:
            return -100.0
        
        try:
            analyzer = ProteinAnalysis(seq)
            aromaticity = analyzer.aromaticity()
            return -abs(aromaticity - self.target_aromaticity) * self.penalty_weight
        except Exception:
            return -100.0
scoring_fn = AromaticityScoringFunction(target_aromaticity=0.15)
guided_decoding = ESM3GuidedDecoding(
     client=model,
     scoring_function=scoring_fn
 )
starting_protein = ESMProtein(sequence="_" * 256)
generated_protein4 = guided_decoding.guided_generate(
     protein=starting_protein,
     num_decoding_steps=32,
     num_samples_per_step=10,
 )

Current score: -0.02: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [39:46<00:00, 74.57s/it]


In [69]:
generated_protein_no_guided4: ESMProtein = model.generate(
    input=starting_protein,
    config=GenerationConfig(track="sequence", num_steps=len(starting_protein) // 8),
)  

generated_protein_no_guided4: ESMProtein = model.generate(
    input=generated_protein_no_guided,
    config=GenerationConfig(track="structure", num_steps=1),
)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:33<00:00,  2.92s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.31s/it]


In [72]:
from Bio.SeqUtils.ProtParam import ProteinAnalysis

# Compute instability for the unguided protein
seq_no_guided4 = str(generated_protein_no_guided4.sequence).replace("_", "")
aromaticity_no_guided4 = ProteinAnalysis(seq_no_guided4).aromaticity()

# Compute instability for the guided protein
seq_guided4 = str(generated_protein4.sequence).replace("_", "")
aromaticity_guided4 = ProteinAnalysis(seq_guided4).aromaticity()
print(f"Aromaticity Without guidance: {aromaticity_no_guided4:.3f}")
print(f"Aromaticity With guidance:    {aromaticity_guided4:.3f}")


Aromaticity Without guidance: 0.012
Aromaticity With guidance:    0.148
