Skip to content

Pytorch implementation for the paper "FRAIL: Fragment-Based Reinforcement Learning for Molecular Design and Benchmarking on Fatty Acid Amide Hydrolase 1 (FAAH-1)".

Notifications You must be signed in to change notification settings

AppliedAI-Lab/FRAIL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FRAIL: Fragment-Based Reinforcement Learning for Molecular Design and Benchmarking on FAAH-1

This is the source code for the paper "FRAIL: Fragment-Based Reinforcement Learning for Molecular Design and Benchmarking on Fatty Acid Amide Hydrolase 1 (FAAH-1)".

Overview

The system combines:

  • Generator based on DrugEx with fragment-based representation (graph transformer) and multi-objective reinforcement learning training.
  • Predictor (QSAR) for properties such as SA, MW, TPSA, HBD, HBA, LogP, and pIC50 on FAAH-1, using RDKit, PyTDC, and a trained Chemprop model.
  • Complete inference pipeline for generating and filtering FAAH-1 molecules.

Environment Setup

It is recommended to use conda to install RDKit and create a separate environment.

  • Step 1: Create environment
conda create -n frail python=3.10 -y
conda activate frail
  • Step 2: Install RDKit (required for RDKit, PyTDC, Chemprop)
conda install -c rdkit rdkit -y
  • Step 3: Install remaining Python libraries
pip install -r requirements.txt
  • Step 4: Check path configuration

Adjust the following configuration files to match your system (especially absolute paths to data and checkpoints):

  • src/configs/generators/DrugexConfigs.py
    • contains DATASETS_PATH, DATASETS_ENCODED_PATH, DATA_FILE, TARGET_NAME, MODEL_PATH, VOCAB_PATH, ...
  • src/configs/predictors/ChempropConfigs.py
    • contains PIC50_PATH, ROUND_DIGITS, FILTER_THRES
  • (optional) src/settings.py if you want to synchronize paths.

Main Structure

Key files:

  • src/run.py: complete inference pipeline (generate + filter) for FAAH-1.
  • src/engine/generator.py: wrapper around DrugEx for fine-tuning and reinforcement learning based on fragment graph.
  • src/engine/predictor.py: property predictor using RDKit + PyTDC + Chemprop pIC50, supports molecule filtering.

Example data directories:

  • data/: contains input dataset (SMILES, pIC50 labels, etc.)

Inference / Generate Molecules with run.py

By default, src/run.py defines the Pipeline class:

  • Pipeline.generator: instance of Generator (DrugEx) from src/engine/generator.py
  • Pipeline.predictor: instance of Predictor from src/engine/predictor.py
  • Method Pipeline.generate(...):
    • generates molecules using DrugEx from input fragments
    • predicts properties
    • filters / writes to CSV file.

Run default inference

From the root directory:

conda activate frail
python src/run.py

Output:

  • A CSV file containing generated molecules and properties, for example:
    • SMILES, SA, MW, TPSA, HBD, HBA, LogP, pIC50

Customize inference

  • Edit directly in src/run.py:
    • NUM_SAMPLES: desired number of molecules.
    • INPUT_FRAGMENTS: seed fragments (SMILES) for other targets, if you want to try beyond FAAH-1.
    • output_file: location/reason for saving results (e.g., in gen_data/).

Or use Pipeline in your own script:

from src.run import Pipeline

pipeline = Pipeline()
smiles_out_path = "gen_data/my_gen_results.csv"

pipeline.generate(
    input_fragments=["ClC1=CC=C2CCNCC2=C1Cl"],
    num_samples=1000,
    output_file=smiles_out_path,
    upscale=5,  # optional, increase raw samples generated for filtering
)

Training Generator (DrugEx) with src/engine/generator.py

The file src/engine/generator.py defines the Generator class with main methods:

  • finetune(...): fine-tuning a pretrained DrugEx model on your FAAH-1 dataset.
  • train_rl(...): multi-objective RL training with DrugExEnvironment + custom scorers.
  • generate(...): generate molecules from a trained model.

General steps:

  1. Prepare data:

    • Configure paths in configs/generators/DrugexConfigs.py:
      • DATASETS_PATH: directory containing raw data files.
      • DATA_FILE: name of CSV file containing SMILES column (and pIC50 label if needed).
      • TARGET_NAME: target name (e.g., "FAAH"), used to name encoded train/test files.
    • Generator will:
      • read data,
      • standardize SMILES,
      • fragment and encode,
      • create train_set and test_set in DATASETS_ENCODED_PATH.
  2. Fine-tune DrugEx model:

Simple example script:

from engine.generator import Generator
from configs.generators.DrugexConfigs import MODEL_PATH, VOCAB_PATH

generator = Generator()

finetuned_model, vocab = generator.finetune(
    model_path=MODEL_PATH,       # pretrained DrugEx checkpoint
    vocab_path=VOCAB_PATH,       # corresponding vocabulary
    epochs=10,
    batch_size=64,
    save_path="data/models/finetune/FAAH_FT_ep10"
)
  1. Train Reinforcement Learning:

After having a good model (pretrained or fine-tuned), you can run RL:

from engine.generator import Generator
from configs.generators.DrugexConfigs import MODEL_PATH, VOCAB_PATH

generator = Generator()

rl_explorer, env = generator.train_rl(
    agent_model_path=MODEL_PATH,        # or finetuned checkpoint
    agent_vocab_path=VOCAB_PATH,
    mutate_model_path=None,             # default uses agent as mutate
    mutate_vocab_path=None,
    epochs=20,
    batch_size=64,
    epsilon=0.2,
    save_path="data/models/rl/FAAH_RL_ep20"
)

You can put the above code into a separate script (e.g., train_generator.py) and run:

python train_generator.py

Predictor and Training / Configuration with src/engine/predictor.py

The file src/engine/predictor.py defines the Predictor class:

  • Combines:
    • PyTDC Oracles for logP and SA.
    • RDKit for MW, TPSA, HBD, HBA.
    • Chemprop model for pIC50 (FAAH-1), loaded from PIC50_PATH.
  • Provides functions:
    • predict_* for each property.
    • predict(smiles) returns a dict of properties.
    • is_valid(smiles, props) to check if molecule is within filter range in FILTER_THRES.
    • filter(smiles_list) to compute properties & filter molecule list (used in RL and pipeline).

Using Predictor independently

Example:

from engine.predictor import Predictor

predictor = Predictor()
smiles = "COc1ccc(OC(=O)N2CCC(c3nc(C4=NOC(c5ccccc5)C4)cs3)CC2)cc1"

props = predictor.predict(smiles)
is_ok = predictor.is_valid(smiles, props)

print(props)
print("Valid:", is_ok)

Training / Replacing pIC50 Model (Chemprop)

In this repo, Predictor assumes you already have a trained Chemprop checkpoint for pIC50 on FAAH-1:

  • Checkpoint path is configured in PIC50_PATH (file configs/predictors/ChempropConfigs.py).
  • Model is loaded via engine.predictors.chemprop.chemprop.

To retrain the pIC50 model:

  • Use original Chemprop code or the engine/predictors/chemprop directory (following Chemprop instructions).
  • Train model on pIC50 dataset for FAAH-1.
  • Update PIC50_PATH to point to the new checkpoint.

After updating, Predictor will automatically use the new model when you run:

  • src/run.py (inference pipeline)
  • src/engine/custom_scorers.py during RL.

Notes

  • Paths in the current repo are absolute paths according to the experimental environment; when running on other machines, please update these paths in the configuration files.
  • RL training and inference with DrugEx may require GPU for efficient execution; ensure you install CUDA/cuDNN and appropriate torch version if using GPU.
  • If you change data structure or add new properties, please update accordingly:
    • engine/custom_scorers.py
    • engine/predictor.py
    • configuration in configs/.

Citations

About

Pytorch implementation for the paper "FRAIL: Fragment-Based Reinforcement Learning for Molecular Design and Benchmarking on Fatty Acid Amide Hydrolase 1 (FAAH-1)".

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published