In [1]:
import ase
import numpy as np
import ase.io as aio

import torch

In [2]:
from models.bond_predictor import BondPredictor
from models.scaffolded_model import ScaffoldedMolDiff
from utils.data import traj_to_ase
from utils.reconstruct import MolReconsError, reconstruct_from_generated_with_edges
from utils.sample import seperate_outputs
from utils.transforms import FeaturizeMol, make_data_placeholder

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
atoms = aio.read("./penicillin_analogues.xyz", index=":")
core_ids = np.load("./penicillin_core_ids.npy")

In [4]:
penicillin_core = atoms[0][core_ids[0]]

In [5]:
core_pos = penicillin_core.get_positions()-penicillin_core.get_center_of_mass()
core_pos = torch.tensor(core_pos, dtype=torch.float32).to("cuda")

In [6]:
ckpt = torch.load("./ckpt/MolDiff.pt", map_location="cuda")
train_config = ckpt["config"]
featurizer = FeaturizeMol(
    train_config.chem.atomic_numbers,
    train_config.chem.mol_bond_types,
    use_mask_node=train_config.transform.use_mask_node,
    use_mask_edge=train_config.transform.use_mask_edge,
)
model = ScaffoldedMolDiff(
    config=train_config.model,
    num_node_types=featurizer.num_node_types,
    num_edge_types=featurizer.num_edge_types,
).to("cuda")
model.load_state_dict(ckpt["model"])
model.eval()

  ckpt = torch.load("./ckpt/MolDiff.pt", map_location="cuda")


ScaffoldedMolDiff(
  (pos_transition): ContigousTransition()
  (node_transition): GeneralCategoricalTransition()
  (edge_transition): GeneralCategoricalTransition()
  (node_embedder): Linear(in_features=8, out_features=246, bias=False)
  (edge_embedder): Linear(in_features=6, out_features=54, bias=False)
  (time_emb): Sequential(
    (0): GaussianSmearing()
  )
  (denoiser): NodeEdgeNet(
    (distance_expansion): GaussianSmearing()
    (node_blocks_with_edge): ModuleList(
      (0-5): 6 x NodeBlock(
        (node_net): MLP(
          (net): Sequential(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=256, out_features=256, bias=True)
          )
        )
        (edge_net): MLP(
          (net): Sequential(
            (0): Linear(in_features=64, out_features=256, bias=True)
            (1): LayerNorm((256,), eps=1e-05, elementwise_aff

In [7]:
core_numbers = penicillin_core.get_atomic_numbers()
core_node_types = [featurizer.ele_to_nodetype[x] for x in core_numbers]
core_node_types = torch.tensor(core_node_types, dtype=torch.long).to("cuda")

In [8]:
batch_size = 16
mol_size = 40
batch_holder = make_data_placeholder(
            n_graphs=batch_size, device="cuda", max_size=mol_size
        )

In [9]:
outputs = model.sample(
            n_graphs=batch_size,
            batch_node=batch_holder["batch_node"],
            halfedge_index=batch_holder["halfedge_index"],
            batch_halfedge=batch_holder["batch_halfedge"],
            bond_predictor=None,
            guidance=None,
            scaffold_positions=core_pos,
            scaffold_node_types=core_node_types,
            readd_noise=True
        )

100%|██████████| 1000/1000 [01:23<00:00, 11.91it/s]


In [10]:
outputs = {
            key: [v.cpu().numpy() for v in value] for key, value in outputs.items()
        }
batch_node = batch_holder["batch_node"].cpu().numpy()
halfedge_index = batch_holder["halfedge_index"].cpu().numpy()
batch_halfedge = batch_holder["batch_halfedge"].cpu().numpy()
output_list = seperate_outputs(
    outputs, batch_size, batch_node, halfedge_index, batch_halfedge
)

In [11]:
trajectories = []
for output_mol in output_list:
    traj = traj_to_ase(output_mol["traj"], featurizer, -1)
    trajectories.append(traj)

In [12]:
aio.write("gen_test.xyz", trajectories)