Skip to content

Enhanced Template-Free Reaction Prediction with Molecular Graphs and Sequence-based Data Augmentation

Notifications You must be signed in to change notification settings


Repository files navigation

Enhanced Template-Free Reaction Prediction with Molecular Graphs and Sequence-based Data Augmentation

Here is the code for "Enhanced Template-Free Reaction Prediction with Molecular Graphs and Sequence-based Data Augmentation" in CIKM'23

We have now reuploaded the missing modules, and our presentation PowerPoint in CIKM'23, which has some additional figure descriptions of the algorithm.

1. Overview

The overview of the directory is shown as:

├─index_elemtwise---------------#the python package for custom kernel
├─index_elemtwise_cuda----------#the cpp and cuda source code for custom index_select + elemtwise operate
│  ├─csrc
│  │  └─cuda
│  └─index_elemtwise
├─model-------------------------#the py source for our model
│  ├─ckpt
│  │  └─Model0
│  │      ├─50k
│  │      ├─50k_class
│  │      ├─full
│  │      └─mit
│  ├─inference------------------#the py source for beam search (some of them are from huggingface)
│  └─preprocess-----------------#the data preprocess source for our model
│      └─data
│          ├─uspto_50k
│          │  ├─process
│          │  └─raw
│          ├─uspto_full
│          │  ├─process
│          │  └─raw
│          └─uspto_MIT
│              ├─process
│              └─raw
└─scripts-----------------------#the .sh file for quick start

2. Environment setup

Code was run and evaluated for:

- python 3.10.9
- pytorch 2.0.0 (for SDPA kernel)
- torch-scatter 2.1.1+pt20cu117
- rdkit 2022.03.2

Models were trained on RTX A5000 with 24GB memory for larger batch size(e.g. 64*2), which also available for less GPU memory with an appropriate batch size setting and larger gradient accumulation steps(e.g. 32*2 and accumulate 4 steps for 6GB).

Note that an different version of rdkit may result in different SMILES canonicalization results.

3. Custom indexing kernel

The custom CUDA kernel for the operation src1.index_select(0, idx1) ~ src2.index_select(0, idx2) is now available, which is one of the operation in our padding-free global attention for molecular graphs, it is faster than naive pytorch operation and also support AMP in pytorch. It should works well in most of the situation and will be further optimized later. You can install it by running :

python index_elemtwise_cuda/ install

4. Data preprocessing

We mainly use USPTO-50k, USPTO-full and USPTO-MIT datasets for training and evaluation, you can download them manually at the following address:


notice that USPTO-50k are already available in the source file.

After downloading them, please put the into model/preprocess/data/${dataset_name}/raw, and then run the following scripts for preprocessing:

scripts/ -> uspto_50k
scripts/ -> uspto_MIT
scripts/ -> uspto_full

5. Training

You can train the model by:

scripts/ -> uspto_50k
scripts/ -> uspto_50k with reaction class
scripts/ -> uspto_MIT
scripts/ -> uspto_full

The log file, checkpoint queue list, tensorboard file, and checkpoints are available at model/ckpt/Model0/${training_start_time}, notice that the model will eval automatically and generate average checkpoint according to the checkpoint queue, like, and so on, for the large dataset like uspto_full, it will use a random subset to run the evaluation during training.

6. Evaling

The checkpoints are available at the following address, which includes .pt file, log file, and evaluation results in the manuscript:


After downloading them, please unzip and put the into model/ckpt/Model0/${dataset_name}, and then run the following scripts for evaling and testing:

scripts/ -> uspto_50k
scripts/ -> uspto_50k with reaction class
scripts/ -> uspto_MIT
scripts/ -> uspto_full

You can find the result in model/ckpt/Model0/${dataset_name}, which includes the top-10 accuracy and top-10 invalid rate, or try to use different searching hyperparameters like temperature(T), top-k sampling, top-p sampling, and group beam search.


Enhanced Template-Free Reaction Prediction with Molecular Graphs and Sequence-based Data Augmentation






No releases published


No packages published