In [1]:
%load_ext autoreload
%autoreload 2

# Load EMA weights

In [2]:
from pathlib import Path
from omegaconf import OmegaConf

import torch
from plaid.diffusion import FunctionOrganismDiffusion
from plaid.denoisers import FunctionOrganismDiT, DenoiserKwargs
from plaid.constants import COMPRESSION_INPUT_DIMENSIONS, COMPRESSION_SHORTEN_FACTORS

In [67]:
model_id = "qfvl29in"

ckpt_dir = Path("/data/lux70/plaid/checkpoints/plaid_compositional_conditioning") 
model_path = ckpt_dir / model_id / "last.ckpt"
config_path = ckpt_dir / model_id / "config.yaml"

cfg = OmegaConf.load(f"/data/lux70/plaid/checkpoints/plaid_compositional_conditioning/{model_id}/config.yaml")

In [68]:
compression_model_id = cfg['compression_model_id']
shorten_factor = COMPRESSION_SHORTEN_FACTORS[compression_model_id]
input_dim = COMPRESSION_INPUT_DIMENSIONS[compression_model_id]

In [69]:
denoiser_kwargs

{'hidden_size': 768, 'max_seq_len': 512, 'depth': 12, 'num_heads': 12, 'mlp_ratio': 4.0, 'use_self_conditioning': False}

In [70]:
denoiser_kwargs = cfg.denoiser
denoiser_kwargs.pop("_target_")

diffusion_kwargs = cfg.diffusion
diffusion_kwargs.pop("_target_")

'plaid.diffusion.FunctionOrganismDiffusion'

In [71]:
denoiser = FunctionOrganismDiT(**denoiser_kwargs, input_dim=input_dim)
diffusion = FunctionOrganismDiffusion(**diffusion_kwargs, model=denoiser)

line 87

Number of visible CUDA GPUs: 1
Current GPU ID device number: 0
Current CUDA memory allocated: 22309.15771484375 MB
Current CUDA memory reserved: 24866.0 MB
Number of visible CPUs: 8

line 94

Number of visible CUDA GPUs: 1
Current GPU ID device number: 0
Current CUDA memory allocated: 22309.15771484375 MB
Current CUDA memory reserved: 24866.0 MB
Number of visible CPUs: 8

line 108

Number of visible CUDA GPUs: 1
Current GPU ID device number: 0
Current CUDA memory allocated: 22309.15771484375 MB
Current CUDA memory reserved: 24866.0 MB
Number of visible CPUs: 8



In [72]:
ckpt = torch.load(model_path)
# ckpt.keys()

In [73]:
ckpt.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters', 'ema_state_dict'])

In [74]:
diffusion.model.load_state_dict(ckpt['state_dict'])
diffusion.ema_model.load_state_dict(ckpt['ema_state_dict'])
_ = diffusion.eval().requires_grad_(False)

# Sample Latent

In [82]:
from cheap.pretrained import load_model_from_id
cheap_model = load_model_from_id(compression_model_id)
_ = cheap_model.to(device)


from cheap.proteins import LatentToSequence,LatentToStructure
latent_to_sequence = LatentToSequence()
latent_to_sequence.to(device)

latent_to_structure = LatentToStructure()
latent_to_structure.to(device)

Using tanh layer at bottleneck...
Finished loading HPCT model with shorten factor 2 and 32 channel dimensions.


Creating ESMFold...
ESMFold model loaded in 33.93 seconds.


<cheap.proteins.LatentToStructure at 0x7f61dc098c70>

In [75]:
import torch
device = torch.device("cuda")
_ = diffusion.to(device)

Human, carbohydrate metabolic process:

In [76]:
# organism_idx = org_df[org_df.organism_id == "HUMAN"].organism_index.iloc[0]
# function_idx = go_df[go_df.GO_term == "carbohydrate metabolic process"].GO_idx.iloc[0]
# print(organism_idx, function_idx)

organism_idx = 1326
function_idx = 55

In [77]:
device = torch.device("cuda")
diffusion = diffusion.to(device)

In [99]:
N, L = 32, 128
shape = (N, L, input_dim)
# organism_y_idxs = torch.full((N,), organism_idx)
# function_y_idxs = torch.full((N,), function_idx)

cond_scale = 50

diffusion.sampling_timesteps=1000
sampled_latent = diffusion.ddim_sample_loop(shape, organism_idx, function_idx, return_all_timesteps=True, cond_scale=cond_scale)

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

In [100]:
print(sampled_latent.shape)
final_sample = sampled_latent[:, -1, :, :]
print(final_sample.shape)

print(final_sample.max(), final_sample.min())

torch.Size([32, 1001, 128, 32])
torch.Size([32, 128, 32])
tensor(0.9276, device='cuda:0') tensor(-1.0014, device='cuda:0')


In [101]:
sampled_uncompressed = cheap_model.decode(final_sample, downsampled_mask=None)
print(sampled_uncompressed.min(), sampled_uncompressed.max())

tensor(-0.8746, device='cuda:0') tensor(0.9289, device='cuda:0')


In [102]:
# import pickle as pkl
# with open("test_sample.pkl", "wb") as f:
#     pkl.dump(sampled_uncompressed,f)

In [103]:
from cheap.utils import LatentScaler
latent_scaler = LatentScaler()
sampled_unscaled = latent_scaler.unscale(sampled_uncompressed) 

In [104]:
print(sampled_unscaled.shape)
print(sampled_unscaled.max(), sampled_unscaled.min())

torch.Size([32, 256, 1024])
tensor(2707.6621, device='cuda:0') tensor(-839.5864, device='cuda:0')


In [105]:
sequences = latent_to_sequence.to_sequence(sampled_unscaled)[-1]

In [106]:
sequences[:3]

['NLRSRHSIASYMLPSLPFLPAQSAPHAPSSFGQKPNAQLPRYNTWDCRPKSDNSIYFCGLSGNIGYTIKKEETDYSAVVEIIMTECQFDGSMVGADDWGSLDAVLMSLSNMPAEKTRTLPGMNHDPEMYNVTNGARNHKIQPEKFCDFERRGYERNLAIRVANGGGEILVLYDGYSLDGVIRPVQIDGRDTSHRLPGFMPRFALVDGGLSSYDGVPAGVNFYCPDSQQDAFGSHHCSARQAQITILNGWPGTNCQA',
 'DDSVPHFQIAFFRCLFLSNAWNRCRSEDFKQQRDAEIYKGARFFFAENETALDRAMDPCAQQSHKAHLYHNPTLDQDSIDPPLSALQLEGRTRFSECATCSNSYDRVLLFDETSNYLRLPSQRHRYEHFRQRFSLINWSRSLKIFLPLRAMNIIPMSATGPAFAYVLDDLKPHLNPPAFNEPKLYIDCDASPVIGRGYDDRLANQGLQTRSVKLTATETGVEFEAHYIGASQSHVVGRDNKDRQNSCNNGQVFFRA',
 'LPDLTGNGRRELIYAVGGRATGLSTFAGSRDRHGGFVVKVKWNLSVVLRASSLISTGDSTAFVVKQCTDDTAPPTIGPEVLSPAMPRVPANIMFGYRAVTDSEVQKTIQYQCAPGRSFAANGPLPEGAGIMSISLANGTILVSDNDMAVVEVEGAGGHGAVVRVSGNLTRYNNFLPVKSPSVSSLLPKNVMSGERLTIRGMPTRFGAPALTELSSTSAFDRALSRDTIPQEVKKCDGNGNGDRTLYLLDLVDSKQA']

In [107]:
pdb_strs, raw_outputs = latent_to_structure.to_structure(sampled_unscaled, return_raw_outputs=True, sequences=sequences, batch_size=32)

(Generating structure): 100%|████████████████████| 1/1 [02:11<00:00, 131.32s/it]


In [108]:
raw_outputs['plddt'].mean(dim=-1).mean(dim=-1)

tensor([33.1343, 36.9338, 34.9471, 35.6851, 40.8498, 37.4101, 36.2797, 36.0825,
        30.1263, 36.2148, 45.2444, 38.5934, 33.0624, 34.6797, 44.8370, 38.5368,
        37.2278, 37.0289, 35.0024, 52.0163, 44.9334, 39.2806, 41.2009, 35.7472,
        35.8245, 33.9206, 35.0714, 36.4963, 33.5843, 35.0759, 35.3503, 36.1060])

In [109]:
import py3Dmol

# for i in range(len(pdb_strs)):
for i in range(0,10): 
    view = py3Dmol.view(width=600, height=600)
    view.addModelsAsFrames(pdb_strs[i])
    
    # Apply the plDDT color scheme
    view.setStyle({'cartoon': {'color': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 100}}})
    # view.setStyle({'cartoon': {'color': {'prop': 'b', 'gradient': 'roygb', 'min': 50, 'max': 90}}})
    
    # # Add surface representation with plDDT-based color
    # view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 50, 'max': 90}})
    view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 100}})

    view.zoomTo()
    view.show()

In [110]:
from plaid.evaluation import RITAPerplexity

perplexity_calc = RITAPerplexity(device=device)
perplexities = perplexity_calc.batch_eval(sequences)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [111]:
perplexities

18.85411834716797