In [None]:
# default_exp reward

# Reward

> Reward function

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export

from mrl.imports import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.chem import *
from mrl.templates import *
from mrl.agent import *
from mrl.environment import *

In [None]:
# export

class Reward(Callback):
    def __init__(self, name, order=10, weight=1., track=True):
        self.name = name
        self.order = order
        self.track = track
        self.weight = weight
        
    def setup(self):
        log = self.environment.log
        log.add_log(self.name)
        if self.track:
            log.add_metric(self.name)
            
    def _compute_reward(self):
        raise NotImplementedError
    
    def compute_reward(self):
        rewards = self._compute_reward()
        rewards = rewards.squeeze()
        self.batch_state.rewards += self.weight*rewards
        self.batch_state[self.name] = rewards
        
        if self.track:
            self.environment.log.update_metric(self.name, rewards.mean().detach().cpu().numpy())
            
            
class FunctionReward(Reward):
    def __init__(self, reward_function, name, order=10, weight=1., track=True):
        super().__init__(name, order, weight, track)
        self.reward_function = reward_function
        
    def _compute_reward(self):
        return self.reward_function(self.batch_state)
        
        
class SampleReward(Reward):
    def __init__(self, reward_function, template_filter, lookup, 
                 name, order=10, weight=1., track=True):
        super().__init__(name, order, weight, track)
        self.reward_function = reward_function
        self.lookup = lookup
        self.template_filter = template_filter
        self.lookup_table = {}
    
    def _compute_reward(self):
        
        samples = self.batch_state.samples
        hps = self.batch_state.template_passes
        outputs = to_device(torch.tensor([0. for i in samples]))
        
        to_score = []
        to_score_idxs = []
        
        for i, sample in enumerate(samples):
            if self.lookup and sample in self.lookup_table.keys():
                outputs[i] = self.lookup_table[sample]
                
            else:
                if (self.template_filter and hps[i]) or (not self.template_filter):
                    to_score.append(sample)
                    to_score_idxs.append(i)
                    
        if to_score:
            scores = self.reward_function(samples)
            
            for i in range(len(to_score)):
                outputs[to_score_idxs[i]] = scores[i]
            
                if self.lookup:
                    self.lookup_table[to_score[i]] = scores[i]
        
        
        return outputs
    

In [None]:
# export

class NoveltyBonus(Reward):
    def __init__(self, weight, name='novel', order=100, track=True):
        super().__init__(name, order, weight, track)
        
    def _compute_reward(self):
        log = self.environment.log
        state = self.batch_state
        old = log.unique_samples
        
        new = [not i in old for i in state.samples]
        reward = to_device(torch.tensor(new)).float()
        return reward


In [None]:
class PredReward(Callback):
    def __init__(self, name, agent, weight=1.):
        super().__init__(order=1)
        self.name = name
        self.weight = weight
        self.agent = agent
        
    def setup(self):
        log = self.environment.log
        log.add_metric(self.name)
        log.add_log(self.name)
        
    def compute_reward(self):
        env = self.environment
        samples = self.batch_state.samples
        with torch.no_grad():
            preds = self.agent.predict_data(samples).squeeze()
        reward = -preds * self.weight
        
        env.log.update_metric(self.name, reward.mean().detach().cpu().numpy())
        self.batch_state.rewards += reward
        self.batch_state[self.name] = reward

In [None]:
class Callback():
    def __init__(self, name='callback', order=10):
        self.order=order
        self.name = name
    
    def __call__(self, event_name):
        
        event = getattr(self, event_name, None)
        if event is not None:
            output = event()
        else:
            output = None
            
        return output

In [None]:

class Event():
    def __init__(self):
        self.setup = 'setup'
        self.before_train = 'before_train'
        self.build_buffer = 'build_buffer'
        self.after_build_buffer = 'after_build_buffer'
        self.before_batch = 'before_batch'
        self.sample_batch = 'sample_batch'
        self.after_sample = 'after_sample'
        self.get_model_outputs = 'get_model_outputs'
        self.compute_reward = 'compute_reward'
        self.after_compute_reward = 'after_compute_reward'
        self.compute_loss = 'compute_loss'
        self.zero_grad = 'zero_grad'
        self.before_step = 'before_step'
        self.step = 'step'
        self.after_batch = 'after_batch'
        self.after_train = 'after_train'

In [None]:

class Reward():
    def __init__(self, template=None, reward_modules=[], trajectory_modules=[]):
        
        if template == None:
            template = Template([])
            
        self.template = template
        self.reward_modules = reward_modules
        self.trajectory_modules = trajectory_modules
        self.mean_reward = None
        
    def __call__(self, model_output):
        
        template_passes = np.array(np.array(self.template(model_output['sequences'])))
        template_rewards = np.array(self.template.eval_mols(model_output['sequences']))
        
        rewards = self.compute_rewards(model_output, template_passes)
        trajectory_rewards = self.compute_trajectory_reward(model_output, template_passes)
        
        rewards = template_rewards + rewards
        
        if self.mean_reward is None:
            self.mean_reward = rewards.mean()
        else:
            self.mean_reward = (1-reward_decay)*rewards.mean() + reward_decay*self.mean_reward
            
        rewards_scaled = rewards - self.mean_rewards
        
        model_output['rewards'] = rewards
        model_output['rewards_scaled'] = rewards_scaled
        model_output['trajectory_rewards'] = trajectory_rewards
        
        return model_output
    
    def compute_trajectory_reward(self, model_output, template_passes):
        
        all_rewards = []
        
        for rm in self.trajectory_modules:
            all_rewards.append(rm(model_output, template_passes))
            
        all_rewards = np.stack(all_rewards, -1)
        all_rewards = all_rewards.sum(-1)
        return all_rewards
    
    def compute_rewards(self, model_output, template_passes):
        
        all_rewards = []
        
        for rm in self.reward_modules:
            all_rewards.append(rm(model_output, template_passes))
            
        all_rewards = np.stack(all_rewards, -1)
        all_rewards = all_rewards.sum(-1)
        return all_rewards

In [None]:

def trajectory_wrapper(inputs, function):
    return np.array([function(i) for i in inputs])

In [None]:

class RewardModule():
    
    def __call__(self, model_output, template_passes=None):
        
        reward_inputs = self.prepare_reward_inputs(model_output, template_passes)
        reward_outputs = self.reward_function(reward_inputs)
        final_reward = self.aggregate_reward(reward_outputs, model_output, template_passes)
        return final_reward
        
    def aggregate_reward(self, reward_outputs, model_output, template_passes=None):
        pass
        
    def prepare_reward_inputs(self, model_output, template_passes=None):
        pass
    
    def reward_function(self, inputs):
        pass
    
class MolReward(RewardModule):
    def __init__(self, mol_function, trajectory=False):
        self.mol_function = mol_function
        self.trajectory = trajectory
        
    def aggregate_reward(self, reward_outputs, model_output, template_passes=None):
        
        if template_passes is not None:
            passed_idxs = np.array([i for i in range(len(template_passes)) if template_passes[i]])
            bs = len(template_passes)
        else:
            passed_idxs = np.arange(len(reward_outputs))
            bs = len(reward_outputs)
        
        if self.trajectory:
            outputs = np.zeros((bs, model_output['sl']))
            
            for i, idx in enumerate(passed_idxs):
                traj = reward_outputs[i]
                traj_len = len(traj)
                outputs[idx, :traj_len] = traj
                
        else:
            outputs = np.zeros((bs))
            outputs[passed_idxs] = reward_outputs
            
        return outputs
        
    def prepare_reward_inputs(self, model_output, template_passes=None):
        
        if self.trajectory:
            inputs = model_output['sequence_trajectories']
        else:
            inputs = model_output['sequences']
            
        output = np.zeros((len(inputs)))
        
        if template_passes is not None:
            inputs = [inputs[i] for i in range(len(inputs)) if template_passes[i]]
            
        return inputs
    
    def reward_function(self, inputs):
        if self.trajectory:
            func = partial(trajectory_wrapper, function=self.mol_function)
        else:
            func = self.mol_function
            
        return maybe_parallel(func, inputs)
    
class MLReward():
    def __init__(self, model, trajectory=False):
        self.model = model
        self.trajectory = trajectory
        
    def reward_function(self, inputs):
        if not type(inputs)==list:
            inputs = [inputs]
        return np.array(self.model(*inputs).detach().cpu())
    
    def prepare_reward_inputs(self, model_output, template_passes=None):
        raise NotImplementedError
        
    def aggregate_reward(self, reward_outputs, model_output, template_passes=None):
        if template_passes is not None:
            passed_idxs = np.array([i for i in range(len(template_passes)) if template_passes[i]])
            bs = len(template_passes)
        else:
            passed_idxs = np.arange(len(reward_outputs))
            bs = len(reward_outputs)
        
        if reward_outputs.ndim==2:
            output = np.zeros((bs, reward_outputs.shape[-1]))
            output[passed_idxs] = reward_output
            
        else:
            output = np.zeros((bs,))
            output[passed_idxs] = reward_outputs
            
        return output
    
class FPModelReward(MLReward):
    def __init__(self, model, fp_func, trajectory=False):
        super().__init__(model, trajectory)
        self.fp_func = fp_func
        
    def prepare_reward_inputs(self, model_output, template_passes=None):
        
        smiles = model_output['sequences']
        fps = np.stack(maybe_parallel(self.fp_func, smiles))
        fps = to_device(torch.from_numpy(fps).float())
        return fps
    
class SequenceModelReward(MLReward):

    def prepare_reward_inputs(self, model_output, template_passes=None):
        
        return model_output['x']

In [None]:
def mf(smile):
    mol = to_mol(smile)
    if mol is None:
        output = 0.
    else:
        output = qed(mol)
        
    return output

In [None]:
r = MolReward(mf, trajectory=True)

In [None]:
mo = ModelOutput()

In [None]:
mo['sequences'] = ['C', 'CCC', 'CCCC']
mo['sequence_trajectories'] = [['C'], ['C', 'CC', 'CCC'], ['C', 'CC', 'CCC', 'CCCC']]
mo['sl'] = 4
template_passes = np.array([True, False, True])

In [None]:
r(mo, template_passes)

array([[0.35978494, 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.35978494, 0.37278556, 0.38547066, 0.43102436]])

In [None]:
r.prepare_reward_inputs(mo)

[['C'], ['C', 'CC', 'CCC'], ['C', 'CC', 'CCC', 'CCCC']]

In [None]:
len(r.reward_function(r.prepare_reward_inputs(mo)))

3

In [None]:
np.array([[5]]).ndim

2

In [None]:
np.array([['C'], ['C', 'CC', 'CCC'], ['C', 'CC', 'CCC', 'CCCC']], dtype=object).ndim

1

parallel reward
    parallel process calculation on one sequence at a time
    
batch reward
    parallel featurize
    batch
    compute