In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from pathlib import Path
from plaid.compression.hourglass_vq import HourglassVQLightningModule
import os

In [3]:
# soft-violet-83
dirpath = Path("/homefs/home/lux70/storage/plaid/checkpoints/hourglass_vq/2024-03-05T06-20-52")
print(os.listdir(str(dirpath)))
model = HourglassVQLightningModule.load_from_checkpoint(dirpath / "last.ckpt")

['last-v1.ckpt', 'epoch=263-step=105000.ckpt', 'epoch=363-step=145000.ckpt', 'last.ckpt']


In [4]:
from plaid.datasets import CATHStructureDataModule

shard_dir = "/homefs/home/lux70/storage/data/cath/shards/"
pdb_dir = "/data/bucket/lux70/data/cath/dompdb"
# shard_dir = "/homefs/home/lux70/storage/data/rocklin/shards/"
# pdb_dir = "/data/bucket/lux70/data/rocklin/structures/"

max_seq_len=256
dm = CATHStructureDataModule(
    shard_dir,
    pdb_dir,
    seq_len=max_seq_len,
    batch_size=32,
    max_num_samples=32,
    shuffle_val_dataset=False
) 
    
dm.setup()
val_dataloader = dm.val_dataloader()
batch = next(iter(val_dataloader))
print(len(val_dataloader.dataset))

32


In [5]:
# grab saved embedding
import torch

device = torch.device("cuda")
x = batch[0].to(device)
sequences = batch[1]
x = x.to(device)

# make mask
from plaid.esmfold.misc import batch_encode_sequences
_, mask, _, _, _ = batch_encode_sequences(sequences)

from plaid.transforms import trim_or_pad_batch_first
mask = mask.to(device)
model = model.to(device)
mask = trim_or_pad_batch_first(mask, pad_to=max_seq_len, pad_idx=0)

# scale
from plaid.utils import LatentScaler
latent_scaler = LatentScaler()
x_norm = latent_scaler.scale(x)

In [6]:
print(x_norm.shape)
print(x_norm.max())

torch.Size([32, 256, 1024])
tensor(1.1303, device='cuda:0')


In [7]:
# model forward pass!!

recons_norm, loss, log_dict, quant_out = model(x_norm, mask.bool(), log_wandb=False)# , return_vq_output=True)
print(torch.mean((recons_norm - x_norm) ** 2))

tensor(0.0376, device='cuda:0', grad_fn=<MeanBackward0>)


In [8]:
N, L, _ = x.shape
print(quant_out['min_encoding_indices'].shape)
print(quant_out['min_encoding_indices'].reshape(N, -1).shape)
print(quant_out['min_encoding_indices'].reshape(N, L, -1).shape)

torch.Size([131072, 1])
torch.Size([32, 4096])
torch.Size([32, 256, 16])


In [None]:
from plaid.proteins import LatentToStructure

del model  # save some GPU space

structure_constructor = LatentToStructure()
structure_constructor.to(device)

recons = latent_scaler.unscale(recons_norm)
recons_struct = structure_constructor.to_structure(recons, sequences, return_raw_features=True, batch_size=4, num_recycles=1)
orig_struct = structure_constructor.to_structure(x, sequences, return_raw_features=True, batch_size=4, num_recycles=1)

loading esmfold model...
ESMFold model created in 46.58 seconds.


(Generating structure): 100%|████████████████████████████████████| 8/8 [01:12<00:00,  9.00s/it]
(Generating structure):  25%|█████████                           | 2/8 [00:18<00:54,  9.09s/it]

In [None]:
for i, pdbstr in enumerate(recons_struct[0]):
    with open(f"/homefs/home/lux70/cache/recons_pred_{i}.pdb", "w") as f: 
        f.write(pdbstr)

In [None]:
for i, pdbstr in enumerate(orig_struct[0]):
    with open(f"/homefs/home/lux70/cache/orig_pred_{i}.pdb", "w") as f: 
        f.write(pdbstr)

In [None]:
import glob
recons_pdbs = glob.glob("/homefs/home/lux70/cache/recons_*.pdb")
orig_pdbs = glob.glob("/homefs/home/lux70/cache/orig_*.pdb")
print(orig_pdbs)

In [None]:
from plaid.utils import run_tmalign

for orig, recons in zip(orig_pdbs, recons_pdbs):
    print(run_tmalign(orig, recons))

In [None]:
# import einops
# tens = torch.masked_select(
#     recons_struct[1]['plddt'],
#     einops.repeat(mask.bool(), "b l -> b l c", c=recons_struct[1]['plddt'].shape[-1])
# )
# print(tens.shape)
# print(torch.mean(tens))

In [None]:
from plaid.utils import view_py3Dmol
import py3Dmol

for i in range(x.shape[0]):
    view = py3Dmol.view(width=400, height=300)
    view.addModelsAsFrames(orig_struct[0][i])
    view.addModelsAsFrames(recons_struct[0][i])
    view.setStyle({"model": 0}, {"cartoon": {"color": "orange"}})
    view.setStyle({"model": 1}, {"cartoon": {"color": "blue"}})
    view.zoomTo()
    view.show()

In [None]:
recons_struct[-1]['positions'].shape

In [None]:
from plaid.evaluation import drmsd

print(recons_struct[-1]['positions'].shape)
r = drmsd(recons_struct[-1]['positions'], orig_struct[-1]['positions'])
print(r.shape)
print(r.mean(dim=(-1)))
print(r.mean(dim=(0, -1)))