In [1]:
import sys
sys.path.append("./source")
import torch
import py3Dmol

from model import CoFlowModel
from utils import to_protein
from esm.pretrained import ESM3_structure_decoder_v0
from esm.pretrained import ESM3_structure_encoder_v0
from esm.tokenization import StructureTokenizer, EsmSequenceTokenizer
from esm.utils.structure.protein_chain import ProteinChain

To run the model, both ESM3 weights and CoFlow weights must be downloaded. ESM3
weights can be downloaded from [huggingface](https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1). You can download CoFlow weights from [here](https://doi.org/10.5281/zenodo.14842367), and extract them to the `checkpoint` directory. In addition, ESM3 documentation can be found at the [official reposity](https://github.com/evolutionaryscale/esm).

In [2]:
device = "cuda:0"
checkpoint_dir = "./checkpoint"
model = CoFlowModel.from_pretrained(checkpoint_dir)
model = model.to(device)
# Initialize structure encoder and decoder for structure generation.
# Note that you have downloaded the ESM3 weights, either it wull occur errors here.
decoder = ESM3_structure_decoder_v0(device)
encoder = ESM3_structure_encoder_v0(device)
struc_tokenizer = StructureTokenizer()
seq_tokenizer = EsmSequenceTokenizer()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


## Unconditional Generation

To generate a protein of length 200 through api of `model.sample`, with specific parameters such as `strategy=3`, `length=200`, `steps=400`, and temperature values for sequence and structure set to `0.7`. The output includes the generated protein's structure and sequence, which are then converted into a protein object using the `to_protein` utility function. This object can be further visualized or analyzed.

In [3]:
length=200
sample_out = model.sample(
    strategy=3,
    length=length,
    steps=400,
    eta=length*0.08,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
)
structure, sequence = sample_out['structure'], sample_out['sequence']   # torch.LongTensor, (N,)
uncond_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)       # ESMProtein, including sequence, coordinates, plddt and ptm

Sample Parallel Peroidical:   0%|          | 0/400 [00:00<?, ?it/s]

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:51<00:00,  7.83it/s]


Visualize the protein with [py3DMol](https://github.com/3dmol/3Dmol.js)

In [4]:
print(uncond_protein.sequence)
view = py3Dmol.view(width=500, height=500)
pdb_str = uncond_protein.to_pdb_string()
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()

MATVVLTIPRAQLPVVEAVLAAAGVEATAEVDGDRVTITATVDEDKLAALLAALGAAGVDLRALAIALPSDDPDLVRLLLLIPADAAAALVAAIAAAGIDPADLRSIDVGTDTGHDALRAALEAAFPGVPIEVHPGWAAFAARLAAETGVTLAPPPGAPLTALWLTMSRASARALLAQLAKVPADLDVTVTLPDGEVIEL


## Conditional Generation


The conditional generation takes a protein (PDB: [7LUH](https://www.rcsb.org/structure/7LUH)) as example for the folding and inverse folding task. We frist tokenize the oracle protein sequence and structure. The protein sequence is tokenized using the `EsmSequenceTokenizer`, which converts the sequence into token IDs suitable for input into the model. Similarly, the structure information of the oracle protein is prepared for encoding by extracting coordinates, residue indices, and other structural features. These processed inputs are then encoded using the `ESM3_structure_encoder_v0` to obtain a representation of the oracle structure. This step is essential for conditioning the model on a specific protein sequence and structure during conditional generation tasks.


In [5]:
oracle_protein = ProteinChain.from_pdb("./7LUH.pdb")
oracle_sequence = torch.LongTensor(
    seq_tokenizer(oracle_protein.sequence, add_special_tokens=False).input_ids)
oracle_sequence = oracle_sequence.to(device)

coord, plddt, res_idx = oracle_protein.to_structure_encoder_inputs()    # prepare for the structure encoder
_, oracle_structure = encoder.encode(                   # turn 3D coordinates into discrete tokens
    coords=coord.to(device),
    residue_index=res_idx.to(device),
)

To fold a given sequence, the model's `sample` method is used with the **oracle sequence as input**. The method generates a structure corresponding to the sequence by sampling from the model's learned distribution. The generated structure and sequence are then converted into a protein object using the `to_protein` utility function.

In [None]:
sample_out = model.sample(
    strategy=3,
    steps=400,
    eta=len(sequence)*0.08,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    sequence=oracle_sequence[None, :],      # takes sequence as input for folding task
    sample=False,
)
structure, sequence = sample_out['structure'], sample_out['sequence']     # torch.LongTensor, (N,)
fold_generated_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:24<00:00, 16.35it/s]


Visualize the protein with py3DMol. Structure of rainbow color is generated. Grey structure is orcale. 

In [7]:
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(fold_generated_protein.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(oracle_protein.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "spectrum"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 1))
view.zoomTo()
view.show()

To inverse fold a given backbone, the model's `sample` method is used with the **oracle structure as input**. The method generates a sequence corresponding to the structure by sampling from the model's learned distribution. The generated structure and sequence are then converted into a protein object using the `to_protein` utility function.

In [None]:
sample_out = model.sample(
    strategy=3,
    steps=400,
    eta=len(sequence)*0.08,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    structure=structure[None, :],       # takes structure as input for inverse folding task
    sample=False,
)
structure, sequence = sample_out['structure'], sample_out['sequence']     # torch.LongTensor, (N,)
inverse_generated_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:24<00:00, 16.20it/s]


In [9]:
print("generated:", inverse_generated_protein.sequence)
print("oracle:", oracle_protein.sequence)

generated: GASEPVAGKEYVELSSPQPVSAPAGKVEVVELFWYGCPHCYAFEPTIEKWVAKQSDKVYFKRLPAMFRESFVPHAQLFYTLIAMGVEHDVHNAVFEAIHKEHKRLLTPEEMADFLATKGVDKEKFLSTYNSFAIKGQVEKAKQLAMNYQVTGVPTMVVNGKYRFDIGSAGSPEGTTKLADYLVNKELAAMK
oracle: SPSAPVAGKDFEVMKSPQPVSAPAGKVEVIEFFWYGCPHCYEFEPTIEAWVKKQGDKIAFKRVPVAFRDDFVPHSKLFYALAALGVSEKVTPAVFNAIHKEKNYLLTPQAQADFLATQGVDKKKFLDAYNSFSVQGQVKQSAELLKNYNIDGVPTIVVQGKYKTGPAYTNSLEGTAQVLDFLVKQVQDKKL


To perform motif scaffolding, a motif sequence and structure are defined. The motif sequence contains masked regions represented by underscores `_`, which are replaced with `<mask>` tokens for processing. The corresponding structure is represented as a list of integers, where `4096` serves as the mask token index. The indices of non-masked residues in the sequence are identified and stored in `motif_mask`. The processed motif sequence is then tokenized using the `seq_tokenizer` to prepare it for input into the model.

In [None]:
motif_sequence = "________FSLFDKDGDGTITTKELGTV__________INEVDADGNGTIDFPEFLTM__________________________________________"
motif_structure = [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 3233, 1335, 214, 1438, 332, 3776, 820, 1669, 1339, 1083, 1918, 714, 1143, 1773, 2874, 3987, 320, 1031, 3714, 1760, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 2286, 3101, 1888, 215, 4074, 9, 3638, 1669, 754, 3717, 626, 736, 2357, 2287, 4059, 715, 4029, 3241, 3702, 2629, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096] # 4096 is the mask token index
motif_mask = [idx for idx, a in enumerate(motif_sequence) if a!="_" ]
motif_sequence = motif_sequence.replace("_", "<mask>")
motif_sequence = seq_tokenizer(motif_sequence, add_special_tokens=False).input_ids

The motif scaffolding process involves generating a protein structure and sequence based on a predefined motif. The `model.sample` method is used to sample from the model's learned distribution, conditioned on the motif sequence and structure. The generated output includes the complete protein's structure and sequence, which are then converted into a protein object using the `to_protein` utility function. 

In [None]:
sample_out = model.sample(
    strategy=3,
    steps=400,
    eta=len(motif_sequence)*0.8,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    sequence=torch.LongTensor([motif_sequence]),  # takes both sequence and strucutre as input for motif scaffolding 
    structure=torch.LongTensor([motif_structure]),
    sample=True,
)
structure, sequence = sample_out['structure'], sample_out['sequence']     # torch.LongTensor, (N,)

motif_generated_protein, ptm, plddt = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
)

Sample Parallel Peroidical:   0%|          | 0/400 [00:00<?, ?it/s]

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:48<00:00,  8.20it/s]


Visualize the protein with py3DMol. Structure of grey color is generated, cyan color is the motif.

In [None]:
view = py3Dmol.view(width=500, height=500, )

view.addModel(motif_generated_protein.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "grey"}})
view.addStyle({"resi": motif_mask}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

## Common Issues

1. FileNotFoundError: [Errno 2] No such file or directory: ".cache/huggingface/hub/models--EvolutionaryScale--esm3-sm-open-v1/snapshots/66ecd636588d3100e13598a5720678db6583d01c/data/weights/esm3_structure_encoder_v0.pth". 
   
    The error you're encountering indicates that ESM3 (or the Hugging Face Hub) is unable to find a specific model weight file (esm3_structure_encoder_v0.pth) in the expected cache directory. The most likely reasons is that the model files might not have downloaded completely due to network issues or interruptions. Please make sure the ESM3 weights has been downloaded to the target directory.

2. RuntimeError: CUDA out of memory. 

    This error occurs when the GPU does not have enough memory to handle the current task. To resolve this, try reducing the batch size, sequence length, or model complexity. Alternatively, you can use a machine with a GPU that has more memory or switch to CPU mode by setting `device="cpu"`.

3. KeyError: 'structure'

    This error indicates that the key 'structure' is missing in the output dictionary. Ensure that the `model.sample` method is called with the correct parameters and that the model has been properly initialized with the required weights.

4. AssertionError: Tokenizer mismatch.

    This error occurs when the sequence or structure tokenizer is not compatible with the model. Verify that the correct tokenizers (`StructureTokenizer` and `EsmSequenceTokenizer`) are being used and that they match the model's requirements.

5. ValueError: Invalid PDB file format.

    This error suggests that the input PDB file is not in the correct format or is corrupted. Check the PDB file for formatting issues and ensure it adheres to the standard PDB format.

6. ImportError: No module named 'esm'.

    This error occurs when the `esm` library is not installed. Install the library using `pip install esm` and ensure it is included in your Python environment.