## Example of GraphEBM: Random Generation

In [1]:
import os
import torch
from torch_geometric.data import DenseDataLoader
from rdkit import RDLogger

from dig.ggraph.dataset import QM9, ZINC250k
from dig.ggraph.method import GraphEBM
from dig.ggraph.evaluation import RandGenEvaluator



In [2]:
device = torch.device('cuda:1')

#### Prepare Dataset

In [3]:
dataset = ZINC250k(one_shot=True, root='./')
splits = dataset.get_split_idx()
train_set = dataset[splits['train_idx']]
train_dataloader = DenseDataLoader(train_set, batch_size=128, shuffle=True, num_workers=0)

making raw files: ./raw
Downloading https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc250k_property.csv
Processing...
making processed files: ./zinc250k_property/processed_oneshot
Done!


#### Training

Before starting training, we need to define an object `graphebm` as an instance of class `GraphEBM`.

**Skip training**: You can also download our trained models on [ZINC250k](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM/GraphEBM_zinc250k_uncond.pt) and [QM9](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM/GraphEBM_qm9_uncond.pt). Note: We found that we might have an error about loading the trained models if we download them with `wget`. If you have the same error, please download the models manually.

In [4]:
graphebm = GraphEBM(n_atom=38, n_atom_type=10, n_edge_type=4, hidden=64, device=device)

In [None]:
graphebm.train_rand_gen(train_dataloader, lr=1e-4, wd=0, max_epochs=20, c=0, ld_step=150, ld_noise=0.005, ld_step_size=30, clamp=True, alpha=1, save_interval=1, save_dir='./checkpoints')

#### Generation

To construct molecules from our generated node matrices and adjacency tensors, we need the `atomic_num_list`, which denotes what atom each dimension of the node matrix corresponds to. `0` denotes the virtual atom type.

In [6]:
### Ignore info output by RDKit
RDLogger.DisableLog('rdApp.error') 
RDLogger.DisableLog('rdApp.warning')

atomic_num_list = dataset.atom_list  # [6, 7, 8, 9, 15, 16, 17, 35, 53, 0] for ZINC250k
gen_mols = graphebm.run_rand_gen(checkpoint_path='./GraphEBM_zinc250k_uncond.pt', n_samples=10000, c=0, ld_step=150, ld_noise=0.005, ld_step_size=30, clamp=True, atomic_num_list=atomic_num_list)

Loading paramaters from ./GraphEBM_zinc250k_uncond.pt
Initializing samples...
Generating samples...


#### Evaluations

In [7]:
train_smiles = [data.smile for data in dataset[splits['train_idx']]]
res_dict = {'mols':gen_mols, 'train_smiles': train_smiles}
evaluator = RandGenEvaluator()
results = evaluator.eval(res_dict)
print(results)

Valid Ratio: 10000/10000 = 100.00%
Unique Ratio: 9805/10000 = 98.05%
Novel Ratio: 10000/10000 = 100.00%
{'valid_ratio': 100.0, 'unique_ratio': 98.05, 'novel_ratio': 100.0}
