In [1]:
%load_ext autoreload
%autoreload 2

# DNA-templated transcription initiation

Pfam families and lengths in dataset:

| pfam_id | Value      |
|---------|------------|
| PF00352 | 79.276549  |
| PF02291 | 119.764097 |
| PF03540 | 48.356998  |
| PF04539 | 75.140669  |
| PF04542 | 68.777518  |
| PF04545 | 50.665917  |
| PF04963 | 191.626108 |

# Load EMA weights

In [12]:
from pathlib import Path
from omegaconf import OmegaConf

import torch
from plaid.diffusion import FunctionOrganismDiffusion
from plaid.denoisers import FunctionOrganismUDiT, DenoiserKwargs
from plaid.constants import COMPRESSION_INPUT_DIMENSIONS, COMPRESSION_SHORTEN_FACTORS

device = torch.device("cuda")

In [2]:
model_id = "5j007z42"

ckpt_dir = Path("/data/lux70/plaid/checkpoints/plaid-compositional") 
model_path = ckpt_dir / model_id / "last.ckpt"
config_path = ckpt_dir / model_id / "config.yaml"

cfg = OmegaConf.load(config_path)

In [3]:
compression_model_id = cfg['compression_model_id']
shorten_factor = COMPRESSION_SHORTEN_FACTORS[compression_model_id]
input_dim = COMPRESSION_INPUT_DIMENSIONS[compression_model_id]

In [4]:
denoiser_kwargs = cfg.denoiser
denoiser_kwargs.pop("_target_")
denoiser = FunctionOrganismUDiT(**denoiser_kwargs, input_dim=input_dim)

In [5]:
# lask.ckpt automatically links to the EMA

ckpt = torch.load(model_path)
ckpt.keys()

  ckpt = torch.load(model_path)


dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])

In [6]:
mod_state_dict = {}
for k, v in ckpt['state_dict'].items():
    if k[:16] == "model._orig_mod.":
        mod_state_dict[k[16:]] = v

In [7]:
denoiser.load_state_dict(mod_state_dict)

<All keys matched successfully>

In [8]:
diffusion_kwargs = cfg.diffusion
diffusion_kwargs.pop("_target_")

# diffusion_kwargs['beta_scheduler_name'] = "sigmoid"
# diffusion_kwargs['sampling_timesteps'] = 500

diffusion = FunctionOrganismDiffusion(model=denoiser,**diffusion_kwargs)

In [9]:
from cheap.pretrained import load_model_from_id
cheap_model = load_model_from_id(compression_model_id)
_ = cheap_model.to(device)

from cheap.proteins import LatentToSequence,LatentToStructure
latent_to_sequence = LatentToSequence()
latent_to_sequence.to(device)

latent_to_structure = LatentToStructure()
latent_to_structure.to(device)

  ckpt = torch.load(checkpoint_fpath)


Using tanh layer at bottleneck...
Finished loading HPCT model with shorten factor 2 and 32 channel dimensions.


  ckpt = torch.load(ckpt_path)
Creating ESMFold...
ESMFold model loaded in 39.36 seconds.


<cheap.proteins.LatentToStructure at 0x7f0cf7eeebb0>

# Sample

In [97]:
# lengths = [40, 60, 24, 36, 96]
lengths = [40, 60, 96]

In [102]:
from plaid.datasets import NUM_ORGANISM_CLASSES, NUM_FUNCTION_CLASSES
import numpy as np

N = 8
input_dim = 32

L = 96

by_length = {}
by_scale = {}

# for L in lengths: 
cond_scale=3

for cond_scale in np.arange(2, 9, 1):
    assert L % 4 == 0
    
    shape = (N, L, input_dim)
    
    organism_idx = NUM_ORGANISM_CLASSES
    function_idx = 162
    
    diffusion.sampling_timesteps=1000
    
    sampled_latent = diffusion.ddim_sample_loop(
        shape=shape,
        organism_idx=organism_idx,
        function_idx=function_idx,
        return_all_timesteps=True,
        cond_scale=float(cond_scale),
    )
    
    # by_length[L] = sampled_latent
    by_scale[cond_scale] = sampled_latent

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

In [106]:
print(by_length.keys())
print(by_scale.keys())

dict_keys([])
dict_keys([2, 3, 4, 5, 6, 7, 8])


In [118]:
import pickle as pkl
with open("/data/lux70/plaid/artifacts/by_scale_GO0006352.pkl", "wb") as f:
    pkl.dump(by_length, f)

In [119]:
!ls /data/lux70/plaid/artifacts/

GO_0036873		 eval		noise_schedule_figures
by_length_GO0006352.pkl  interpolation	samples
by_scale_GO0006352.pkl	 mask_tokens	zero_out_massive_activations
by_scale_GO0046873.pkl	 natural


In [135]:
# sampled_latent = by_length[24]
cond_scale = 3
sampled_latent = by_scale[cond_scale]

In [136]:
print(sampled_latent.shape)
final_sample = sampled_latent[:, -1, :, :]
print(final_sample.shape)

print(final_sample.max(), final_sample.min())

torch.Size([8, 1001, 96, 32])
torch.Size([8, 96, 32])
tensor(1.0280, device='cuda:0') tensor(-1.0027, device='cuda:0')


In [137]:
sampled_uncompressed = cheap_model.decode(final_sample, downsampled_mask=None)
print(sampled_uncompressed.min(), sampled_uncompressed.max())

tensor(-0.8620, device='cuda:0') tensor(1.0151, device='cuda:0')


In [138]:
# import pickle as pkl
# with open("test_sample.pkl", "wb") as f:
#     pkl.dump(sampled_uncompressed,f)

In [139]:
from cheap.utils import LatentScaler
latent_scaler = LatentScaler()
sampled_unscaled = latent_scaler.unscale(sampled_uncompressed) 

In [140]:
print(sampled_unscaled.shape)
print(sampled_unscaled.max(), sampled_unscaled.min())

torch.Size([8, 192, 1024])
tensor(2517.4624, device='cuda:0') tensor(-802.8044, device='cuda:0')


In [141]:
sequences = latent_to_sequence.to_sequence(sampled_unscaled)[-1]

In [142]:
sequences[5]

'DTYNSDKKLLIDKLKKINIENKDILVVSNEFYMDVYESIVIILKKEGAKLIVRDCSFIDNQEDVEFEISDLKRFNPDLGILEIDEIDIVTLSGDEIFLNKFKRIILWVKNKEKTPISVVHLSDFTRAKELPKMERPEGYINNLKDYLESVFGKNTRIFDDLNTYDNKNVFSFNGLEAKELGELISRIVKNKL'

In [143]:
pdb_strs = latent_to_structure.to_structure(sampled_unscaled, sequences=sequences, batch_size=8, num_recycles=4)

(Generating structure): 100%|██████████████| 1/1 [00:20<00:00, 20.31s/it]


In [144]:
import py3Dmol

for i in range(len(pdb_strs)):
# for i in range(10,20): 
# for i in range(0, 10): 
    view = py3Dmol.view(width=400, height=400)
    view.addModelsAsFrames(pdb_strs[i])
    
    # Apply the plDDT color scheme
    # view.setStyle({'cartoon': {'color': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 100}}})
    view.setStyle({'cartoon': {'color': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 90}}})
    
    # # Add surface representation with plDDT-based color
    view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 100}})
    # view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 100}})

    view.zoomTo()
    view.show()

In [145]:
# pdbs = pdb_strs[1]
# with open("pdb1_GO0006352.pdb", "w") as f:
#     f.write(pdbs)

In [146]:
from pathlib import Path

samples_dir = Path("/data/lux70/plaid/artifacts/GO_0006352/by_scale")
if not samples_dir.exists():
    samples_dir.mkdir(parents=True)

for i in range(len(pdb_strs)):
    with open(samples_dir / f"condscale{cond_scale}_sample{i}", "w") as f:
        f.write(pdb_strs[i])

In [147]:
!ls /data/lux70/plaid/artifacts/GO_0006352/by_scale

condscale2_sample0  condscale2_sample6	condscale3_sample4  condscale8_sample2
condscale2_sample1  condscale2_sample7	condscale3_sample5  condscale8_sample3
condscale2_sample2  condscale3_sample0	condscale3_sample6  condscale8_sample4
condscale2_sample3  condscale3_sample1	condscale3_sample7  condscale8_sample5
condscale2_sample4  condscale3_sample2	condscale8_sample0  condscale8_sample6
condscale2_sample5  condscale3_sample3	condscale8_sample1  condscale8_sample7
