In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
from plaid.datasets import CATHShardedDataModule

dm = CATHShardedDataModule(storage_type="hdf5", dtype="fp32", seq_len=64)
dm.setup("fit")
dl = dm.train_dataloader()
batch = next(iter(dl))

seq_to_header = {v: k for k, v in dm.train_dataset.header_to_seq.items()}

In [3]:
from plaid.utils import LatentToSequence, LatentToStructure
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
l2seq = LatentToSequence(device)
l2struct = LatentToStructure(device)

loaded decoder from /home/amyxlu/plaid/plaid/utils/../../cached_tensors/decoder_vocab_21.ckpt


In [6]:
x, sequences = batch
headers = [seq_to_header[s] for s in sequences]
seqs = [s[:64] for s in sequences]

KeyError: 'IFARDPSKLPDYRMIISHPMWWDLIKARLTRYEYTSPSAFINDMRLVVQNCYDYNREESPFSTL'

In [12]:
cath_ids = [h.split("|")[2].split("/")[0] for h in headers]
cath_ids[:10]

['1h8mA00',
 '3tgnA03',
 '4g3aA00',
 '4rhiA00',
 '3cwzB03',
 '3m1cA04',
 '3qv0A00',
 '1rlhA02',
 '1u69A00',
 '1oznA00']

In [13]:
pdbstrs, metrics = l2struct.to_structure(x.to(torch.float32), seqs, num_recycles=4)

(Generating structure from latents..):   0%|          | 0/1 [00:00<?, ?it/s]

(Generating structure from latents..): 100%|██████████| 1/1 [00:19<00:00, 19.81s/it]


In [14]:
metrics

Unnamed: 0,plddt,ptm,aligned_confidence_probs,predicted_aligned_error
0,76.790588,0.76446,0.015625,2.91292
1,53.388073,0.426433,0.015625,12.359621
2,72.762375,0.730029,0.015625,3.67455
3,43.81052,0.126293,0.015625,21.946461
4,65.219398,0.621859,0.015625,7.050356
5,41.152943,0.206621,0.015625,16.407764
6,54.154346,0.28525,0.015625,20.142078
7,44.535862,0.275349,0.015625,13.622183
8,64.865936,0.606047,0.015625,7.890108
9,74.966965,0.725256,0.015625,3.258469


In [15]:
for i in range(10):
    with open("pred_from_latent_{}.pdb".format(i), "w") as f:
        f.write(pdbstrs[i])

In [None]:
# del l2struct

In [16]:
from plaid.esmfold import esmfold_v1
efold = esmfold_v1()

In [17]:
efold.to(device)
with torch.no_grad():
    pdbstr_inferred = efold.infer_pdbs(seqs)

In [18]:
for i in range(10):
    with open("pred_from_seq_{}.pdb".format(i), "w") as f:
        f.write(pdbstr_inferred[i])

In [20]:
import os
import shutil
for i, name in enumerate(cath_ids[:10]):
    original = "/shared/amyxlu/data/cath/full/dompdb/" + name
    shutil.copy(original, f"{i}_name.pdb")
    

In [19]:
from plaid.utils import run_tmalign
scores = []
for i in range(10):
    tmscore = run_tmalign(f"pred_from_latent_{i}.pdb", f"pred_from_seq_{i}.pdb")
    scores.append(tmscore)

scores

[0.3898,
 0.89816,
 0.4072,
 0.33646,
 0.82221,
 0.26164,
 0.20608,
 0.36169,
 0.78232,
 0.60658]

In [21]:
from plaid.utils import run_tmalign
scores2 = []
for i in range(10):
    tmscore = run_tmalign(f"pred_from_latent_{i}.pdb", f"{i}_name.pdb")
    scores2.append(tmscore)
print(scores2)

[0.42449, 0.49204, 0.26532, 0.13074, 0.46336, 0.22204, 0.20049, 0.33428, 0.33495, 0.22423]


In [22]:
from plaid.utils import run_tmalign
scores3 = []
for i in range(10):
    tmscore = run_tmalign(f"pred_from_seq_{i}.pdb", f"{i}_name.pdb")
    scores3.append(tmscore)
print(scores3)

[0.21496, 0.49122, 0.15789, 0.13635, 0.4812, 0.20423, 0.21574, 0.45479, 0.37716, 0.22346]
