In [93]:
import sys
sys.path.append("./source")
import torch
import py3Dmol

from model import CoFlowModel
from utils import to_protein
from esm.pretrained import ESM3_structure_decoder_v0
from esm.tokenization import StructureTokenizer, EsmSequenceTokenizer

To run the model, weights must be downloaded from [here](https://doi.org/10.5281/zenodo.14842367), and extract them to the `checkpoint` directory.

In [94]:
device = "cuda:0"
checkpoint_dir = "./checkpoint"
model = CoFlowModel.from_pretrained(checkpoint_dir)
model = model.to(device)
decoder = ESM3_structure_decoder_v0(device)
struc_tokenizer = StructureTokenizer()
seq_tokenizer = EsmSequenceTokenizer()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


## Unconditional Generation

In [123]:
length=200
sample_out = model.sample(
    strategy=3,
    length=length,
    steps=400,
    eta=length*0.08,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
)
structure, sequence = sample_out['structure'], sample_out['sequence']
uncond_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:52<00:00,  7.57it/s]


Visualize the protein with py3DMol

In [133]:
print(uncond_protein.sequence)
view = py3Dmol.view(width=500, height=500)
pdb_str = uncond_protein.to_pdb_string()
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()

MTRVLLVADGGAAGVAGALALALLARGYAVTLFVADTPAARRWAERVAAAGGRALLAPAGVTLGEPAALEQVLAAAAEVDAVVIVVLARDAALAAAGLPPEAAAAAAAALAAQVAALVAALRAARPDLRVSVLLLTGTDAPSPGADRIAAALPADVPVTRVVTGVRLSADSPTGVTAADVDALDAATVARLADVVEAQLE


## Conditional Generation

In [125]:
oracle_sequence = "SPSAPVAGKDFEVMKSPQPVSAPAGKVEVIEFFWYGCPHCYEFEPTIEAWVKKQGDKIAFKRVPVAFRDDFVPHSKLFYALAALGVSEKVTPAVFNAIHKEKNYLLTPQAQADFLATQGVDKKKFLDAYNSFSVQGQVKQSAELLKNYNIDGVPTIVVQGKYKTGPAYTNSLEGTAQVLDFLVKQVQDKKL"
tokenized_structure = torch.LongTensor([
    144,3056,2769,1280,1709,2119,3979,2063,413,2876,391,269,943,3949,247,1688,2968,4003,1164,3949,3674,352,1005,3681,546,3891,2502,4012,444,768,984,3264,2283,883,2230,4093,1785,2483,3590,193,441,3023,2869,2144,1394,3833,462,2303,3259,3272,3918,588,588,3370,612,1464,1598,1739,3589,832,2820,3495,3857,3973,480,2354,2701,3816,3681,3979,1667,46,2011,1176,450,3958,191,584,4060,819,401,1450,2205,667,82,462,74,1059,2827,2697,3372,1619,749,50,508,1894,3313,2337,2137,3407,1126,1474,76,1785,187,2247,248,2471,445,588,3049,201,2874,2424,3755,2737,588,199,2504,726,4057,833,116,1200,1783,21,588,3815,3793,1509,2408,3108,1803,2371,3145,123,631,1521,137,1352,1107,668,1450,840,3194,987,3042,2249,2186,1702,2366,2990,1571,1952,788,3134,1726,593,2711,3437,3032,3577,1329,1867,3731,431,1598,2080,2075,3717,1729,3668,953,965,717,375,1248,1616,101,3209,2301,101,2481,3961,1800,833,1783,1035,3581,4006,1025
])

tokenized_sequence = torch.LongTensor(
            seq_tokenizer(oracle_sequence, add_special_tokens=False).input_ids)

tokenized_sequence = tokenized_sequence.to(device)
tokenized_structure = tokenized_structure.to(device)

oracle, _, _ = to_protein(
    structure=tokenized_structure,
    sequence=tokenized_sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    pad=True,
)

To fold given sequence.

In [126]:
sample_out = model.sample(
    strategy=3,
    steps=400,
    eta=len(sequence)*0.08,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    sequence=tokenized_sequence[None, :],
    sample=False,
)
structure, sequence = sample_out['structure'], sample_out['sequence']
fold_generated_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:26<00:00, 15.16it/s]


Visualize the protein with py3DMol. Structure of rainbow color is generated. Grey structure is orcale.

In [127]:
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(fold_generated_protein.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(oracle.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "spectrum"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 1))
view.zoomTo()
view.show()

To inverse fold given structure.

In [128]:
sample_out = model.sample(
    strategy=3,
    steps=500,
    eta=length*0.08,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    structure=structure[None, :],
    sample=False,
)
structure, sequence = sample_out['structure'], sample_out['sequence']
inverse_generated_protein, _, _ = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
    strip=False,
)

Sample Parallel Peroidical: 100%|██████████| 500/500 [00:40<00:00, 12.20it/s]


In [129]:
print("generated:", inverse_generated_protein.sequence)
print("oracle:", oracle_sequence)

generated: GASEPVAGKEYVELSSPQPVSTPAGKIEVVELFWYGCPHCYAFEPTITKWAEKQPDNVHFVRVPAMFRESFVPHGQLFYALISMGVEHDVHNAVFDAIHKEHKRLATPEEMADFLATKGVDKEKFLATYNSFAIKGQVEKAKKLAMNYQVTGVPTMVVNGKYRFDIGMTGSPEGTTKLADYLVNKEAAAAK
oracle: SPSAPVAGKDFEVMKSPQPVSAPAGKVEVIEFFWYGCPHCYEFEPTIEAWVKKQGDKIAFKRVPVAFRDDFVPHSKLFYALAALGVSEKVTPAVFNAIHKEKNYLLTPQAQADFLATQGVDKKKFLDAYNSFSVQGQVKQSAELLKNYNIDGVPTIVVQGKYKTGPAYTNSLEGTAQVLDFLVKQVQDKKL


## Motif Scaffolding

In [130]:
motif_sequence = "________FSLFDKDGDGTITTKELGTV__________INEVDADGNGTIDFPEFLTM__________________________________________"
motif_structure = [4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 3233, 1335, 214, 1438, 332, 3776, 820, 1669, 1339, 1083, 1918, 714, 1143, 1773, 2874, 3987, 320, 1031, 3714, 1760, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 2286, 3101, 1888, 215, 4074, 9, 3638, 1669, 754, 3717, 626, 736, 2357, 2287, 4059, 715, 4029, 3241, 3702, 2629, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096, 4096] # 4096 is the mask token index
motif_mask = [idx for idx, a in enumerate(motif_sequence) if a!="_" ]
motif_sequence = motif_sequence.replace("_", "<mask>")
motif_sequence = seq_tokenizer(motif_sequence, add_special_tokens=False).input_ids

In [136]:
sample_out = model.sample(
    strategy=3,
    steps=400,
    eta=len(sequence)*0.8,
    purity=False,
    sequence_temp=0.7,
    structure_temp=0.7,
    device=device,
    sequence=torch.LongTensor([motif_sequence]),
    structure=torch.LongTensor([motif_structure]),
    sample=True,
)
structure, sequence = sample_out['structure'], sample_out['sequence']

motif_generated_protein, ptm, plddt = to_protein(
    structure=structure,
    sequence=sequence,
    decoder=decoder,
    struc_tokenizer=struc_tokenizer,
    seq_tokenizer=seq_tokenizer,
)

Sample Parallel Peroidical: 100%|██████████| 400/400 [00:52<00:00,  7.68it/s]


In [137]:
view = py3Dmol.view(width=500, height=500, )

view.addModel(motif_generated_protein.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "grey"}})
view.addStyle({"resi": motif_mask}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()