In [None]:
# hide
# all_tutorial

# Tutorial - RL Train Cycle Overview

>Overview of the RL training cycle

## RL Train Cycle Overview

The goal of this tutorial is to walk through the RL fit cycle to familiarize ourselves with the `Events` cycle and get a better understanding of how `Callback` and `Environment` classes work.

## High Level Overview

### The Environment

At the highest level, we have the `Environment` class. The `Environment` holds together several sub-modules and orchestrates them during the fit loop. The following are contained in the `Environment`:
- `agent` - This is the actual model we're training
- `template_cb` - this holds a `Template` class that we use to define our chemical space
- `samplers` - samplers generate new samples to train on
- `buffer` - the buffer collects and distributes samples from all the `samplers`
- `rewards` - rewards score samples
- `losses` - losses generate values we can backpropagate through 
- `log` - the log holds a record of all samples in the training process

### Callbacks and the Event Cycle

Each one of the above items is a `Callback`. A `Callback` is a a general class that can hook into the `Environment` fit cycle at a number of pre-defined `Events`. When the `Environment` calls a specific `Event`, the event name is passed to every callback in the `Environment`. If a given `Callback` has a defined function named after the event, that function is called. This creates a very flexible system for customizing training loops.

We'll be looking more at `Events` later. For now, we'll just list them in brief. These are the events called during the RL training cycle in the order they are executed:

- `setup` - called when the `Environment` is created, used to set up values
- `before_train` - called before training is started
- `build_buffer` - draws samples from `samplers` into the `buffer`
- `filter_buffer` - filters samples in the buffer
- `after_build_buffer` - called after buffer filtering. Used for cleanup, logging, etc
- `before_batch` - called before a batch starts, used to set up the `batch state`
- `sample_batch` - samples are drawn from `sampers` and `buffer` into the `batch state`
- `before_filter_batch` - allows preprocessing of samples before filtering
- `filter_batch` - filters samples in `batch state`
- `after_sample` - used for calculating sampling metrics
- `before_compute_reward` - used to set up any values needed for reward computation 
- `compute_reward` - used by `rewards` to compute rewards for all samples in the `batch state`
- `after_compute_reward` - used for logging reward metrics
- `reward_modification` - modify rewards in ways not tracked by the log
- `after_reward_modification` - log reward modification metrics
- `get_model_outputs` - generate necessary tensors from the model
- `after_get_model_outputs` - used for any processing required prior to loss calculation 
- `compute_loss` - compute loss values
- `zero_grad` - zero grad
- `before_step` - used for computation before optimizer step (ie gradient clipping)
- `step` - step optimizer
- `after_batch` - compute batch stats
- `after_train` - final event after all training batches

In [None]:
import sys
sys.path.append('..')

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *

  return f(*args, **kwds)


In [None]:
from collections import Counter

In [None]:
set_global_pool(10)

## Getting Started

We start by creating all the components we need to train a model

### Agent

The `Agent` is the actual model we want to train. For this example, we will use the `LSTM_LM_Small_ZINC` model, which is a `LSTM_LM` model trained on a chunk of the ZINC database.

The agent will actually contain two versions of the model. The main model that we will train with every update iteration, and a baseline model which is updated as an exponentially weighted moving average of the main model. Both models are used in the RL training algorithm we will set up later

In [None]:
agent = LSTM_LM_Small_ZINC(drop_scale=0.5,opt_kwargs={'lr':5e-5})

### Template

The `Template` class is used to conrol the chemical space. We can set parameters on what molecular properties we want to allow. For this example, we set the following:

- Hard Filters - must have qualities
    - `ValidityFilter` - must be a valid chemical structure
    - `SingleCompoundFilter` - samples must be single compounds
    - `RotBondFilter` - compounds can have at most 8 rotatable bonds
    - `ChargeFilter` - compounds must have no net charge
- Soft Filters - nice to have qualities
    - `QEDFilter` - Compounds get a score bonus of +1 if their QED value is greater than 0.5
    - `SAFilter` - compounds get a score bonus of + if their SA score is less than 5
    
We then pass the `Template` to the `TemplateCallback` which integrates the template into the fit loop. Note that we pass `prefilter=True` to the `TemplateCallback`, which ensures compounds that don't meet our hard filters are removed from training

In [None]:
template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     RotBondFilter(None, 8),
                     ChargeFilter(0, 0)],
                    [QEDFilter(0.5, None, score=1.),
                     SAFilter(None, 5, score=1.)])

template_cb = TemplateCallback(template, prefilter=True)

### Reward

For the reward, we will load a scikit-learn linear regression model. This model was trained to predict affinity against erbB1 using molecular fingerprints as inputs

This score function is extremely simple and likely won't translate well to real affinity. It is used as a lightweight example

In [None]:
class FP_Regression_Score():
    def __init__(self, fname):
        self.model = torch.load(fname)
        self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
        
    def __call__(self, samples):
        mols = to_mols(samples)
        fps = maybe_parallel(self.fp_function, mols)
        fps = [fp_to_array(i) for i in fps]
        x_vals = np.stack(fps)
        preds = self.model.predict(x_vals)
        return preds
    
reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')

reward = Reward(reward_function, weight=10.)

aff_reward = RewardCallback(reward, 'aff')

We can think of the score function as a black box that takes in samples (SMILES strings) and returns a single numeric score for each sample. Any score function that follows this paradigm can be integrated into MRL

In [None]:
samples = ['Brc1cc2c(NCc3cccs3)ncnc2s1',
           'Brc1cc2c(NCc3ccncc3)ncnc2s1']

reward_function(samples)

array([-0.80169969, -0.66942228])

### Loss Function

For our loss, we will use the `PPO` reinforcement learning algorithm. See the [PPO](arxiv.org/pdf/1707.06347.pdf) paper for full details.

The gist of it is the loss function takes a batch of samples and directs he model to increase the probability of above-average samples (relative to the batch mean) and decrease he probability of below-average samples. 

In [None]:
pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

### Samplers

`Samplers` fill the role of generating samples to train on. We will use four samplers for this run:

- `sampler1`: `ModelSampler` - this sampler will draw samples from the main model in the `Agent`. We set `buffer_size=1000`, which means we will generate 1000 samples every time we build the buffer. We set `p_batch=0.5`, which means during training, 50% of each batch will be sampled on the fly from the main model and the rest of the batch will come from the buffer
- `sampler2`: `ModelSampler` - this sampler is the same as `sampler1`, but we draw from the baseline model instead of the main model. We set `p_batch=0.`, so this sampler will only contribute to the buffer
- `sampler3`: `LogSampler` - this sampler looks through the log of previous samples. Based on our input arguments, it grabs the top `95` percentile of samples in the log, and randomly selects `100` samples from that subset
- `sampler4`: `DatasetSampler` - this sampler is seeded wih erbB1 training data used to train the score function. This sampler will randomly select 4 samples from the dataset to add to the buffer

In [None]:
gen_bs = 1500
df = pd.read_csv('files/erbB1_affinity_data.csv')
df = df[df.value<-1]

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0.5, gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 95, 100)
sampler4 = DatasetSampler(df.smiles.values, 'erbB1_data', buffer_size=4)

samplers = [sampler1, sampler2, sampler3, sampler4]

### Other Callbacks

We'll add three more callbacks:

- `MaxCallback`: this will grab the max reward within a batch that came from the source `live`. `live` is the name we gave to `sampler1` above. This means the max callback will grab all outputs from `sampler1` corresponding to samples from the live model and add the largest to the batch metrics
- `PercentileCallback`: this does the same as `MaxCallback` but instead of printing the maximum score, it prints the 90th percentile score
- `NoveltyReward`: this is reward modification that gives a bonus score of `0.05` to new samples (ie samples that haven't appeared before in training)

In [None]:
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
new_cb = NoveltyReward(weight=0.05)

cbs = [new_cb, live_p90, live_max]

## Training Walkthrough

Now we will step through the training cycle looking at how each callback event is used

### Setup

The first event occurs when we create our `Environment` using the callbacks we set up before. Instantiating the `Environment` registers all callbacks and runs the `setup` event. Many callbacks use the `setup` event to add terms to the batch log or the metrics log.

In [None]:
env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[loss],
                 cbs=cbs)

Inside the environment, we just created a `Buffer` and a `Log`.

The `Buffer` holds a list of samples, which is currently empty

In [None]:
env.buffer

buffer

In [None]:
env.buffer.buffer

[]

The `Log` holds a number of containers for tracking training outputs

- `metrics`: dictionary of batch metrics. Each key maps to a list where each value in the list is the metric term for  given batch
- `batch_log`: dictionary of batch items. Each key maps to a list. Each element in that list is a list containing the batch values for that key in a given batch
- `unique_samples`: dictionary of unique samples and the rewards for those samples. Useful for looking up if a sample has been seen before
- `df`: dataframe of unique samples and all associated values stored in the `batch_log`

We can see that these log terms have already been populated during the `setup` event

In [None]:
env.log.metrics

{'rewards': [],
 'rewards_final': [],
 'new': [],
 'diversity': [],
 'bs': [],
 'template': [],
 'valid': [],
 'live_diversity': [],
 'live_valid': [],
 'live_rewards': [],
 'live_new': [],
 'aff': [],
 'novel': [],
 'PPO': [],
 'rewards_live_p90': [],
 'rewards_live_max': []}

In [None]:
env.log.batch_log

{'samples': [],
 'sources': [],
 'rewards': [],
 'rewards_final': [],
 'template': [],
 'aff': [],
 'novel': [],
 'PPO': []}

In [None]:
env.log.df

Unnamed: 0,samples,sources,rewards,rewards_final,template,aff,novel,PPO


The keys in the above dictionaries were added by the associated callbacks. For example, look at the `setup` method in `ModelSampler`, the type of sampler we used for `sampler1`:

```
    def setup(self):
        if self.p_batch>0. and self.track:
            log = self.environment.log
            log.add_metric(f'{self.name}_diversity')
            log.add_metric(f'{self.name}_valid')
            log.add_metric(f'{self.name}_rewards')
            log.add_metric(f'{self.name}_new')
```

We gave `sampler1` the name `live`. As a result, the terms `live_diversity`, `live_valid`, `live_rewards` and `live_new` were added to the metrics.

We can also look at the `setup` method of our loss function `loss`:

```
    def setup(self):
        if self.track:
            log = self.environment.log
            log.add_metric(self.name)
            log.add_log(self.name)
```

This is responsible for the `PPO` terms in the `batch_log` and the `metrics`. The PPO metrics term will store the average PPO loss value across a batch, while the PPO batch log term will store the PPO value for each item in a batch

### The Fit Cycle

At this point, we could start training using `Environment.fit`. We could call `env.fit(200, 90, 10, 2)` to train for 10 batches with a batch size of 200. For this tutorial, we will step through each part of the fit cycle and observe what is happening

### Before Train

The first stage of the fit cycle is the `before_train` stage. This sets the batch size and sequence length based on the inputs to `Environment.fit` (which we will set manually) and prints the top of the log

In [None]:
env.bs = 200 # batch size of 200
env.sl = 90 # max sample length of 90 steps
mb = master_bar(range(1))
env.log.pbar = mb
env.report = 1
env.log.report = 1 # report stats every batch
env('before_train')

### Build Buffer

The next stage of the cycle is the `build_buffer` stage. This consists of the following events:
- `build_buffer`: samplers add items to the buffer
- `filter_buffer`: the buffer is filtered
- `after_build_buffer`: use as needed

Going into this stage, our buffer is empty:

In [None]:
env.buffer.buffer

[]

#### build_buffer

By calling the `build_buffer` event, our samplers will add items to the buffer

In [None]:
env('build_buffer')

Now we have 2004 items in the buffer.

In [None]:
len(env.buffer.buffer)

2004

We can use the `buffer_sources` attribute to see where each item came from. We have 1000 items from `live_buffer` which corresponds to `sampler1`, sampling from the main model.

We have 1000 items from `base_buffer` which corresponds to `sampler2`, sampling from the baseline model.

We have 4 items from `erbB1_data_buffer`, our dataset sampler (`sampler4`).

Our log sampler, `sampler3` was set to start sampling after 10 training iterations, so we don't currently have any samples from that sampler

In [None]:
Counter(env.buffer.buffer_sources)

Counter({'live_buffer': 1000, 'base_buffer': 1000, 'erbB1_data_buffer': 4})

#### filter_buffer

It's likely some of these samples don't match our compound requirements defined in the `Template` we used, so we want to filter the buffer for passing compounds. This is what the `filter_buffer` does. For this current example, the only callback doing any buffer filtering is the template callback. However, the `filter_buffer` can be used to implement any form of buffer filtering.

Any callback that passes a list of boolean values to `Buffer._filter_buffer` can filter the buffer.

After filtering, we have 1830 remaining samples

In [None]:
env('filter_buffer')

In [None]:
len(env.buffer.buffer)

1830

In [None]:
Counter(env.buffer.buffer_sources)

Counter({'live_buffer': 914, 'base_buffer': 914, 'erbB1_data_buffer': 2})

#### after_build_buffer

Next is the `after_build_buffer` event. None of our current callbacks make use of this event, but it exists to allow for evaluation/postprocessing/whatever after buffer creation.

### Sample Batch

The next event stage is the `sample_batch` stage. This consists of the following events:

- `before_batch`: set up/refresh any required state prior to batch sampling
- `sample_batch`: draw one batch of samples
- `before_filter_batch`: evaluate unfiltered batch
- `filter_batch`: filter batch
- `after_sample`: compute sample based metrics

#### before_batch

This event is used to create a new `BatchState` for the environment. The batch state is a container designed to hold any values required by the batch

In [None]:
env.batch_state = BatchState()
env('before_batch')

Currently the batch state only has placeholder values for commonly generated terms

In [None]:
env.batch_state

{'samples': [],
 'sources': [],
 'rewards': tensor(0., device='cuda:0'),
 'loss': tensor(0., device='cuda:0', grad_fn=<CopyBackwards>),
 'latent_data': {}}

#### sample_batch

Now we actually draw samples to form a batch. All of our `Sampler` objects have a `p_batch` value, which designated what percentage of the batch should come from that sampler. Batch sampling is designed such that individual sampler `p_batch` values are respected, and any remaining batch percentage comes from the buffer.

Only `sampler1` has `p_batch>0.`, with a value of `p_batch=0.5`. This means 50% of the batch will be sampled on he fly from `sampler1`, and the remaining 50% of the batch will come from the buffer.

Using a hybrid of live sampling and buffer sampling seems to work best. That said, it is possible to have every batch be 100% buffer samples (like offline RL), or have 100% be live samples (like online RL)

In [None]:
env('sample_batch')

Now we can see we've populated several terms in the batch state. `BatchState.samples` now has a list of samples. `BatchState.sources` has the source of each sample.

We also added `BatchState.live_raw` and `BatchState.base_raw`. These terms hold the outputs of `sampler1` and `sampler2`. When we filter `BatchState.samples`, we can refer to the `_raw` terms to see what samples were removed.

Note that `BatchState.base_raw` is an empty list since `sampler2.p_batch=0.`

In [None]:
env.batch_state.keys()

dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw'])

`BatchState.sources` holds the source of each sample. We have 100 samples from `live`, which corresponds to our on the fly samples from `sampler1`. The remaining 100 samples come from `live_buffer` and `base_buffer`. This means they came from either `sampler1` (live) or `sampler2` (base) by way of being sampled from the buffer

In [None]:
Counter(env.batch_state['sources'])

Counter({'live_buffer': 55, 'base_buffer': 45, 'live': 100})

In [None]:
env.batch_state['samples'][:5]

['CC1CCN(C(=O)C2CCN(C(=O)Nc3nc(-c4ccccc4)cs3)CC2)CC1',
 'COc1ccnc(O)c1C(=O)NCC1CC(NC(=O)CCC(F)F)C1',
 'C[C@@H]([C@@H](C)NC(=O)Nc1cccc(S(N)(=O)=O)c1)N1CCOCC1',
 'CCN1C(=O)[C@@H]2CN(C(=O)Nc3ccc(OC(F)F)cn3)CCN2C1=O',
 'CC#CCN[C@@H]1CCN(C(=O)CN(C)S(=O)(=O)CC)C1']

In [None]:
env.batch_state['sources'][:5]

['live_buffer', 'live_buffer', 'base_buffer', 'base_buffer', 'live_buffer']

In [None]:
env.batch_state['live_raw'][:5]

['COc1ccc(N2CCN(CCNC(=O)Nc3ccc(N(C)C(=O)OC(C)(C)C)cc3)CC2)cc1',
 'C=CCCOCC(=O)N[C@@H]1CCCN(C(=O)[C@H](C)C#N)CC1',
 'COCCO[C@H](C)c1nnc(N2CC[C@H](C(N)=O)C2)n1CC1CC1',
 'CCN(CCNCc1cnsn1)C(=O)c1ccc2c(c1)N(C)CC2',
 'FC(F)(F)c1nnc(N2CCC[C@@H](OCCO)C2)s1']

In [None]:
env.batch_state['base_raw']

[]

#### before_filter_batch

This event is not used by any of our current callbacks. It provides a hook to influence the batch state prior to filtering

#### filter_batch

Now the batch will be filtered by our `Template`, as well as any other callbacks with a `filter_batch` method

In [None]:
env('filter_batch')

We can see that 6 of our 200 samples were removed by filtering

In [None]:
len(env.batch_state['samples'])

194

We can compare the values in `BatchState.samples` and `BatchState.live_raw` to see what was filtered

In [None]:
raw_samples = env.batch_state['live_raw']
filtered_samples = [env.batch_state['samples'][i] for i in range(len(env.batch_state['samples'])) 
                    if env.batch_state.sources[i]=='live']

len(filtered_samples), len(raw_samples)

(94, 100)

In [None]:
# filtered compounds
[i for i in raw_samples if not i in filtered_samples]

['COCCO[C@H](C)c1nnc(N2CC[C@H](C(N)=O)C2)n1CC1CC1',
 'FC(F)(F)c1nnc(N2CCC[C@@H](OCCO)C2)s1',
 'CC(C)(C)[S@@](=O)Cc1cccc(NC(=O)c2cc(Br)cn2C)c1',
 'CC(C)C[C@H](CNC(=O)NC[C@@H]1CCc2ccccc2C1)N[C@@H](C)c1ccccc1',
 'CCc1cnccc1C(=O)N1C[C@H]2CCC[C@@H](C1)N2C(=O)C1CC1',
 'Cc1cc(C(=O)N[C@@H]2COC3(CN(C(=O)C[C@H]4C[C@H]4C4CC4)C3)C2)on1',
 'Cc1ccncc1CNC(=O)Cc1coc2c1ccc(F)c2F',
 'CCn1cc(CC(=O)N2CCC([C@H](C)NCCCOC)CC2)cn1',
 'O=C(NC1CCS(=O)(=O)CC1)C(=O)N[C@@H]1CCC[C@H]2OCC[C@H]21',
 'O=S(=O)(NCc1cccc(C(F)(F)F)c1)c1ccc(Cl)c2ccccc21',
 'CCC[C@@H](C(=O)NCCN(CC)Cc1nnc(CC)s1)C(C)C',
 'C#CC[C@H](CCOC)NC(=O)N[C@@H](C)C(=O)NCc1ccco1',
 'COC1CCN(C(C)(C)Cn2c(CC(C)C)nnc2N(C)Cc2ccc(F)c(F)c2)CC1',
 'Cn1cc([C@@H]2CSCCCN2C(=O)c2cccc(S(=O)(=O)N3[C@H](C)CCC[C@@H]3C)c2)cn1',
 'CCn1cc(C(=O)N2C[C@@H](C)[C@H](NC(=O)c3ccncc3)C2)c(C2CC2)n1']

#### after_sample

The `after_sample` event is used to calculate metrics related to sampling

In [None]:
env('after_sample')

We can see that several values have been added to `Environment.log.metrics`

- `new - 1.0`: percent of samples that have not been seen before
- `diversity - 1.0`: number of unique samples relative to the number of total samples
- `bs - 194`: true batch size after filtering
- `valid - 0.97`: percent of samples that passed filtering
- `live_diversity - 1.0`: number of unique samples relative to the number of total samples from `sampler1`
- `live_valid - 0.94`: percent of samples that passed filtering from `sampler1`
- `live_new - 1.0`: percent of samples that have not been seen before from `sampler1`

In [None]:
env.log.metrics

{'rewards': [],
 'rewards_final': [],
 'new': [1.0],
 'diversity': [1.0],
 'bs': [194],
 'template': [],
 'valid': [0.97],
 'live_diversity': [1.0],
 'live_valid': [0.94],
 'live_rewards': [],
 'live_new': [1.0],
 'aff': [],
 'novel': [],
 'PPO': [],
 'rewards_live_p90': [],
 'rewards_live_max': []}

### Compute Reward

After we sample a batch, we enter the `compute_reward` stage. This consists of the following events:

- `before_compute_reward` - used to set up any values needed for reward computation 
- `compute_reward` - used by `rewards` to compute rewards for all samples in the `batch state`
- `after_compute_reward` - used for logging reward metrics
- `reward_modification` - modify rewards in ways not tracked by the log
- `after_reward_modification` - log reward modification metrics

#### before_compute_reward

This event can be used to set up any values needed for reward computation. Most rewards only need the raw samples as inputs, but rewards can use other inputs if needed. The only requirement for a reward is that it returns a tensor with one value per batch item.

By default, the `Agent` class will tensorize the samples present at this step. Our `PPO` loss will also add placeholder values for the terms needed by that function

In [None]:
env('before_compute_reward')

A number of new items have populated the batch state

In [None]:
env.batch_state.keys()

dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw', 'model_gathered_logprobs', 'base_gathered_logprobs', 'mask', 'trajectory_rewards', 'model_logprobs', 'base_logprobs', 'value_input', 'x', 'y', 'bs', 'lengths', 'sl'])

In [None]:
env.batch_state.x # x tensor

tensor([[ 0, 23, 23,  ...,  2,  2,  2],
        [ 0, 23, 28,  ...,  2,  2,  2],
        [ 0, 23, 31,  ...,  2,  2,  2],
        ...,
        [ 0, 23, 23,  ...,  2,  2,  2],
        [ 0, 23, 31,  ...,  2,  2,  2],
        [ 0, 23, 23,  ...,  2,  2,  2]], device='cuda:0')

In [None]:
env.batch_state.y # y tensor

tensor([[23, 23, 11,  ...,  2,  2,  2],
        [23, 28, 34,  ...,  2,  2,  2],
        [23, 31, 23,  ...,  2,  2,  2],
        ...,
        [23, 23, 31,  ...,  2,  2,  2],
        [23, 31, 23,  ...,  2,  2,  2],
        [23, 23,  5,  ...,  2,  2,  2]], device='cuda:0')

In [None]:
env.batch_state.mask # padding mask

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]], device='cuda:0')

#### compute_reward

This step actually computes rewards. The `BatchState` has a tensor of 0s as a placeholder for reward values. Rewards will compute a numeric score for each item in the batch and add it to `BatchState.rewards`

In [None]:
env.batch_state.rewards

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.], device='cuda:0')

In [None]:
env('compute_reward')

In [None]:
env.batch_state.rewards

tensor([ -4.9207,  -6.4757, -10.4049,  -4.2980,  -7.7184,  -3.4992,  -8.7251,
         -6.5808,  -1.9598,  -5.0743,  -5.3781,  -6.2712,   0.4300,  -8.3497,
         -8.6163,  -3.9752,  -8.2539,  -4.1198, -14.1471,   3.9196,  -7.4355,
         -5.7973,  -2.4372,  -4.5564,  -7.3128,  -3.7687, -10.7635,  -6.4387,
        -10.8553,  -8.7834,  -7.7953, -12.2559,  -8.6774,  -3.6228,  -5.5580,
         -7.0956,  -3.0343,  -0.1838,  -7.5708, -10.2341,  -9.2917,  -3.3889,
         -7.2033,  -2.3414,  -0.5930,  -6.9157,  -9.9772, -17.4336,  -4.8487,
         -6.6143,   0.8611,  -6.0077,  -7.2941,   0.7670,  -8.1461,  -2.2909,
         -4.9466, -12.3601,  -5.5133,  -6.9176,  -2.6929, -11.1348,  -6.7806,
         -5.8056, -11.1539,  -7.9624,  -2.9736,  -7.4745,  -7.9849, -13.5966,
         -4.9899,  -0.3546,  -8.9225,  -5.8561,  -4.5420,  -0.6707,  -7.9550,
         -2.6196,  -7.2471,  -3.8212,  -4.4772,   0.4433, -13.9170,  -5.6456,
         -8.0470,  -8.7309,  -1.2033,  -8.6667,  -5.4399,  -5.15

So where did these rewards come from?

One reward term comes from our `Template`. We specified soft rewards for compounds with `QED>=0.5` and `SA<=5`. Compounds could score a maximum of 2 from the template.

We also have the reward from the erbB1 regression model we set up earlier.

The specific rewards from each of these sources are logged in the `BatchState`

For the `Template`, we have `BatchState.template` and `BatchState.template_passes`

In [None]:
env.batch_state.keys()

dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw', 'model_gathered_logprobs', 'base_gathered_logprobs', 'mask', 'trajectory_rewards', 'model_logprobs', 'base_logprobs', 'value_input', 'x', 'y', 'bs', 'lengths', 'sl', 'template', 'template_passes', 'aff'])

Template scores:

In [None]:
env.batch_state.template

array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 2., 2., 2., 1., 2.,
       2., 2., 2., 2., 2., 2., 2.])

`BatchState.template_passes` shows which samples passed the hard filters. Since we decided to prefilter with our template earlier, all remaining samples are passing

In [None]:
env.batch_state.template_passes

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,

And here we have the erbB2 regression scores

In [None]:
env.batch_state.aff

tensor([ -6.9207,  -8.4757, -12.4049,  -6.2980,  -9.7184,  -5.4992, -10.7251,
         -8.5808,  -3.9598,  -7.0743,  -7.3781,  -8.2712,  -1.5700, -10.3497,
        -10.6163,  -5.9752, -10.2539,  -6.1198, -16.1471,   1.9196,  -9.4355,
         -7.7973,  -4.4372,  -6.5564,  -9.3128,  -4.7687, -12.7635,  -8.4387,
        -12.8553, -10.7834,  -9.7953, -14.2559, -10.6774,  -5.6228,  -7.5580,
         -9.0956,  -5.0343,  -2.1838,  -9.5708, -12.2341, -11.2917,  -5.3889,
         -9.2033,  -4.3414,  -2.5930,  -7.9157, -11.9772, -19.4336,  -6.8487,
         -8.6143,  -1.1389,  -8.0077,  -9.2941,  -1.2330, -10.1461,  -4.2909,
         -6.9466, -14.3601,  -7.5133,  -8.9176,  -4.6929, -13.1348,  -8.7806,
         -7.8056, -13.1539,  -9.9624,  -4.9736,  -9.4745,  -9.9849, -15.5966,
         -6.9899,  -2.3546, -10.9225,  -6.8561,  -6.5420,  -2.6707,  -9.9550,
         -4.6196,  -9.2471,  -5.8212,  -6.4772,  -1.5567, -15.9170,  -7.6456,
         -9.0470, -10.7309,  -3.2033, -10.6667,  -7.4399,  -7.15

#### after_compute_reward

This event is used to calculate metrics on the rewards

In [None]:
env('after_compute_reward')

In [None]:
env.log.metrics

{'rewards': [-6.3607645],
 'rewards_final': [],
 'new': [1.0],
 'diversity': [1.0],
 'bs': [194],
 'template': [1.9639175257731958],
 'valid': [0.97],
 'live_diversity': [1.0],
 'live_valid': [0.94],
 'live_rewards': [-6.473961],
 'live_new': [1.0],
 'aff': [array(-8.324681, dtype=float32)],
 'novel': [],
 'PPO': [],
 'rewards_live_p90': [-2.212899112701416],
 'rewards_live_max': [3.9196339]}

#### reward_modification

The reward modification event can be thought of as a second reward that isn't logged. The reason for including this is to allow for transient, "batch context" rewards that don't affect logged values.

When we set up our callbacks earlier, we had a term

`new_cb = NoveltyReward(weight=0.05)`

Which would add a bonus score of 0.05 to new, never before seen samples. The point of this callback is to give the model a soft incentive to generate novel samples. 

We want this score to impact our current batch. However, if we treated it the same as our actual rewards, the samples would be saved into `env.log` with their scores inflated by 0.05. Later, when our `LogSampler` samples from the log, the sampling would be influenced by a score that was only supposed to be given once.

Separating out rewards and reward modifications lets us avoid this

In [None]:
env('reward_modification')

In [None]:
env.batch_state.novel

tensor([0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500,
        0.0500, 0.0500, 0.0500, 0.0500, 

#### after_reward_modification

Similar to `after_compute_reward`, this event can be used to compute stats on reward modifications

In [None]:
env('after_reward_modification')

In [None]:
env.log.metrics

{'rewards': [-6.3607645],
 'rewards_final': [-6.3607645],
 'new': [1.0],
 'diversity': [1.0],
 'bs': [194],
 'template': [1.9639175257731958],
 'valid': [0.97],
 'live_diversity': [1.0],
 'live_valid': [0.94],
 'live_rewards': [-6.473961],
 'live_new': [1.0],
 'aff': [array(-8.324681, dtype=float32)],
 'novel': [array(0.05, dtype=float32)],
 'PPO': [],
 'rewards_live_p90': [-2.212899112701416],
 'rewards_live_max': [3.9196339]}

### Get Model Outputs

After computing rewards, we move to set up our loss calculation. The `get_model_outputs` stage is based on generating the values that we will be backpropagating through. This stage consists of the following events:

- `get_model_outputs` - generate necessary tensors from the model
- `after_get_model_outputs` - used for any processing required prior to loss calculation 

#### get_model_outputs

This is where we generate tensor values used for loss computation.

The specifics of what happens here depends on the type of model used. For autoregressive models, this step involves taking the `x` and `y` tensors we generated during the `before_compute_reward` event and doing a forward pass.

`x` is a tensor of size `(bs, sl)`. Running `x` through the model will give a set of log probabilities of size `(bs, sl, d_vocab)`. We then use `y` to gather the relevant log probs to get a gathered log prob tensor of size `(bs, sl)`.

We generate these values from both the main model and the baseline model

In [None]:
env('get_model_outputs')

In [None]:
env.batch_state.keys()

dict_keys(['samples', 'sources', 'rewards', 'loss', 'latent_data', 'live_raw', 'base_raw', 'model_gathered_logprobs', 'base_gathered_logprobs', 'mask', 'trajectory_rewards', 'model_logprobs', 'base_logprobs', 'value_input', 'x', 'y', 'bs', 'lengths', 'sl', 'template', 'template_passes', 'aff', 'rewards_final', 'novel', 'model_output', 'model_encoded', 'model_latent', 'y_gumbel', 'base_output', 'base_encoded', 'base_latent', 'state_values', 'ref_state_values'])

In [None]:
env.batch_state.model_logprobs.shape, env.batch_state.model_gathered_logprobs.shape

(torch.Size([194, 75, 47]), torch.Size([194, 75]))

#### after_get_model_outputs

This event is not used by any of our current callbacks, but can be used for any sort of post-processing needed before loss computation

### Compute Loss

Now we actually compute a loss value and do an optimizer update. See the `PPO` class for a description of the policy gradient algorithm used.

Loss computation consists of the following steps:

- `compute_loss` - compute loss values
- `zero_grad` - zero grad
- `before_step` - used for computation before optimizer step (ie gradient clipping)
- `step` - step optimizer

#### compute_loss

When we first created our `BatchState`, there was a placehoder value for `loss`. This is the value that will ulimately be backpropagated through. This means we can run any sort of loss configuration, so long as the final values end up in `BatchState.loss`.

For example, the `PPO` policy gradient algorithm we are using involved a `ValueHead` that predicts values at every time step. This model is held in the `PolicyLoss` callback that holds the `PPO` class. During the `compute_loss` event, `PPO` computes an additional loss for the value head that is added to `BatchState.loss`.  `PolicyLoss` also holds an optimizer for the `ValueHead` parameters.

In [None]:
env.batch_state.loss

tensor(0., device='cuda:0', grad_fn=<CopyBackwards>)

In [None]:
env('compute_loss')

In [None]:
env.batch_state.loss

tensor(2.6813, device='cuda:0', grad_fn=<AddBackward0>)

#### zero_grad

This is an event to zero gradients of all optimizers in play. We currently have one optimizer in `Agent` for our generative model and one in `PolicyLoss` for the `ValueHead` of our policy gradient algorithm.

In [None]:
env('zero_grad')
env.batch_state.loss.backward()

#### before_step

This is an event before the actual optimizer step. This is used for things like gradient clipping

In [None]:
env('before_step')

#### step

This is the actual optimizer step. This will step both the `Agent` and `PolicyLoss` optimizers

In [None]:
env('step')

### After Batch

The `after_batch` stage consists of a single `after_batch` event. This is used for any updates at the end of the batch.

In particular, the `Log` will update `Log.df` and the `Agent` will update he baseline model

In [None]:
env('after_batch')

In [None]:
env.log.df

Unnamed: 0,samples,sources,rewards,rewards_final,template,aff,novel,PPO
0,CC1CCN(C(=O)C2CCN(C(=O)Nc3nc(-c4ccccc4)cs3)CC2...,live_buffer,-4.870742,-4.920742,2.0,-6.920742,0.05,0.210345
1,COc1ccnc(O)c1C(=O)NCC1CC(NC(=O)CCC(F)F)C1,live_buffer,-6.425727,-6.475727,2.0,-8.475727,0.05,0.053149
2,C[C@@H]([C@@H](C)NC(=O)Nc1cccc(S(N)(=O)=O)c1)N...,base_buffer,-10.354889,-10.404889,2.0,-12.404889,0.05,4.652449
3,CCN1C(=O)[C@@H]2CN(C(=O)Nc3ccc(OC(F)F)cn3)CCN2...,base_buffer,-4.247975,-4.297976,2.0,-6.297976,0.05,0.487112
4,CC#CCN[C@@H]1CCN(C(=O)CN(C)S(=O)(=O)CC)C1,live_buffer,-7.668355,-7.718355,2.0,-9.718355,0.05,0.708384
...,...,...,...,...,...,...,...,...
189,CCO[C@@H]1C[C@@H]1C(=O)NC1CCN(CC(=O)Nc2ccccc2)CC1,live,-5.953860,-6.003860,2.0,-8.003860,0.05,-0.006749
190,CCn1cc(C(=O)N2C[C@@H](NC(=O)c3ccncc3)[C@H](C)C...,live,-8.985070,-9.035070,2.0,-11.035070,0.05,2.448218
191,CC[C@H](NC(=O)C(=O)N1CCC[C@H]1c1ccc(OC)c(OC)c1...,live,-6.859601,-6.909601,2.0,-8.909601,0.05,0.219688
192,C[C@H](N)[C@H]1CCCN(C(=O)c2ccccc2Br)C1,live,-7.184169,-7.234169,2.0,-9.234169,0.05,0.379862


### After Train

The `after_train` event can be used to calculate any final statistics or other values as desired

In [None]:
env('after_train')

### Conclusions

Hopefully walking through the training process step by step has made he process more understandable. We conclude by simply running `Environment.fit` so we don't have to go through things step by step anymore

In [None]:
env.fit(200, 90, 50, 2)

iterations,rewards,rewards_final,new,diversity,bs,template,valid,live_diversity,live_valid,live_rewards,live_new,aff,novel,PPO,rewards_live_p90,rewards_live_max
2,-6.198,-6.198,1.0,1.0,189,1.979,0.945,1.0,0.89,-6.663,1.0,-8.177,0.05,2.885,-2.415,1.224
4,-5.958,-5.958,1.0,1.0,185,1.951,0.925,1.0,0.85,-5.662,1.0,-7.91,0.05,2.987,-1.025,3.383
6,-6.115,-6.115,1.0,1.0,185,1.957,0.925,1.0,0.85,-6.001,1.0,-8.072,0.05,2.641,-1.548,3.829
8,-5.425,-5.425,1.0,1.0,196,1.98,0.98,1.0,0.96,-5.028,1.0,-7.404,0.05,3.626,-0.638,5.833
10,-5.735,-5.735,1.0,1.0,193,1.943,0.965,1.0,0.93,-5.702,1.0,-7.678,0.05,2.803,-0.664,2.142
12,-5.602,-5.602,1.0,1.0,194,1.959,0.97,1.0,0.94,-5.945,1.0,-7.561,0.05,2.916,-1.131,3.825
14,-6.294,-6.294,1.0,1.0,192,1.958,0.96,1.0,0.92,-6.859,1.0,-8.253,0.05,2.827,-1.196,3.223
16,-6.021,-6.021,1.0,1.0,193,1.964,0.965,1.0,0.93,-6.118,1.0,-7.985,0.05,2.55,-1.561,4.166
18,-5.939,-5.939,0.979,1.0,187,1.957,0.935,1.0,0.87,-6.474,1.0,-7.896,0.049,3.5,-0.362,4.195
20,-5.93,-5.93,0.937,1.0,189,1.958,0.945,1.0,0.89,-6.755,1.0,-7.887,0.047,5.162,-1.946,3.934
