In [None]:
# hide
# all_tutorial

# Tutorial - RL Train Cycle Overview

>Overview of the RL training cycle

## RL Train Cycle Overview

The goal of this tutorial is to walk through the RL fit cycle to familiarize ourselves with the `Events` cycle and get a better understanding of how `Callback` and `Environment` classes work.

## High Level Overview

### The Environment

At the highest level, we have the `Environment` class. The `Environment` holds together several sub-modules and orchestrates them during the fit loop. The following are contained in the `Environment`:
- `agent` - This is the actual model we're training
- `template_cb` - this holds a `Template` class that we use to define our chemical space
- `samplers` - samplers generate new samples to train on
- `buffer` - the buffer collects and distributes samples from all the `samplers`
- `rewards` - rewards score samples
- `losses` - losses generate values we can backpropagate through 
- `log` - the log holds a record of all samples in the training process

### Callbacks and the Event Cycle

Each one of the above items is a `Callback`. A `Callback` is a a general class that can hook into the `Environment` fit cycle at a number of pre-defined `Events`. When the `Environment` calls a specific `Event`, the event name is passed to every callback in the `Environment`. If a given `Callback` has a defined function named after the event, that function is called. This creates a very flexible system for customizing training loops.

We'll be looking more at `Events` later. For now, we'll just list them in brief. These are the events called during the RL training cycle in the order they are executed:

- `setup` - called when the `Environment` is created, used to set up values
- `before_train` - called before training is started
- `build_buffer` - draws samples from `samplers` into the `buffer`
- `filter_buffer` - filters samples in the buffer
- `after_build_buffer` - called after buffer filtering. Used for cleanup, logging, etc
- `before_batch` - called before a batch starts, used to set up the `batch state`
- `sample_batch` - samples are drawn from `sampers` and `buffer` into the `batch state`
- `before_filter_batch` - allows preprocessing of samples before filtering
- `filter_batch` - filters samples in `batch state`
- `after_sample` - used for calculating sampling metrics
- `before_compute_reward` - used to set up any values needed for reward computation 
- `compute_reward` - used by `rewards` to compute rewards for all samples in the `batch state`
- `after_compute_reward` - used for logging reward metrics
- `reward_modification` - modify rewards in ways not tracked by the log
- `get_model_outputs` - generate necessary tensors from the model
- `after_get_model_outputs` - used for any processing required prior to loss calculation 
- `compute_loss` - compute loss values
- `zero_grad` - zero grad
- `before_step` - used for computation before optimizer step (ie gradient clipping)
- `step` - step optimizer
- `after_batch` - compute batch stats
- `after_train` - final event after all training batches

In [None]:
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *

from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *

  return f(*args, **kwds)


In [None]:
set_global_pool(10)

In [None]:
template = Template([ValidityFilter(), 
                     SingleCompoundFilter(), 
                     RotBondFilter(None, 8),
                     ChargeFilter(None, 0)],
                    [QEDFilter(0.5, None, score=1.),
                     SAFilter(None, 5, score=1.)], 
                    fail_score=-10., log=False)

template_cb = TemplateCallback(template, prefilter=True)

In [None]:
class FP_Regression_Score():
    def __init__(self, fname):
        self.model = torch.load(fname)
        self.fp_function = partial(failsafe_fp, fp_function=ECFP6)
        
    def __call__(self, samples):
        mols = to_mols(samples)
        fps = maybe_parallel(self.fp_function, mols)
        fps = [fp_to_array(i) for i in fps]
        x_vals = np.stack(fps)
        preds = self.model.predict(x_vals)
        return preds
    
reward_function = FP_Regression_Score('files/erbB1_regression.sklearn')

reward = Reward(reward_function, weight=10.)

aff_reward = RewardCallback(reward, 'aff')



In [None]:
pg = PPO(0.99,
        0.5,
        lam=0.95,
        v_coef=0.5,
        cliprange=0.3,
        v_cliprange=0.3,
        ent_coef=0.01,
        kl_target=0.03,
        kl_horizon=3000,
        scale_rewards=True)

loss = PolicyLoss(pg, 'PPO', 
                   value_head=ValueHead(256), 
                   v_update_iter=2, 
                   vopt_kwargs={'lr':1e-3})

In [None]:
agent = LSTM_LM_Small_ZINC(drop_scale=0.5,opt_kwargs={'lr':5e-5})

In [None]:
gen_bs = 1500
df = pd.read_csv('untracked_files/affinity_data_set.csv')
df = df[df.value<-1]

sampler1 = ModelSampler(agent.vocab, agent.model, 'live', 1000, 0.5, gen_bs)
sampler2 = ModelSampler(agent.vocab, agent.base_model, 'base', 1000, 0., gen_bs)
sampler3 = LogSampler('samples', 'rewards', 10, 95, 100)
sampler4 = DatasetSampler(df.smiles.values, 'erbB1_data', buffer_size=4)

samplers = [sampler1, sampler2, sampler3, sampler4]

In [None]:
supervised_cb = SupevisedCB(agent, 200, 0.5, 97, 5e-5, 64)
live_max = MaxCallback('rewards', 'live')
live_p90 = PercentileCallback('rewards', 'live', 90)
new_cb = NoveltyReward(weight=0.05)

cbs = [new_cb, supervised_cb, live_p90, live_max]

In [None]:
env = Environment(agent, template_cb, samplers=samplers, rewards=[aff_reward], losses=[loss],
                 cbs=cbs)

In [None]:
env.fit(200, 90, 10, 2)

iterations,rewards,rewards_final,new,diversity,bs,template,valid,live_diversity,live_valid,live_rewards,live_new,aff,novel,PPO,rewards_live_p90,rewards_live_max
0,-6.118,-6.118,1.0,1.0,190,1.968,0.95,1.0,0.9,-6.763,1.0,-8.087,0.05,3.216,-1.674,3.496
2,-6.28,-6.28,1.0,1.0,188,1.968,0.94,1.0,0.88,-5.995,1.0,-8.249,0.05,2.319,-2.585,6.606
4,-6.139,-6.139,1.0,1.0,188,1.973,0.94,1.0,0.88,-6.146,1.0,-8.113,0.05,2.588,-2.053,1.301
6,-6.07,-6.07,1.0,1.0,192,1.984,0.96,1.0,0.92,-5.836,1.0,-8.055,0.05,2.991,-0.729,7.122
8,-6.53,-6.53,1.0,1.0,189,1.958,0.945,1.0,0.89,-5.979,1.0,-8.488,0.05,2.663,-2.031,4.096
