## Generative Model: GCPN
___

Graph Generation using GCPN on the ZINC250k Dataset.

This script demonstrates the use of the Graph Convolutional Policy Network (GCPN)
to perform graph generation on the ZINC250k dataset. It initializes an RGCN model,
sets up a graph generation task, trains the model for one epoch, saves the trained model,
and then loads the trained model to generate molecule samples, which are printed in the SMILES format.

Dependencies:
- torch
- torchdrug

Date: Sep.23.2023
Place: UC Merced

In [None]:
import torch
from torch import nn, optim
from torchdrug import datasets
from torchdrug import core, models, tasks

In [None]:
dataset = datasets.ZINC250k("~/molecule-datasets/", kekulize=True,
                            atom_feature="symbol")

In [None]:
model = models.RGCN(
    input_dim=dataset.node_feature_dim,
    num_relation=dataset.num_bond_type,
    hidden_dims=[256, 256, 256, 256],
    batch_norm=False,
)

task = tasks.GCPNGeneration(
    model, dataset.atom_types, max_edge_unroll=12, max_node=38, criterion="nll"
)


In [None]:
optimizer = optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(
    task, dataset, None, None, optimizer, gpus=(0,), batch_size=128, log_interval=10
)

solver.train(num_epoch=1)
solver.save("path_to_dump/graphgeneration/gcpn_zinc250k_1epoch.pkl")


In [None]:
solver.load("path_to_dump/graphgeneration/gcpn_zinc250k_1epoch.pkl")
results = task.generate(num_sample=32, max_resample=5)
print(results.to_smiles())
