In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import k_diffusion as K
from pathlib import Path

model_dir = Path(f"/shared/amyxlu/kdplaid/")
model_id = "etbey7fe"
filename = K.model_loading_utils.infer_ckpt_path(model_dir, model_id)

In [3]:
model, model_config = K.model_loading_utils.load_model(filename)

Loading checkpoint from /shared/amyxlu/kdplaid/checkpoints/etbey7fe/03715000.pth


In [19]:
from k_diffusion.callback import SampleCallback
import torch

config = K.config.SampleCallbackConfig(
    batch_size=64,
    seq_len=64,
    n_to_sample=128,
    save_to_disk=False,
    log_to_wandb=False,
)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sampler = SampleCallback(
    model=model,
    config=config,
    model_config=model_config,
    is_wandb_setup=True,
    device=device,
)


In [20]:
sampled_latent, sampled_raw, _ = sampler.sample_latent(
    clip_range=config.clip_range, save=config.save_to_disk, log_wandb_stats=config.log_to_wandb, return_raw=True
)

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

Sampled latent mean: tensor(2.0054) std: tensor(67.7701)
Raw sampled latent mean: tensor(0.0362) std: tensor(0.8872)


In [21]:
if not config.n_to_construct == -1:
    sampled_latent = sampled_latent[torch.randperm(sampled_latent.shape[0])][
        : config.n_to_construct
    ]

print("Constructing sequences...")
_, _, strs, _ = sampler.construct_sequence(
    sampled_latent,
    calc_perplexity=config.calc_perplexity,
    save_to_disk=config.save_to_disk,
    log_to_wandb=config.log_to_wandb,
)

print(strs)

Constructing sequences...
percentage similarty to argmax idx: 0.887
Perplexity: 4.311
['ADGAGADDAGAAAAAADAADAGGAAAAGAAAAGGGAAADATGAADAAAAAAAAAAAAGAATGAA', 'GGDTTGDTTGTTTTTTGTTGGTTTTTTDTATTTDTTTTGTGDTTDGDTTTTTTTTTTTDTTTTT', 'ATDAAAAATAAAATDDADTATDADTAAADDAAAATGATAADTAATDTDATTATATDDTAADTAA', 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT', 'TTTTTTATTATTTGTTTTTTTTTTTTTTTTTTTTTTTTTATATTTTAATATATTAATTAATTTT', 'TTTTTTATAATTTTTDTTTTAATTTAATTTATDTDTTTTATTTTTTAATTAATTTTTTATTTTD', 'TTDTATDAAGTAGADATDTAAGDAAAAGAAADTTGAAADATAATTATAADAGDADTAGDTTTAD', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGAAAAAAAAAAAAAAA', 'TDGGGGTTTDGGTDGTTTGTGTTTTGGTTTDGGTMGADTTTGDGDTGGDDGGTGGTDTGTTTTT', 'AADAAAADAAAAAAAADAAAAAAAAAAAAAAAAAAADAAADDAAAAAAADAAAAAAAADAAAAA', 'AATTTTAAAATAATADAAAATDATAAADDTAATAAAAAAATAAADDADTAATAAAAADAAAADA', 'ADDDDDDAADDADDDDDDADDDDDADDDDDDDDDDDADDDDDDDDDDDDADDDDDDDDDDDDDD', 'DTDTDDDDTTTDTTTDDTDTTDDDDTDDDDTDTTTTDDTDDTTDDTDTDTTTTDDDTDTDTTTD', 'ADDDDADADADDDDADDDADDDDADAAD

In [22]:
# print("Constructing sequences from unnormalized latent...")
# _, _, strs = sampler.construct_sequence(
#     sampled_raw_expanded,
#     calc_perplexity=config.calc_perplexity,
#     save_to_disk=config.save_to_disk,
#     log_to_wandb=config.log_to_wandb,
#     wandb_prefix="raw_unnormalized_"
# )

print("Constructing structures...")
pdb_strs, metrics, _ = sampler.construct_structure(
    sampled_latent,
    strs,
    save_to_disk=config.save_to_disk,
    log_to_wandb=config.log_to_wandb,
)
print(metrics)


Constructing structures...


(Generating structure from latents..): 100%|██████████| 1/1 [01:12<00:00, 72.06s/it]

        plddt       ptm  aligned_confidence_probs  predicted_aligned_error
0   35.382561  0.123928                  0.015625                16.697651
1   35.380360  0.096650                  0.015625                18.911467
2   35.587723  0.083282                  0.015625                21.584911
3   58.036613  0.195988                  0.015625                12.102503
4   40.899563  0.107495                  0.015625                17.771069
5   33.387318  0.094670                  0.015625                18.864834
6   36.881855  0.094901                  0.015625                18.680901
7   37.283455  0.118193                  0.015625                15.929408
8   37.794956  0.084063                  0.015625                21.613316
9   41.267616  0.169703                  0.015625                11.281233
10  33.464329  0.083977                  0.015625                20.255028
11  49.891956  0.183397                  0.015625                13.905256
12  37.080059  0.109578  




In [None]:
def view_py3dmol(pdbpaths):
    import py3Dmol

    for pdbpath in pdbpaths:
        with open(pdbpath) as ifile:
            system = "".join([x for x in ifile])
        view = py3Dmol.view(width=400, height=300)
        view.addModelsAsFrames(system)
        view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
        view.zoomTo()
        view.show()

In [29]:
import py3Dmol

system = pdb_strs[4]
view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(system)
view.setStyle({"model": -1}, {"cartoon": {"color": "spectrum"}})
view.zoomTo()
view.show()

In [None]:
import py3Dmol

py3Dmol.view()