In [11]:
import sys
sys.path.append("./source")
import torch
import py3Dmol

from model import CoFlowModel
from utils import to_protein
from esm.pretrained import ESM3_structure_decoder_v0
from esm.pretrained import ESM3_structure_encoder_v0
from esm.tokenization import StructureTokenizer, EsmSequenceTokenizer
from esm.utils.structure.protein_chain import ProteinChain


To run the model, weights must be downloaded from [here](https://doi.org/10.5281/zenodo.14842367), and extract them to the `checkpoint` directory.

In [2]:
device = "cuda:0"
checkpoint_dir = "./checkpoint"
model = CoFlowModel.from_pretrained(checkpoint_dir)
model = model.to(device)
decoder = ESM3_structure_decoder_v0(device)
encoder = ESM3_structure_encoder_v0(device)
struc_tokenizer = StructureTokenizer()
seq_tokenizer = EsmSequenceTokenizer()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


## Unconditional Generation

In [3]:
length=200
sample_out = model.sample(
    strategy=3,
    length=length,
    steps=400,
    eta=length*0.08,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
)
structure, sequence = sample_out['structure'], sample_out['sequence']
uncond_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:53<00:00,  7.51it/s]


Visualize the protein with py3DMol

In [4]:
print(uncond_protein.sequence)
view = py3Dmol.view(width=500, height=500)
pdb_str = uncond_protein.to_pdb_string()
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()

AAPTLTLTVTVTAPPTVTLTAGGTATATFTVTVTLTSSTHPTVTATGTVTLTGVYDPTATPATVAAGSAGEVVLASATGTATVTVDGVTLTVTVTVTVTLNTSTNTVTTTVTVTVTLTPGDPNDPNALPVTVTVTVTLTDDAGTTTGTATVSGTPATATVTEPWTPGTRTVTVTVTVTLTVGSSSATGTGTATLNINLVN


## Conditional Generation

Tokenize oracle protein sequence and structure.

In [16]:
oracle_protein = ProteinChain.from_pdb("./7LUH.pdb")
oracle_sequence = torch.LongTensor(
    seq_tokenizer(oracle_protein.sequence, add_special_tokens=False).input_ids)
oracle_sequence = oracle_sequence.to(device)

coord, plddt, res_idx = oracle_protein.to_structure_encoder_inputs()
_, oracle_structure = encoder.encode(
    coords=coord.to(device),
    residue_index=res_idx.to(device),
)

To fold given sequence.

In [17]:
sample_out = model.sample(
    strategy=3,
    steps=400,
    eta=len(sequence)*0.08,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    sequence=oracle_sequence[None, :],
    sample=False,
)
structure, sequence = sample_out['structure'], sample_out['sequence']
fold_generated_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:26<00:00, 14.93it/s]


Visualize the protein with py3DMol. Structure of rainbow color is generated. Grey structure is orcale.

In [18]:
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(fold_generated_protein.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(oracle_protein.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "spectrum"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 1))
view.zoomTo()
view.show()

To inverse fold given backbone.

In [19]:
sample_out = model.sample(
    strategy=3,
    steps=500,
    eta=length*0.08,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    structure=structure[None, :],
    sample=False,
)
structure, sequence = sample_out['structure'], sample_out['sequence']
inverse_generated_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 500/500 [00:38<00:00, 12.93it/s]


In [21]:
print("generated:", inverse_generated_protein.sequence)
print("oracle:", oracle_protein.sequence)

generated: GASAPVAGKEYVELSSPQPVSAPAGKIEVVELFWYGCPHCYAFEPTIEKWAAKQGADVQFKRVPAIFRESFVPHAQLFYTLISMGVEHDVHNAVFEAIHKEHKRLATPEEMADFLAGKGVDKEKFLSMYNSFAIKGQVEKAKQLAMAYQVTGVPTMVVNGKYKFGIGMAGSPEGTTKLADYLVEKEKAAKK
oracle: SPSAPVAGKDFEVMKSPQPVSAPAGKVEVIEFFWYGCPHCYEFEPTIEAWVKKQGDKIAFKRVPVAFRDDFVPHSKLFYALAALGVSEKVTPAVFNAIHKEKNYLLTPQAQADFLATQGVDKKKFLDAYNSFSVQGQVKQSAELLKNYNIDGVPTIVVQGKYKTGPAYTNSLEGTAQVLDFLVKQVQDKKL


## Motif Scaffolding

In [22]:
motif_sequence = "________FSLFDKDGDGTITTKELGTV__________INEVDADGNGTIDFPEFLTM__________________________________________"
motif_structure = [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 3233, 1335, 214, 1438, 332, 3776, 820, 1669, 1339, 1083, 1918, 714, 1143, 1773, 2874, 3987, 320, 1031, 3714, 1760, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 2286, 3101, 1888, 215, 4074, 9, 3638, 1669, 754, 3717, 626, 736, 2357, 2287, 4059, 715, 4029, 3241, 3702, 2629, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096] # 4096 is the mask token index
motif_mask = [idx for idx, a in enumerate(motif_sequence) if a!="_" ]
motif_sequence = motif_sequence.replace("_", "<mask>")
motif_sequence = seq_tokenizer(motif_sequence, add_special_tokens=False).input_ids

In [23]:
sample_out = model.sample(
    strategy=3,
    steps=400,
    eta=len(sequence)*0.8,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    sequence=torch.LongTensor([motif_sequence]),
    structure=torch.LongTensor([motif_structure]),
    sample=True,
)
structure, sequence = sample_out['structure'], sample_out['sequence']

motif_generated_protein, ptm, plddt = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:54<00:00,  7.40it/s]


In [24]:
view = py3Dmol.view(width=500, height=500, )

view.addModel(motif_generated_protein.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "grey"}})
view.addStyle({"resi": motif_mask}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()