In [3]:
# precompute_embeddings.py
import math
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import login
from esm.models.esm3 import ESM3

In [4]:
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
def run_roundtrip_for_sequence(seq: str, out_dir: str, protein_id: str, num_steps: int = 8):
    """
    Roda: completar sequência (se tiver máscaras) -> estrutura -> round-trip
    e salva dois PDBs: generation e round_tripped.
    """
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    protein = ESMProtein(sequence=seq)

    # 1) Completar / gerar sequência (se tiver '_' como mask)
    protein = model.generate(
        protein,
        GenerationConfig(track="sequence", num_steps=num_steps, temperature=0.7),
    )

    # 2) Gerar estrutura para essa sequência
    protein = model.generate(
        protein,
        GenerationConfig(track="structure", num_steps=num_steps),
    )
    gen_pdb_path = out_dir / f"{protein_id}_generation.pdb"
    protein.to_pdb(str(gen_pdb_path))

    # 3) Round-trip: apagar sequência, re-gerar, depois apagar coords e re-foldar
    protein.sequence = None
    protein = model.generate(
        protein,
        GenerationConfig(track="sequence", num_steps=num_steps),
    )

    protein.coordinates = None
    protein = model.generate(
        protein,
        GenerationConfig(track="structure", num_steps=num_steps),
    )
    rt_pdb_path = out_dir / f"{protein_id}_round_tripped.pdb"
    protein.to_pdb(str(rt_pdb_path))

    return str(gen_pdb_path), str(rt_pdb_path)

In [8]:
# This will download the model weights and instantiate the model on your machine.
model = ESM3.from_pretrained("esm3-open").to("cuda") # or "cpu"

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

  state_dict = torch.load(


In [10]:
# torch.save(model, "./esm3-open.pth")  # salva o objeto inteiro

In [11]:

# Generate a completion for a partial Carbonic Anhydrase (2vvb)
prompt = "___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLKGGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIKTKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________"
protein = ESMProtein(sequence=prompt)
# Generate the sequence, then the structure. This will iteratively unmask the sequence track.
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7))
# We can show the predicted structure for the generated sequence.
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./generation.pdb")
# Then we can do a round trip design by inverse folding the sequence and recomputing the structure
protein.sequence = None
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
protein.coordinates = None
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./round_tripped.pdb")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00,  9.20it/s]
  state_dict = torch.load(
  state_dict = torch.load(
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 12.54it/s]
  state_dict = torch.load(
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):  # type: ignore
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 12.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 11.87it/s]


In [None]:

# TODO: substitua isso pela sua forma real de carregar o dataset
# Ex.: ler de um TSV, FASTA, etc.
sequences = [
    # ("protein_id_1", "MKT..."),
    # ("protein_id_2", "AGT..."),
]

N = len(sequences)
if N == 0:
    raise RuntimeError("A lista 'sequences' está vazia. Carregue suas sequências aqui.")

# Dimensão da embedding (depende do modelo)
embed_dim = model.config.hidden_size

# Cria memmap para salvar tudo em disco
emb_path = "./embed_protein.dat"
emb_memmap = np.memmap(
    emb_path,
    dtype="float32",
    mode="w+",
    shape=(N, embed_dim),
)

batch_size = 8  # ajuste conforme sua GPU

with torch.no_grad():
    idx = 0
    for start in tqdm(range(0, N, batch_size), desc="Gerando embeddings"):
        end = min(start + batch_size, N)
        batch_seqs = [seq for _, seq in sequences[start:end]]

        tokens = tokenizer(
            batch_seqs,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        tokens = {k: v.to(device) for k, v in tokens.items()}

        outputs = model(**tokens)

        # Exemplo genérico: usar last_hidden_state (B, L, D)
        # Verifique na doc do modelo se existe um pooled output mais adequado.
        reps = outputs.last_hidden_state  # (batch, seq_len, embed_dim)

        # Máscara de atenção para ignorar padding na média
        attn_mask = tokens["attention_mask"].unsqueeze(-1)  # (B, L, 1)
        reps = reps * attn_mask
        lengths = attn_mask.sum(dim=1).clamp(min=1)  # (B, 1)
        emb_batch = reps.sum(dim=1) / lengths       # (B, D)

        emb_batch = emb_batch.cpu().numpy().astype("float32")
        emb_memmap[start:end, :] = emb_batch

emb_memmap.flush()
print(f"Salvo em {emb_path} com shape {(N, embed_dim)}")