In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import os

# print(os.environ['CHEAP_CACHE'], os.environ['PLAID_CACHE'])

from plaid.pipeline import SampleLatent, DecodeLatent
from plaid.constants import GO_TERM_TO_FUNCTION_IDX, ORGANISM_5_LETTER_CODE_TO_IDX 

  from xformers.components.attention import ScaledDotProduct


# Stage 1: Sample Latent

In [3]:
organism_idx = ORGANISM_5_LETTER_CODE_TO_IDX['ECOLI']
function_idx = GO_TERM_TO_FUNCTION_IDX['6-phosphofructokinase activity']
output_root_dir = '/shared/amyxlu/plaid/artifacts/samples/'

sample_config = {
    'function_idx': function_idx,  
    'organism_idx': organism_idx,
    'batch_size': -1,  # batch size of -1 means we sample all at once
    'cond_scale': 3.0, # 0.0 for unconditional
    'length': None,  # Autochoose length based known Pfam domains. Only works if you provide a function to condition on!
    'model_id': 'PLAID-100M',  # PLAID-2B or PLAID-100M
    'num_samples': 16,
    'output_root_dir': output_root_dir,
    'return_all_timesteps': True,  # saves latents for intermediate timesteps for visualization and debugging
    'sample_scheduler': 'ddim',  # sampling scheduler
    'sampling_timesteps': 500,  # number of diffusion timesteps
    'use_compile': False,  # use JIT compilation. This only makes sense if you are running multiple batches.
    'use_condition_output_suffix': True,  # appends the conditioning code to the output folder
    'use_uid_output_suffix': False  # appends a unique ID to the output folder
}

sample_latent = SampleLatent(**sample_config)
sample_latent = sample_latent.run()
npz_path = sample_latent.outpath

  return torch.load(model_path)


Auto-choosing length 280 (implicit length in GPU memory: 140).


Sampling batches:   0%|          | 0/1 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/500 [00:00<?, ?it/s]

Sampling batches: 100%|██████████| 1/1 [01:03<00:00, 63.16s/it]


Sampling took 63.16 seconds.
Saved .npz file to /shared/amyxlu/plaid/artifacts/samples/f166_o1030_l140_s3/latent.npz [shape=(16, 501, 140, 32)].


# Stage 2: Decode Latent

In [4]:
from plaid.esmfold import esmfold_v1
esmfold = esmfold_v1()
esmfold = esmfold.to('cuda')

Creating ESMFold...
ESMFold model loaded in 31.01 seconds.


In [5]:
npz_path = '/shared/amyxlu/plaid/artifacts/samples/f166_o1030_l140_s3/latent.npz'

decode_config = {
    'npz_path': npz_path,
    'output_root_dir': Path(npz_path).parent,
    'batch_size': 4,  # for structure decoding only
    'device': 'cuda',
    'num_recycles': 4,
    'use_compile': False,
    'chunk_size': 128,
}

decode_latent = DecodeLatent(**decode_config, esmfold=esmfold)
seq_strs, pdb_paths = decode_latent.run()

Output root dir: /shared/amyxlu/plaid/artifacts/samples/f166_o1030_l140_s3
Using checkpoint at /home/amyxlu/.cache/cheap/checkpoints/j1v1wv6w.
Using tanh layer at bottleneck...
Finished loading HPCT model with shorten factor 2 and 32 channel dimensions.
Loading latent samples from /shared/amyxlu/plaid/artifacts/samples/f166_o1030_l140_s3/latent.npz
Decompressing latent samples
Constructing sequences and writing to /shared/amyxlu/plaid/artifacts/samples/f166_o1030_l140_s3
Wrote 16 sequences to /shared/amyxlu/plaid/artifacts/samples/f166_o1030_l140_s3/sequences.fasta.
Constructing structures and writing to /shared/amyxlu/plaid/artifacts/samples/f166_o1030_l140_s3/structures


(Generating structure): 100%|██████████| 4/4 [03:34<00:00, 53.68s/it]


In [11]:
import py3Dmol

def view_structure_with_confidence(pdbstr):
    view = py3Dmol.view(width=800, height=600)
    view.addModel(pdbstr, "pdb")
    
    # Color by B-factor (confidence)
    view.setStyle({
        'model': -1
    }, {
        'cartoon': {
            'colorscheme': {
                'prop': 'b',
                'gradient': 'roygb',
                'min': 0,
                'max': 100
            }
        },
    })
    
    view.zoomTo()
    return view

with open(pdb_paths[0], 'r') as f:
    pdbstr = f.read()

view_structure_with_confidence(pdbstr)

<py3Dmol.view at 0x715954dcd090>

In [12]:
print(seq_strs[0])

SIAVLTSGGDSPGMNAAIRAAVRRAAQHKIRIRGMKEGYSGLIQGEFQEIDPRDVNRILIKGGTILGSARCTTMRDREGKKKLAENLKKNGINALVVVGGDGSMRGAMAFAHEWDIPVVGVPQTIDSDIPETDITIGYDTAVSIAIEAIDRIRDTSSSFNRVFVVEIMGRDVGHIALQAGISGGADVVLIPEHDHSFEKIAMQLKPAHNRGKTAGIIVAAEGFFGNIRASELAQIIKEEGRSGSKPRVIILGHVLRGGTPTLQDRILATRMGVEAVEALK
