In [1]:
%load_ext autoreload
%autoreload 2

# Load EMA weights

In [2]:
from pathlib import Path
from omegaconf import OmegaConf

import torch
from plaid.diffusion import FunctionOrganismDiffusion
from plaid.denoisers import FunctionOrganismUDiT, DenoiserKwargs
from plaid.constants import COMPRESSION_INPUT_DIMENSIONS, COMPRESSION_SHORTEN_FACTORS

device = torch.device("cuda")

  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")


In [3]:
model_id = "5j007z42"

ckpt_dir = Path("/data/lux70/plaid/checkpoints/plaid-compositional") 
model_path = ckpt_dir / model_id / "last.ckpt"
config_path = ckpt_dir / model_id / "config.yaml"

cfg = OmegaConf.load(config_path)

In [4]:
compression_model_id = cfg['compression_model_id']
shorten_factor = COMPRESSION_SHORTEN_FACTORS[compression_model_id]
input_dim = COMPRESSION_INPUT_DIMENSIONS[compression_model_id]

In [5]:
denoiser_kwargs = cfg.denoiser
denoiser_kwargs.pop("_target_")
denoiser = FunctionOrganismUDiT(**denoiser_kwargs, input_dim=input_dim)

In [6]:
# lask.ckpt automatically links to the EMA

ckpt = torch.load(model_path)
ckpt.keys()

  ckpt = torch.load(model_path)


dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])

In [7]:
mod_state_dict = {}
for k, v in ckpt['state_dict'].items():
    if k[:16] == "model._orig_mod.":
        mod_state_dict[k[16:]] = v

In [8]:
denoiser.load_state_dict(mod_state_dict)

<All keys matched successfully>

In [9]:
diffusion_kwargs = cfg.diffusion
diffusion_kwargs.pop("_target_")

# diffusion_kwargs['beta_scheduler_name'] = "sigmoid"
# diffusion_kwargs['sampling_timesteps'] = 500

diffusion = FunctionOrganismDiffusion(model=denoiser,**diffusion_kwargs)

In [10]:
from cheap.pretrained import load_model_from_id
cheap_model = load_model_from_id(compression_model_id)
_ = cheap_model.to(device)

from cheap.proteins import LatentToSequence,LatentToStructure
latent_to_sequence = LatentToSequence()
latent_to_sequence.to(device)

latent_to_structure = LatentToStructure()
latent_to_structure.to(device)

  ckpt = torch.load(checkpoint_fpath)


Using tanh layer at bottleneck...
Finished loading HPCT model with shorten factor 2 and 32 channel dimensions.


  ckpt = torch.load(ckpt_path)
Creating ESMFold...
ESMFold model loaded in 33.80 seconds.


<cheap.proteins.LatentToStructure at 0x7fa07c1fbc70>

# Sample

In [11]:
# organism_idx = org_df[org_df.organism_id == "HUMAN"].organism_index.iloc[0]
# function_idx = go_df[go_df.GO_term == "carbohydrate metabolic process"].GO_idx.iloc[0]
# print(organism_idx, function_idx)

In [12]:
device = torch.device("cuda")
diffusion = diffusion.to(device)

In [22]:
from plaid.datasets import NUM_ORGANISM_CLASSES, NUM_FUNCTION_CLASSES
import hydra
from omegaconf import OmegaConf
from plaid.pipeline import SampleLatent

cfg = OmegaConf.load("/homefs/home/lux70/code/plaid/configs/pipeline/sample_latent.yaml")

N, L = 64, 88
assert L % 4 == 0
shape = (N, L, input_dim)

organism_idx = NUM_ORGANISM_CLASSES
function_idx = 28  # protein kinase activity
cond_scale = 8.

cfg['sample_scheduler'] = "dpmpp_3m_sde"
cfg['sampling_timesteps'] = 30
cfg['function_idx'] = function_idx
cfg['cond_scale'] = cond_scale

solver = SampleLatent(**cfg)
x = solver.sample()

OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU 0 has a total capacity of 79.15 GiB of which 45.62 MiB is free. Including non-PyTorch memory, this process has 79.10 GiB memory in use. Of the allocated memory 78.30 GiB is allocated by PyTorch, and 320.53 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
print(sampled_latent.shape)
final_sample = sampled_latent[:, -1, :, :]
print(final_sample.shape)

print(final_sample.max(), final_sample.min())

In [None]:
sampled_uncompressed = cheap_model.decode(final_sample, downsampled_mask=None)
print(sampled_uncompressed.min(), sampled_uncompressed.max())

In [None]:
# import pickle as pkl
# with open("test_sample.pkl", "wb") as f:
#     pkl.dump(sampled_uncompressed,f)

In [None]:
from cheap.utils import LatentScaler
latent_scaler = LatentScaler()
sampled_unscaled = latent_scaler.unscale(sampled_uncompressed) 

In [None]:
print(sampled_unscaled.shape)
print(sampled_unscaled.max(), sampled_unscaled.min())

In [None]:
sequences = latent_to_sequence.to_sequence(sampled_unscaled)[-1]

In [None]:
sequences[:3]

In [None]:
pdb_strs, raw_outputs = latent_to_structure.to_structure(sampled_unscaled, return_raw_outputs=True, sequences=sequences, batch_size=32, num_recycles=1)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

_ = sns.distplot(raw_outputs['plddt'].mean(dim=-1).mean(dim=-1).cpu().numpy(), bins=30)

In [None]:
import py3Dmol

# for i in range(len(pdb_strs)):
# for i in range(10,20): 
# for i in range(0, 10): 
for i in range(20,30): 
    view = py3Dmol.view(width=600, height=600)
    view.addModelsAsFrames(pdb_strs[i])
    
    # Apply the plDDT color scheme
    # view.setStyle({'cartoon': {'color': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 100}}})
    view.setStyle({'cartoon': {'color': {'prop': 'b', 'gradient': 'roygb', 'min': 50, 'max': 90}}})
    
    # # Add surface representation with plDDT-based color
    view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 50, 'max': 90}})
    # view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'colorscheme': {'prop': 'b', 'gradient': 'roygb', 'min': 0, 'max': 100}})

    view.zoomTo()
    view.show()

In [None]:
from plaid.evaluation import RITAPerplexity

perplexity_calc = RITAPerplexity(device=device)
perplexities = perplexity_calc.batch_eval(sequences)

In [None]:
perplexities = [perplexity_calc.calc_perplexity(s) for s in sequences]

In [None]:
import seaborn as sns
_ = sns.distplot(perplexities, bins=30)

In [None]:
import numpy as np
np.mean(perplexities)

In [None]:
from plaid.utils import write_to_fasta


write_to_fasta(
    sequences,
    "/homefs/home/lux70/generated.fasta",
    [f"sample{i}" for i in range(len(sequences))]
)