In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq mrl-pypi  # upgrade mrl on colab

In [None]:
# default_exp train.reward

# Reward

> Rewards - non-differentiable scores for samples

## Overview

Rewards are non-differentiable score functions for evaluating samples. Rewards should generally follow the format `reward = reward_function(sample)`

Rewards in MRL occupy five events in the fit loop:
- `before_compute_reward` - set up necessary values prior to reward calculation (if needed)
- `compute_reward` - compute reward
- `after_compute_reward` - compute metrics (if needed)
- `reward_modification` - adjust rewards
- `after_reward_modification` - compute metrics (if needed)

### Rewards vs Reward Modifications

MRl breaks rewards up into two phases - rewards and reward modifications. The difference between the two phases is that __reward__ values are saved in the batch log, while __reward_modifications__ are not. 

In this framework, rewards are absolute scores for samples that are used to evaluate the sample relative to all other samples in the log. Reward modifications are transient scores that depend on the current training context.

A reward modification might be something like adding a score bonus to compounds the first time they are created during training to encourage diversity, or penalizing compounds if they appear more than 3 times in the last 5 batches. These types of reward modifications allow us to influence the behavior of the generative model without having these scores effect the true rewards we save in the log

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export

from mrl.imports import *
from mrl.core import *
from mrl.train.callback import *
from mrl.torch_imports import *
from mrl.torch_core import *

## Reward Class

As mentioned above, rewards generally follow the format `reward = reward_function(sample)`. The `Reward` class acts as a wrapper around the `reward_function` to provide some convenience functions. `Reward` maintains a lookup table of `sample : reward` values to avoid repeat computation. `Reward` handles batching novel samples (ie not in the lookup table), sending them to `reward_function`, and merging the outputs with the lookup table values.

Creating a custom reward involves creating a callable function or object that can take in a list of `samples` and return a list of reward values. For example:

```
class MyRewardFunction():
    def __call__(self, samples):
        rewards = self.do_reward_calculation(samples)
        return rewards
        
reward_function = MyRewardFunction()
reward = Reward(reward_function, weight=0.5, log=True)
```


In [None]:
# export

class Reward():
    '''
    Reward - wrapper for `reward_function`. Handles batching 
    and value lookup
    
    Inputs:
    
    - `reward_function Callable`: function with the format 
    `rewards = reward_function(samples)`
    
    - `weight float`: weight to scale rewards
    
    - `bs Optional[int]`: if given, samples will be batched into 
    chunks of size `bs` and sent to `reward_function` as batches
    
    - `device Optional[bool]`: if True, reward function output is 
    mapped to device. see `to_device`
    
    - `log bool`: if True, keeps aa lookup table of 
    `sample : reward` values to avoid repeat computation
    '''
    def __init__(self, reward_function, weight=1, bs=None, device=False, log=True):
        
        self.reward_function = reward_function
        self.weight = weight
        self.bs = bs
        self.device = device
        self.score_log = {}
        self.log = log
        
    def load_data(self, samples, values):
        for i in range(len(samples)):
            self.score_log[samples[i]] = values[i]
            
    def __call__(self, samples, **reward_kwargs):
        
        rewards = np.array([0. for i in samples])
        
        to_score = []
        to_score_idxs = []
        
        for i in range(len(samples)):
                
            if self.log:
                if samples[i] in self.score_log:
                    rewards[i] = self.score_log[samples[i]]
                else:
                    to_score.append(samples[i])
                    to_score_idxs.append(i)

            else:
                to_score.append(samples[i])
                to_score_idxs.append(i)
                    
        if to_score:
            new_rewards = self.compute_batched_reward(to_score, **reward_kwargs)

            for i in range(len(to_score)):
                batch_idx = to_score_idxs[i]
                reward = new_rewards[i]
                rewards[batch_idx] = reward

                if self.log:
                    self.score_log[to_score[i]] = reward
                
        rewards = torch.tensor(rewards).float().squeeze()
        rewards = rewards * self.weight
        
        if self.device:
            rewards = to_device(rewards)

        return rewards
            
    def _compute_reward(self, samples, **reward_kwargs):
        return self.reward_function(samples, **reward_kwargs)
    
    def compute_batched_reward(self, samples, **reward_kwargs):
        if self.bs is not None:
            sample_chunks = chunk_list(samples, self.bs)
            rewards = []
            for chunk in sample_chunks:
                rewards_iter = self._compute_reward(chunk, **reward_kwargs)
                if isinstance(rewards_iter, torch.Tensor):
                    rewards_iter = rewards_iter.detach().cpu()
                    
                rewards += list(rewards_iter)
            
        else:
            rewards = self._compute_reward(samples, **reward_kwargs)
            if isinstance(rewards, torch.Tensor):
                rewards = rewards.detach().cpu()
            
        return rewards
    
    def add_data_to_log(self, samples, rewards):
        for i in range(len(samples)):
            self.score_log[samples[i]] = rewards[i]


## Reward Callback

`RewardCallback` handles it loop integration and metric logging for a given `Reward`

In [None]:
# export

class RewardCallback(Callback):
    '''
    RewardCallback - callback wrapper for `Reward` 
    used during `compute_reward` event
    
    Inputs:
    
    - `reward Reward`: reward to use
    
    - `name str`: reward name
    
    - `sample_name str`: sample name to grab from 
    `BatchState` to send to `reward`
    
    - `order int`: callback order
    
    - `track bool`: if metrics should be tracked 
    from this callback
    '''
    def __init__(self, reward, name, sample_name='samples',
                order=10, track=True):
        super().__init__(name=name, order=order)
        
        self.reward = reward
        self.sample_name = sample_name
        self.track = track
        
    def setup(self):
        log = self.environment.log
        log.add_log(self.name)
        if self.track:
            log.add_metric(self.name)
            
    def compute_reward(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state[self.sample_name]
        
        if samples:
            rewards = self.reward(samples)
        else:
            rewards = to_device(torch.tensor(0.))

        batch_state.rewards += rewards
        batch_state[self.name] = rewards
        
        if self.track:
            env.log.update_metric(self.name, rewards.mean().detach().cpu().numpy())

For greater flexibility, `GenericRewardCallback` will pass the entire `BatchState` to `reward`

In [None]:
# export

class GenericRewardCallback(RewardCallback):
    '''
    GenericRewardCallback - generic reward 
    wrapper
    
    Inputs:
    
    - `reward Callable`: reward function. Reward 
    will be passed the entire batch state
    
    - `name str`: reward name
    
    - `order int`: callback order
    
    - `track bool`: if metrics should be tracked 
    from this callback
    '''
    def __init__(self, reward, name, 
                order=10, track=True):
        super().__init__(reward,
                         name,
                         order=order,
                         track=track
                        )
        
    def compute_reward(self):
        env = self.environment
        batch_state = env.batch_state
        rewards = self.reward(batch_state)
        
        batch_state.rewards += rewards
        batch_state[self.name] = rewards
        
        if self.track:
            env.log.update_metric(self.name, rewards.mean().detach().cpu().numpy())

## Reward Modification

As discussed above, reward modifications apply changes to rewards based on some sort of transient batch context. These are rewards that will influence a given batch, but not the logged rewards.

Reward modifications should update the value `BatchState.rewards_final`

In [None]:
# export

class RewardModification(Callback):
    '''
    RewardModification - callback wrapper for `Reward` 
    used during `reward_modification` event
    
    Inputs:
    
    - `reward Reward`: reward to use
    
    - `name str`: reward name
    
    - `sample_name str`: sample name to grab from 
    `BatchState` to send to `reward`
    
    - `order int`: callback order
    
    - `track bool`: if metrics should be tracked 
    from this callback
    '''
    def __init__(self, reward, name, sample_name='samples',
                order=10, track=True):
        super().__init__(name=name, order=order)
        
        self.reward = reward
        self.sample_name = sample_name
        self.track = track
        
    def setup(self):
        log = self.environment.log
        log.add_log(self.name)
        if self.track:
            log.add_metric(self.name)
            
    def reward_modification(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state[self.sample_name]
        
        if samples:
            rewards = self.reward(samples)
        else:
            rewards = 0.

        batch_state.rewards_final += rewards
        batch_state[self.name] = rewards
        
        if self.track:
            env.log.update_metric(self.name, rewards.mean().detach().cpu().numpy())

In [None]:
# export

class NoveltyReward(Callback):
    '''
    NoveltyReward - gives a reward bonus 
    for new samples. Rewards are given a 
    bonus of `weight`
    
    Inputs:
    
    - `weight float`: novelty score weight
    
    - `track bool`: if metrics should be tracked 
    from this callback
    '''
    def __init__(self, weight=1., track=True):
        super().__init__(name='novel')
        
        self.weight = weight
        self.track = track
        
    def setup(self):
        log = self.environment.log
        log.add_log(self.name)
        if self.track:
            log.add_metric(self.name)
            
    def reward_modification(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state.samples
        
        df = env.log.df
        new = (~pd.Series(samples).isin(df.samples)).values
        
        rewards = np.array([float(i) for i in new])*self.weight
        rewards = to_device(torch.from_numpy(rewards).float())

        batch_state.rewards_final += rewards
        batch_state[self.name] = rewards
        
        if self.track:
            env.log.update_metric(self.name, rewards.mean().detach().cpu().numpy())

## Contrastive Reward

Similar to `ContrastiveTemplate`, `ContrastiveReward` provides a wrapper around a `RewardCallback` to adapt it for the task of contrastive generation.

For contrastive generation, we want the model to ingest a source sample and produce a target sample that receives a higher reward than the source sample. `ContrastiveReward` takes some `base_reward` and computes the values of that base reward for both source and target samples, and returns the difference between those rewards.

Optionally, the contrastive reward will scale the relative reward based on a given `max_score` (ie `reward = (target_reward - source_reward)/(max_reward - source_reward)`). This scales the contrastive reward relative to the maximum possible reward 


In [None]:
# export

class ContrastiveReward(RewardCallback):
    '''
    ContrastiveReward - contrastive wrapper for 
    reward callbacks
    
    Inputs:
    
    - `base_reward RewardCallback`: base reward callback
    
    - `max_score Optional[float]`: maximum possible score. 
    If given, contrastive rewards are scaled following 
    `reward = (target_reward - source_reward)/(max_reward - source_reward)`
    '''
    def __init__(self, base_reward, max_score=None):
        super().__init__(reward = base_reward.reward,
                         name = base_reward.name,
                         sample_name = base_reward.sample_name,
                         order = base_reward.order,
                         track = base_reward.track)
        
        self.base_reward = base_reward
        self.max_score = max_score
    
    def setup(self):
        self.base_reward.environment = self.environment
        
    def __call__(self, event_name):
        
        event = getattr(self, event_name, None)
        
        if event is not None:
            output = event()
        else:
            output = None
            
        if not event_name=='compute_reward':
            _ = self.base_reward(event_name)
            
        return output
        
    def compute_and_clean(self, samples):
        rewards = self.base_reward.reward(samples)
        if isinstance(rewards, torch.Tensor):
            rewards = rewards.detach().cpu()
            
        rewards = np.array(rewards)
        return rewards
        
    def _compute_reward(self, samples):
        source_samples = [i[0] for i in samples]
        target_samples = [i[1] for i in samples]
        
        source_rewards = self.compute_and_clean(source_samples)
        target_rewards = self.compute_and_clean(target_samples)
        
        rewards = target_rewards - source_rewards
        if self.max_score is not None:
            rewards = rewards / (self.max_score-source_rewards)
            
        rewards = to_device(torch.from_numpy(rewards).float())
            
        return rewards
    
    def compute_reward(self):
        env = self.environment
        batch_state = env.batch_state
        samples = batch_state[self.sample_name]
        
        rewards = self._compute_reward(samples)
        
        batch_state.rewards += rewards
        batch_state[self.name] = rewards
        
        if self.track:
            env.log.update_metric(self.name, rewards.mean().detach().cpu().numpy())
    