In [6]:
import torch
from torchdrug import datasets

dataset = datasets.ZINC250k("~/molecule-datasets/", kekulize=True,
                            atom_feature="symbol")

Loading /Users/ca/molecule-datasets/250k_rndm_zinc_drugs_clean_3.csv:  50%|█████     | 249456/498911 [00:03<00:03, 71432.88it/s]
Constructing molecules from SMILES: 100%|██████████| 249455/249455 [06:07<00:00, 678.36it/s]


In [7]:
from pna import PNA
from torch_geometric.utils import degree

In [11]:
deg = torch.zeros(10, dtype=torch.long)
for data in dataset:
    graph = data['graph']
    d = degree(graph.edge_list[:, 1], num_nodes=graph.num_node, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

PNA(input_dim=dataset.node_feature_dim,
                  hidden_dim=256, num_layer=3,
                  edge_input_dim=dataset.edge_feature_dim,
                  num_relation=dataset.num_bond_type,
                  aggregators=['mean'],
                  scalers=['identity'],
                  deg=deg, batch_norm=False)

PNA(
  (node_encoder): Linear(in_features=18, out_features=256, bias=True)
  (layers): ModuleList(
    (0): PNALayer(256, 256, towers=1)
    (1): PNALayer(256, 256, towers=1)
    (2): PNALayer(256, 256, towers=1)
  )
  (readout): SumReadout()
)

In [9]:
from gsn import GSN
from gsn import prepare_GSN_dataset
prepare_GSN_dataset(dataset, max_cycle=8)

In [12]:
from torchdrug import core, models, tasks

model = GSN(input_dim=dataset.node_feature_dim,
                  hidden_dim=256, num_layer=3,
                  edge_input_dim=dataset.edge_feature_dim,
                  num_relation=dataset.num_bond_type,
                  batch_norm=False)
task = tasks.GCPNGeneration(model, dataset.atom_types, max_edge_unroll=12,
                            max_node=38, criterion="nll")

In [13]:
from torch import nn, optim
optimizer = optim.Adam(task.parameters(), lr = 1e-3)
solver = core.Engine(task, dataset, None, None, optimizer,
                     batch_size=128, log_interval=10)

solver.train(num_epoch=1)

14:35:11   Preprocess training set
14:35:12   {'batch_size': 128,
 'class': 'core.Engine',
 'gpus': None,
 'gradient_interval': 1,
 'log_interval': 10,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.001,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 10,
          'atom_types': [6, 7, 8, 9, 15, 16, 17, 35, 53],
          'baseline_momentum': 0.9,
          'class': 'tasks.GCPNGeneration',
          'criterion': 'nll',
          'gamma': 0.9,
          'hidden_dim_mlp': 128,
          'max_edge_unroll': 12,
          'max_node': 38,
          'model': {'activation': 'relu',
                    'batch_norm': False,
                    'class': 'GSN',
                    'concat_hidden': False,
                    'edge_input_dim': 18,
                    'hidden_dim': 256,
                    '

KeyboardInterrupt: 