In [27]:
import torch
from torchdrug import datasets, core, models, tasks, utils
from torch import optim,utils

# The following functions takes 8 mins
dataset = datasets.ZINC250k("./molecule-datasets/", kekulize=True,
                            node_feature="symbol")

Loading ./molecule-datasets/250k_rndm_zinc_drugs_clean_3.csv:  50%|█████     | 249456/498911 [00:01<00:01, 138004.60it/s]
Constructing molecules from SMILES: 100%|██████████| 249455/249455 [03:06<00:00, 1336.08it/s]


In [32]:

subset = dataset
subset.data = dataset.data[:3]


In [30]:
dataset.num_bond_type

3

In [21]:
print(subset.atom_types)

[6, 7, 8, 9, 16, 17, 35, 53]


In [35]:
import torch
from torchdrug import datasets, core, models, tasks, utils
from torch import optim,utils

import pickle
with open("zinc250k.pkl", "wb") as fout:
    pickle.dump(subset, fout)
with open("zinc250k.pkl", "rb") as fin:
    dataset = pickle.load(fin)

In [53]:
print(f"dim:{dataset.node_feature_dim}, num_relation={dataset.num_bond_type}, ")

dim:67, num_relation=3, 


In [31]:
dataset.transform = None
# (1)
model = models.RGCN(input_dim=18,
                    num_relation=3,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)
# (2)
task = tasks.GCPNGeneration(model, [6, 7, 8, 9, 15, 16, 17, 35, 53], max_edge_unroll=12,
                            max_node=38, criterion="nll")


In [33]:
optimizer = optim.Adam(task.parameters(), lr = 1e-3)
solver = core.Engine(task, subset, None, None, optimizer,
                     gpus=(0,), batch_size=128, log_interval=10)

23:12:33   Preprocess training set
23:12:33   {'batch_size': 128,
 'class': 'core.Engine',
 'gpus': (0,),
 'gradient_interval': 1,
 'log_interval': 10,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'differentiable': False,
               'eps': 1e-08,
               'foreach': None,
               'fused': False,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 10,
          'atom_types': [6, 7, 8, 9, 15, 16, 17, 35, 53],
          'baseline_momentum': 0.9,
          'class': 'tasks.GCPNGeneration',
          'criterion': 'nll',
          'gamma': 0.9,
          'hidden_dim_mlp': 128,
          'max_edge_unroll': 12,
          'max_node': 38,
          'model': {'activation': 'relu',
                    'batch_norm': False,
         

In [34]:
solver.load("gcpn_zinc250k_5epoch.pkl")
results = task.generate(num_sample=5, max_resample=5)
print(results.to_smiles())

23:12:35   Load checkpoint from gcpn_zinc250k_5epoch.pkl




['CC(=NC(N)=C(C)C(=O)Cl)C(N)=O', 'CN1N=C(C(=O)O)C=C1N1CCC(O)C1', 'ClC1=CC=C(OCCN2CCNCC2)C=C1', 'CCN=C1SC=C(C)N1C(=O)C1CCNC1', 'CNCC1CN(C2=CC=C(C)C=C2)C(=O)NC2=CC=CC=C2OCCCO1']


In [42]:
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)



def simple_molecule_generation():
    with open("zinc250k.pkl", "rb") as fin:
        dataset = pickle.load(fin)
    dataset.transform = None
    model = models.RGCN(input_dim=18,
                        num_relation=3,
                        hidden_dims=[256, 256, 256, 256], batch_norm=False)
    task = tasks.GCPNGeneration(model, [6, 7, 8, 9, 15, 16, 17, 35, 53], max_edge_unroll=12,
                                max_node=38, criterion="nll")
    optimizer = optim.Adam(task.parameters(), lr = 1e-3)
    solver = core.Engine(task, dataset, None, None, optimizer,
                            gpus=(0,), batch_size=128, log_interval=10)
    solver.load("gcpn_zinc250k_5epoch.pkl")
    results = task.generate(num_sample=1, max_resample=10)
    return results.to_smiles()

In [44]:
simple_molecule_generation()

['O=S(=O)(NC1=CC=CN=C1Cl)C1=CC=CC=C1Cl']

In [57]:
import logging
logging.disable(logging.NOTSET)

In [None]:
dataset = datasets.ZINC250k("./molecule-datasets/", kekulize=True)

In [60]:
# define the task and load the pretrained model
dataset.transform = None
model = models.RGCN(input_dim=18,
                    num_relation=3,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)
task = tasks.GCPNGeneration(model, [6, 7, 8, 9, 15, 16, 17, 35, 53], max_edge_unroll=12,
                            max_node=38, task="plogp", criterion="ppo",
                            reward_temperature=1,
                            agent_update_interval=3, gamma=0.9)
optimizer = optim.Adam(task.parameters(), lr = 1e-5)
solver = core.Engine(task, dataset, None, None, optimizer,
                     gpus=(0,), batch_size=512, log_interval=10)
solver.load('gcpn_zinc250k_5epoch.pkl', load_optimizer=False)

23:42:06   Preprocess training set
23:42:06   {'batch_size': 512,
 'class': 'core.Engine',
 'gpus': (0,),
 'gradient_interval': 1,
 'log_interval': 10,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'differentiable': False,
               'eps': 1e-08,
               'foreach': None,
               'fused': False,
               'lr': 1e-05,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 3,
          'atom_types': [6, 7, 8, 9, 15, 16, 17, 35, 53],
          'baseline_momentum': 0.9,
          'class': 'tasks.GCPNGeneration',
          'criterion': 'ppo',
          'gamma': 0.9,
          'hidden_dim_mlp': 128,
          'max_edge_unroll': 12,
          'max_node': 38,
          'model': {'activation': 'relu',
                    'batch_norm': False,
          

In [61]:
# RL finetuning
solver.train(num_epoch=1)

23:42:08   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:42:08   Epoch 0 begin
23:42:22   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:42:22   PPO objective: 1.70762
23:42:22   Penalized logP: -3.193
23:42:22   Penalized logP (max): 2.56596
23:42:23   2 / 512 molecules are invalid even after 20 resampling
23:42:46   1 / 512 molecules are invalid even after 20 resampling
23:43:22   1 / 512 molecules are invalid even after 20 resampling
23:43:36   1 / 512 molecules are invalid even after 20 resampling
23:43:47   1 / 512 molecules are invalid even after 20 resampling
23:43:58   1 / 512 molecules are invalid even after 20 resampling
23:44:13   1 / 327 molecules are invalid even after 20 resampling
23:44:20   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
23:44:20   PPO objective: 1.2842
23:44:20   Penalized logP: -2.82945
23:44:20   Penalized logP (max): 2.64114
23:44:37   1 / 132 molecules are invalid even after 20 resampling
23:44:54   1 / 512 molecules are invalid even after 20 resampling
23:45:02   1 / 48 molecules are inva

In [62]:
solver.save("gcpn_zinc250k_rl_1epoch.pkl")

00:39:59   Save checkpoint to gcpn_zinc250k_rl_1epoch.pkl


In [84]:
results = task.generate(num_sample=1, max_resample=5).cpu()
results.visualize()