## Example of GraphEBM: Goal-Directed Generation

In [1]:
import os
import torch
from torch_geometric.data import DenseDataLoader
from rdkit import RDLogger

from dig.ggraph.dataset import ZINC250k, ZINC800
from dig.ggraph.method import GraphEBM
from dig.ggraph.evaluation import PropOptEvaluator, ConstPropOptEvaluator



In [2]:
device = torch.device('cuda:1')

#### Prepare Dataset

In [4]:
dataset_qed = ZINC250k(one_shot=True, root='./zinc250k_qed', prop_name='qed')
splits = dataset_qed.get_split_idx()
train_set_qed = dataset_qed[splits['train_idx']]
train_dataloader_qed = DenseDataLoader(train_set_qed, batch_size=128, shuffle=True, num_workers=0)

dataset_plogp = ZINC250k(one_shot=True, root='./zinc250k_plogp', prop_name='penalized_logp')
splits = dataset_plogp.get_split_idx()
train_set_plogp = dataset_plogp[splits['train_idx']]
train_dataloader_plogp = DenseDataLoader(train_set_plogp, batch_size=128, shuffle=True, num_workers=0)

Processing...
making processed files: zinc250k_plogp/zinc250k_property/processed_oneshot
Done!


#### Training

Before starting training, we need to define an object `graphebm` as an instance of class `GraphEBM`.

**Skip training**: You can also download our trained models for goal-directed generation towards [QED](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM/GraphEBM_zinc250k_goal_qed.pt) and [plogp](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM/GraphEBM_zinc250k_goal_plogp.pt). Note: We found that we might have an error about loading the trained models if we download them with `wget`. If you have the same error, please download the models manually.

In [5]:
graphebm = GraphEBM(n_atom=38, n_atom_type=10, n_edge_type=4, hidden=64, device=device)

In [None]:
graphebm.train_goal_directed(train_dataloader_qed, lr=1e-4, wd=0, max_epochs=20, c=0, ld_step=150, ld_noise=0.005, ld_step_size=30, clamp=True, alpha=1, save_interval=1, save_dir='./checkpoints_goal_qed')

In [None]:
graphebm.train_goal_directed(train_dataloader_plogp, lr=1e-4, wd=0, max_epochs=20, c=0, ld_step=150, ld_noise=0.005, ld_step_size=30, clamp=True, alpha=1, save_interval=1, save_dir='./checkpoints_goal_qed')

#### Generation

To construct molecules from our generated node matrices and adjacency tensors, we need the `atomic_num_list`, which denotes what atom each dimension of the node matrix corresponds to. `0` denotes the virtual atom type.

In [7]:
### Ignore info output by RDKit
RDLogger.DisableLog('rdApp.error') 
RDLogger.DisableLog('rdApp.warning')

atomic_num_list = dataset_qed.atom_list
gen_mols = graphebm.run_rand_gen(checkpoint_path='./GraphEBM_zinc250k_goal_qed.pt', n_samples=10000, c=0, ld_step=150, ld_noise=0.005, ld_step_size=30, clamp=True, atomic_num_list=atomic_num_list)

Loading paramaters from ./GraphEBM_zinc250k_goal_qed.pt
Initializing samples...
Generating samples...


#### Property Optimization

Running property optimization and the next constraint property optimization takes more time. 

In [None]:
RDLogger.DisableLog('rdApp.error') 
RDLogger.DisableLog('rdApp.warning')

train_smiles = [data.smile for data in dataset_qed[splits['train_idx']]]
initialization_loader_qed = DenseDataLoader(train_set_qed, batch_size=10000, shuffle=False, num_workers=0)

save_mols_list, prop_list = graphebm.run_prop_opt('./GraphEBM_zinc250k_goal_qed.pt', initialization_loader=initialization_loader_qed, c=0, ld_step=300, ld_noise=0.005, ld_step_size=0.2, clamp=True, atomic_num_list=atomic_num_list, train_smiles=train_smiles)
print(prop_list)
res_dict = {'mols':save_mols_list}
evaluator = PropOptEvaluator()
results = evaluator.eval(res_dict)
print(results)

#### Constraint Property Optimization

In [None]:
dataset_zinc800 = ZINC800(one_shot=True, root='./zinc800_plogp')
initialization_dataloader = DenseDataLoader(dataset_zinc800, batch_size=800, shuffle=True, num_workers=0)

RDLogger.DisableLog('rdApp.error') 
RDLogger.DisableLog('rdApp.warning')

train_smiles = [data.smile for data in dataset_zinc800]

mols_0_list, mols_2_list, mols_4_list, mols_6_list, imp_0_list, imp_2_list, imp_4_list, imp_4_list = graphebm.run_const_prop_opt('./GraphEBM_zinc250k_goal_qed.pt', initialization_loader=initialization_dataloader, c=0, ld_step=500, ld_noise=0.005, ld_step_size=0.2, clamp=True, atomic_num_list=atomic_num_list, train_smiles=train_smiles)
res_dict = {'inp_smiles': train_smiles, 'mols_0':mols_0_list, 'mols_2': mols_2_list, 'mols_4': mols_4_list, 'mols_6': mols_6_list}
evaluator = ConstPropOptEvaluator()
results = evaluator.eval(res_dict)
print(results)