## Example of GraphEBM: Compositional Generation

In [1]:
import os
import torch
from torch_geometric.data import DenseDataLoader
from rdkit import RDLogger

from dig.ggraph.dataset import ZINC250k
from dig.ggraph.method import GraphEBM



In [2]:
device = torch.device('cuda:1')

#### Prepare Dataset

In [3]:
dataset = ZINC250k(one_shot=True, root='./')
atomic_num_list = dataset.atom_list

#### Generation

To generate molecules towards multiple objectives in a compositional manner, we need two trained models based on our goal-directed generation.

**Skip training**: You can also download our trained models for goal-directed generation towards [QED](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM/GraphEBM_zinc250k_goal_qed.pt) and [plogp](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM/GraphEBM_zinc250k_goal_plogp.pt). Note: We found that we might have an error about loading the trained models if we download them with `wget`. If you have the same error, please download the models manually.

In [4]:
graphebm = GraphEBM(n_atom=38, n_atom_type=10, n_edge_type=4, hidden=64, device=device)

In [5]:
### Ignore info output by RDKit
RDLogger.DisableLog('rdApp.error') 
RDLogger.DisableLog('rdApp.warning')

gen_mols = graphebm.run_comp_gen(checkpoint_path_qed='./GraphEBM_zinc250k_goal_qed.pt', checkpoint_path_plogp='./GraphEBM_zinc250k_goal_plogp.pt', n_samples=10000, c=0, ld_step=300, ld_noise=0.005, ld_step_size=30, clamp=True, atomic_num_list=atomic_num_list)

Loading paramaters from ./GraphEBM_zinc250k_goal_qed.pt
Loading paramaters from ./GraphEBM_zinc250k_goal_plogp.pt
Initializing samples...
Generating samples...
