In [6]:
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
from tqdm import tqdm
import attr
import random
import os
import sys
from datetime import datetime

from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    ESMProteinError,
    ESMProteinTensor,
    SamplingConfig,
    SamplingTrackConfig,
    LogitsConfig,
)
from esm.models.esm3 import ESM3
from esm.tokenization import get_esm3_model_tokenizers

# Import classes from denoising_strategies.py
# Import the Tee and PrintFormatter classes from denoising_strategies.py
from denoising_strategies import Tee, PrintFormatter

# Import the BaseDenoising and MaxProbBasedDenoising classes
from denoising_strategies import BaseDenoising, MaxProbBasedDenoising, EntropyBasedDenoising
from benchmarking_utils import single_metric_UACCE

In [7]:
## On Forge with larger ESM3 models
from esm.sdk import client
token = os.getenv("ESM_FORGE_API_KEY")
client = client(model="esm3-open", url="https://forge.evolutionaryscale.ai", token=token)

In [8]:
# --- Configuration ---
TEST_SEQUENCE = "ACDE"
NOISE_PERCENTAGE = 50.0 # Mask 50% initially (2 positions for length 4)
NUM_DECODING_STEPS = 2 # Number of steps to unmask
TEMPERATURE = 0.0
TRACK = "sequence"
# --- End Configuration ---

# Create a dummy protein
protein = ESMProtein(sequence=TEST_SEQUENCE)
protein_tensor = client.encode(protein)
print(f"Original Protein: {protein.sequence}\n")

Original Protein: ACDE



In [9]:
# Instantiate Denoiser with local model
denoiser = EntropyBasedDenoising(client)
denoiser.track = TRACK # Set track for prints

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

In [10]:
denoiser.denoise(protein, NOISE_PERCENTAGE, NUM_DECODING_STEPS, TEMPERATURE, TRACK)

Starting entropy-based denoising process
├── Adding noise to protein tensor
│   ├── Masked positions (sequence): [2, 1]
└── Resulting tensor: tensor([ 0, 32, 32, 13,  9,  2])
├── Initial sequence: __DE
├── Starting denoising steps:


Denoising:   0%|          | 0/2 [00:00<?, ?it/s]

│   ├── Step 1/2
│   │   ├── Computing position entropies
│   │   │   ├── Raw entropies: tensor([1.8328, 1.5358, 2.8640, 2.8807, 2.4746, 2.6426])
│   │   └── Masked positions entropies: tensor([   inf, 1.5358, 2.8640,    inf,    inf,    inf])
│   │   ├── Unmasking position 1
<class 'int'> 4
│   │   └── Predicted token: L (4)


Denoising:  50%|█████     | 1/2 [00:01<00:01,  1.44s/it]

│   └── Current sequence: L_DE
│   ├── Step 2/2
│   │   ├── Computing position entropies
│   │   │   ├── Raw entropies: tensor([1.8412, 1.6270, 2.8515, 2.8643, 2.3652, 2.5886])
│   │   └── Masked positions entropies: tensor([   inf,    inf, 2.8515,    inf,    inf,    inf])
│   │   ├── Unmasking position 2
<class 'int'> 10
│   │   └── Predicted token: R (10)


Denoising: 100%|██████████| 2/2 [00:02<00:00,  1.41s/it]

│   └── Current sequence: LRDE





│   ├── Final denoised sequence: LRDE
└── Total model calls: 3


ESMProtein(sequence='LRDE', secondary_structure=None, sasa=None, function_annotations=None, coordinates=tensor([[[ -0.2313,   1.7491,   4.2768],
         [  0.8281,   0.7578,   4.1250],
         [  0.3268,  -0.4602,   3.3563],
         [  1.3604,   0.3280,   5.4946],
         [ -0.8824,  -0.6548,   3.2133],
         [  2.2913,   1.3151,   6.2005],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [  2.4358,   0.9468,   7.6733],
         [  3.6536,   1.3487,   5.5162],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      nan],
         [     nan,      nan,      

In [11]:
protein, cost = denoiser.return_generation()
print(f"Final Protein: {protein.sequence}\n")
print(f"Cost: {cost}\n")

Final Protein: LRDE

Cost: 3



In [12]:
print(f"pLDDT: {protein.plddt}\n")

pLDDT: tensor([0.6986, 0.7365, 0.7167, 0.6839])



In [13]:
print(f"pTM: {protein.ptm}\n")

pTM: 0.020479248836636543

