Skip to content

marcojira/tgm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Trajectory General Mellowmax (TGM)

A jax implementation of the TGM algorithm. Contains:

  • Algorithm implementations of
    • TGM (encompasses GFN)
    • SAC
    • PPO
  • Code and data for the following synthetic and biological sequence design tasks:
    • BitSequence
    • Untranslated region (UTR)
    • Antimicrobial peptides (AMP)
    • Green fluorescent protein (GFP)
  • Proxy training

Demo of TGM training

Example run of TGM.

Setup

To get started (currently tested for Python 3.10):

git clone https://github.com/marcojira/tgm.git

# Or your favorite virtual env
python -m venv env && source env/bin/activate && pip install --upgrade pip
pip install -e tgm
python example.py

Training

The easiest way to train a sampler, is to use the run function from medium_rl.run and pass it a Config object from src/medium_rl/config.py.

Config

The Config objects contains general training options (e.g. sampling policy, minibatch size, evaluation, etc.). It also consists of the three following sub configurations:

  • EnvConfig: Configuration for the environment
  • AlgConfig: Configuration for the training algorithm
  • NetworkConfig: Configuration for the neural network

To see available options, check src/medium_rl/config.py. Also already contains base configurations for each. They can then be composed as follows:

from medium_rl.config import (
    AMPConfig,
    TGMConfig,
    BaseTransformerConfig,
    Config,
)
    env_cfg = AMPConfig() # Taking default options
    alg_cfg = TGMConfig(alpha=1, omega=1, q=0.75) # Changing some values
    network_cfg = BaseTransformerConfig(dropout=0.05)

    cfg = Config(
        env=env_cfg,
        alg=alg_cfg,
        network=network_cfg,
        reward_exp=64, # Change beta
        lr=1e-4, # Change lr
    )

Run

Once the Config object is created, running training simply requires:

from medium_rl.run import run

run(cfg)

Training speed

On an L40s GPU, training for 100k samples should be quick:

  • BitSequence: <3 minutes.
  • UTR: <5 minutes.
  • AMP: <5 minutes.
  • GFP: <30 minutes.

Environments

All environments are subclasses of SequenceEnv that describes a generic sequence generation DCG. Similarly to PGX, the core object is a State that contains information about the current sequence. Then, SequenceEnv defines init, step and get_rewards functions to initialize the state, step the state and get proxy rewards for a sequence.

For each of the biological sequence design tasks, the checkpoint for the proxy reward function is provided in src/medium_rl/envs/proxies/<env_name>/proxy.pkl and the validation mean/std in src/medium_rl/envs/proxies/<env_name>/val_stats.pkl

BitSequence

Synthetic task described in Trajectory balance: Improved credit assignment in GFlowNets.

UTR

Sequence design task for the 5' UTR mRNA region that regulates transcription of the main coding sequence. Data to train the proxy comes from brandontrabucco/design-bench#11 (comment) and consists of 250 000 sequences and their associated ribosome loads.

AMP

Antimicrobial peptide design task. Proxy was trained as a binary classifier to predict whether a sequence is antimicrobial on a dataset of 9222 non-AMP sequences and 6438 AMP sequences from https://github.com/MJ10/clamp-gen-data/tree/master/data/dataset. The logit of the classifier is used as proxy reward.

GFP

Green fluorescent protein design task. Data to train the proxy was sourced from brandontrabucco/design-bench#11 (comment) and consists of 56086 variations of the original GFP protein and their associated fluorescence.

Custom environments

To create a custom environment, one can extend the SequenceEnv environment as follows:

from medium_rl.envs.sequence_env import SequenceEnv

class NEW_ENV_NAMESequence(SequenceEnv):
    # Everything below needs to be specified
    name = "NEW_ENV_NAME"

    num_tokens = len(NEW_ENV_ALPHABET)
    alphabet = NEW_ENV_ALPHABET
    dict = {NEW_ENV_ALPHABET[i]: i for i in range(len(NEW_ENV_ALPHABET))}

    CLS = 0 # CLS or BOS token index
    PAD = 1 # PAD token index
    EOS = 2 # EOS token index

    def __init__(self, min_len: int, max_len: int, **kwargs):
        super().__init__(min_len, max_len)
        self.proxy = NEW_ENV_PROXY() # Initialize proxy if necessary

    def get_rewards(
        self,
        token_seq: Array,  # [B, T], batch of sequence of tokens
    ):
        # Need to specify/write proxy reward function that takes in a [B, T] array of token indexes
        # - B: Batch size
        # - T: Sequence length
        # and returns the proxy reward
        rewards = self.proxy.evaluate(token_seq)
        return rewards

Then, a EnvConfig can be specified as follows

class NEW_ENVConfig(EnvConfig):
    name: str = "NEW_ENV"
    min_len: int = 5
    max_len: int = 10
    ...

Proxy training

src/medium_rl/envs/proxies/train_proxy.py contains code for training proxy reward functions from data. train_model expects a x (a [N, T] array of token indexes) and y (a [N,] array of either floats to regress to or binary classes) as well as model_cfg specifying the hyperparameters of the network. See train_proxies.py for example uses.

Acknowledgments

The biological environments are jax implementations with moderate modifications of the environments of Biological Sequence Design with GFlowNets as well as the benchmarks of Design-Bench: Benchmarks for Data-Driven Offline Model-Based Optimization. The training process for the proxy reward functions comes from the former and the data used from the latter. The BitSequence environment comes from Trajectory balance: Improved credit assignment in GFlowNets . The design of the SequenceEnv environment is inspired by the PGX library.

About

Jax codebase for the TGM algorithm

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages