In [1]:
import sys
sys.path.append("../..")

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import tqdm

In [3]:
ckpts = "small_w512 small_w1024 small_w1536 base_w512 base_w1024 base_w1536 big_w512 big_w1024 original".split()

In [None]:
DIFF_STEPS = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
from train import Config, LitQuetzal

def gen(ckpt, bsz=8, seed=0):
    ckpt_path = f"../../checkpoints/{ckpt}.ckpt"

    lit = LitQuetzal.load_from_checkpoint(ckpt_path, map_location=DEVICE)
    model = lit.ema.module
    model.eval();
    kwargs = {
        "device": DEVICE,
        "num_steps": DIFF_STEPS,
        "pbar": True,
        "max_len": 32,
    }
    torch.manual_seed(seed)
    return model.generate(bsz, **kwargs)

outs = [gen(ckpt) for ckpt in ckpts]

 66%|██████▌   | 21/32 [00:02<00:01,  8.79it/s]
 66%|██████▌   | 21/32 [00:02<00:01,  9.51it/s]
 72%|███████▏  | 23/32 [00:02<00:00,  9.38it/s]
 66%|██████▌   | 21/32 [00:02<00:01,  9.33it/s]
 66%|██████▌   | 21/32 [00:02<00:01,  9.48it/s]
 66%|██████▌   | 21/32 [00:02<00:01,  9.32it/s]
 72%|███████▏  | 23/32 [00:02<00:00,  9.33it/s]
 66%|██████▌   | 21/32 [00:02<00:01,  9.46it/s]
 72%|███████▏  | 23/32 [00:02<00:00,  9.29it/s]


In [5]:
torch.save(outs, "gen_30_midrotate.pt")

In [10]:
from draw import make_html
for b_idx in range(8):

    import py3Dmol
    ncols = 3
    nrows = 3

    view = py3Dmol.view(width=2880, height=2880, viewergrid=(nrows, ncols))

    ref = outs[-1][0][b_idx]
    for i, a in enumerate(ref.atoms):
        if a == 0:
            break
    c = ref.coords[:i]
    mean = c.mean(dim=0, keepdim=True)

    c = c - mean
    U, _, _ = np.linalg.svd(c.T.numpy())
    if np.linalg.det(U) < 0:
        U[:, -1] *= -1
    U = torch.tensor(U)

    for i in range(nrows*ncols):
        row = i // ncols
        col = i % ncols
        # print(row, col, ckpts[i])

        M = outs[i][0][b_idx]
        M.coords = (M.coords - mean) @ U
        view = M.show(view=view, viewer=(row, col), zoom=True)

    path = f"seeded_30_{b_idx}.html"
    make_html(view, path)
    # view.show()