A jax
implementation of the TGM algorithm. Contains:
- Algorithm implementations of
TGM
(encompassesGFN
)SAC
PPO
- Code and data for the following synthetic and biological sequence design tasks:
BitSequence
- Untranslated region (
UTR
) - Antimicrobial peptides (
AMP
) - Green fluorescent protein (
GFP
)
- Proxy training
To get started (currently tested for Python 3.10
):
git clone https://github.com/marcojira/tgm.git
# Or your favorite virtual env
python -m venv env && source env/bin/activate && pip install --upgrade pip
pip install -e tgm
python example.py
The easiest way to train a sampler, is to use the run
function from medium_rl.run
and pass it a Config
object from src/medium_rl/config.py
.
The Config
objects contains general training options (e.g. sampling policy, minibatch size, evaluation, etc.). It also consists of the three following sub configurations:
EnvConfig
: Configuration for the environmentAlgConfig
: Configuration for the training algorithmNetworkConfig
: Configuration for the neural network
To see available options, check src/medium_rl/config.py
. Also already contains base configurations for each. They can then be composed as follows:
from medium_rl.config import (
AMPConfig,
TGMConfig,
BaseTransformerConfig,
Config,
)
env_cfg = AMPConfig() # Taking default options
alg_cfg = TGMConfig(alpha=1, omega=1, q=0.75) # Changing some values
network_cfg = BaseTransformerConfig(dropout=0.05)
cfg = Config(
env=env_cfg,
alg=alg_cfg,
network=network_cfg,
reward_exp=64, # Change beta
lr=1e-4, # Change lr
)
Once the Config
object is created, running training simply requires:
from medium_rl.run import run
run(cfg)
On an L40s GPU, training for 100k samples should be quick:
BitSequence
: <3 minutes.UTR
: <5 minutes.AMP
: <5 minutes.GFP
: <30 minutes.
All environments are subclasses of SequenceEnv
that describes a generic sequence generation DCG. Similarly to PGX, the core object is a State
that contains information about the current sequence. Then, SequenceEnv
defines init
, step
and get_rewards
functions to initialize the state, step the state and get proxy rewards for a sequence.
For each of the biological sequence design tasks, the checkpoint for the proxy reward function is provided in src/medium_rl/envs/proxies/<env_name>/proxy.pkl
and the validation mean/std in src/medium_rl/envs/proxies/<env_name>/val_stats.pkl
Synthetic task described in Trajectory balance: Improved credit assignment in GFlowNets.
Sequence design task for the 5' UTR mRNA region that regulates transcription of the main coding sequence. Data to train the proxy comes from brandontrabucco/design-bench#11 (comment) and consists of 250 000 sequences and their associated ribosome loads.
Antimicrobial peptide design task. Proxy was trained as a binary classifier to predict whether a sequence is antimicrobial on a dataset of 9222 non-AMP sequences and 6438 AMP sequences from https://github.com/MJ10/clamp-gen-data/tree/master/data/dataset. The logit of the classifier is used as proxy reward.
Green fluorescent protein design task. Data to train the proxy was sourced from brandontrabucco/design-bench#11 (comment) and consists of 56086 variations of the original GFP protein and their associated fluorescence.
To create a custom environment, one can extend the SequenceEnv
environment as follows:
from medium_rl.envs.sequence_env import SequenceEnv
class NEW_ENV_NAMESequence(SequenceEnv):
# Everything below needs to be specified
name = "NEW_ENV_NAME"
num_tokens = len(NEW_ENV_ALPHABET)
alphabet = NEW_ENV_ALPHABET
dict = {NEW_ENV_ALPHABET[i]: i for i in range(len(NEW_ENV_ALPHABET))}
CLS = 0 # CLS or BOS token index
PAD = 1 # PAD token index
EOS = 2 # EOS token index
def __init__(self, min_len: int, max_len: int, **kwargs):
super().__init__(min_len, max_len)
self.proxy = NEW_ENV_PROXY() # Initialize proxy if necessary
def get_rewards(
self,
token_seq: Array, # [B, T], batch of sequence of tokens
):
# Need to specify/write proxy reward function that takes in a [B, T] array of token indexes
# - B: Batch size
# - T: Sequence length
# and returns the proxy reward
rewards = self.proxy.evaluate(token_seq)
return rewards
Then, a EnvConfig
can be specified as follows
class NEW_ENVConfig(EnvConfig):
name: str = "NEW_ENV"
min_len: int = 5
max_len: int = 10
...
src/medium_rl/envs/proxies/train_proxy.py
contains code for training proxy reward functions from data. train_model
expects a x
(a [N, T] array of token indexes) and y
(a [N,] array of either floats to regress to or binary classes) as well as model_cfg
specifying the hyperparameters of the network. See train_proxies.py
for example uses.
The biological environments are jax implementations with moderate modifications of the environments of Biological Sequence Design with GFlowNets
as well as the benchmarks of Design-Bench: Benchmarks for Data-Driven Offline Model-Based Optimization. The training process for the proxy reward functions comes from the former and the data used from the latter. The BitSequence environment comes from Trajectory balance: Improved credit assignment in GFlowNets
. The design of the SequenceEnv
environment is inspired by the PGX library.