In [None]:
# default_exp templates.template

# Templates

> Templates are used to define chemical spaces using easily calculable properties


## Overview

Tempates are a core concept in MRL used to define chemical spaces. Tempates collect a series of molecular heuristics and validate if a molecule meets those criteria. For example:

```
Molecular weight: 250-450
Rotatable bonds: Less than 8
PAINS Filter: Pass
```

Templates can also be used to assign a score for meeting heuristic criteria. This allows us to define different criteria for __must-have__ molecular properties versus __nice-to-have___ chemical properties. In a reinforcement learning context, this translates into giving a score bonus to molecules that fit the nice-to-have criteria. Scores can also be negative to allow for penalizing a molecule that still passes the must-have criteria.

```
Must Have:
Molecular weight: 250-450, 
Rotatable bonds: Less than 8
PAINS Filter: Pass

Nice To Have:
Molecular weight: 350-400 (+1), 
TPSA: Less than 80 (+1)
Substructure Match: '[#6]1:[#6]:[#7]:[#6]:[#6]:[#6]:1' (+3)
Substructure Match: '[#6]1:[#6]:[#7]:[#7]:[#7]:[#6]:1' (-1)
```

Based on the above criteria, a molecule that passes the must-have criteria could get a score between -1 and +5 based on meeting the nice-to-have criteria.

Templates are instantiated through the `Template` class. A template is a collections of filters, created through the `Filter` class. See `Filter` for more details on defining filter functions.

Templates contain two sets of filters - hard filters and soft filters. Hard filters contain the must-have citeria for a molecule, while the soft filters contain the nice-to-have criteria.

During model training, a generative model creates a batch of compounds. These compounds are first screened against the hard filters. Compounds that fail the hard filters can then be excluded from the training batch or assigned a default failure score. Compounds that pass the hard filters are then scored using the soft filters. Soft filters can provide a small score bonus or penalty for a molecule in addition to the main score function.

Soft filters incentivise a genertive model to maximize the soft filter conditions without making them a hard requirement. This allows soft filters to be highly targeted towards narrow property ranges or highly specific substructures. If these highly targeted criteria were set as hard filters, they might invalidate too many compounds and cause the model to struggle during training.

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.filters import *

  return f(*args, **kwds)


## Templates

The `Template` class holds a collection of hard filters and soft filters and manages screening a molecule against those filters. Templates by default log the fitler results in detail, allowing you to inspect which filters a molecule failed.

Templates can be merged by adding templates together, ie `new_template = template1 + template2`. Adding two templates merges the hard and soft filters in each template.

Templates by default will log all molecules passed through the hard and soft filters to an internal dataframe, and build a lookup table of `{smile : final_score}`. The first time a molecule is screened, it is added to the internal dataframe and lookup table. If that molecule is seen again (defined by the smiles string), the lookup value is returned. If this behavior isn't desired, pass `log=False` to disable all logging and lookup, or `use_lookup=False` to keep logging but avoid using the lookup table.

In [None]:
# export

class Template():
    '''
    Template - class for managing hard and soft filters
    
    Inputs:
        `hard_filters` - (list), list of `Filter` objects used for pass/fail screening
        
        `soft_filters` - (list), list of `Filter` objects used for soft scoring
        
        `log` - (bool), if True, template will log screened compounds
        
        `use_lookup` - (bool), if True, filter results are stored in a lookup table. If a compound 
            is re-screened, the lookup value is returned
            
        `fail_score` - (float), placeholder score for compounds that fail to pass hard filters
            
    '''
    def __init__(self, hard_filters, soft_filters=[], log=True, use_lookup=True, fail_score=0.):
        self.hard_filters = hard_filters
        self.soft_filters = soft_filters
        self.log = log
        self.use_lookup = use_lookup
        self.fail_score = fail_score
        
        self.hard_log = pd.DataFrame(columns=['smiles']+list(range(len(self.hard_filters)))+['final'])
        self.hard_col_names = ['smiles'] + [i.name for i in self.hard_filters] + ['final']
        self.hard_lookup = {}
        
        self.soft_log = pd.DataFrame(columns=['smiles']+list(range(len(self.soft_filters)))+['final'])
        self.soft_col_names = ['smiles'] + [i.name for i in self.soft_filters] + ['final']
        self.soft_lookup = {}
        
    def __call__(self, mols, filter_type='hard'):
        
        if filter_type=='hard':
            outputs = maybe_parallel(self.hf, mols)
        else:
            outputs = maybe_parallel(self.sf, mols)
            
        if is_container(mols):
            return_outputs = [i[0] for i in outputs]
            log_outputs = [i[1] for i in outputs if i[1]]
        else:
            return_outputs = outputs[0]
            log_outputs = [outputs[1]] if outputs[1] else []
        
        self.log_data(log_outputs, filter_type=filter_type)
        self.clean_logs()
        
        return return_outputs
    
    def eval_mols(self, mols):
        hardpass = self.__call__(mols, filter_type='hard')
        
        remaining = []
        idxs = []
        scores = []
        for i, mol in enumerate(mols):
            if hardpass[i]:
                remaining.append(mol)
                idxs.append(i)
                scores.append(0)
            else:
                scores.append(self.fail_score)
                
        if remaining:
            soft_scores = self.__call__(remaining, filter_type='soft')
            for i, score in enumerate(soft_scores):
                idx = idxs[i]
                scores[idx] = score
        
        return scores
        
    def hf(self, mol, agg=True):
        'run hard filters'
        mol = to_mol(mol) # future note - update for proteins
        smile = to_smile(mol)
        
        if self.use_lookup and smile in self.hard_lookup.keys():
            output = self.hard_lookup[smile]
            log_data = []
            
        else:
            filter_results = []
            for filt in self.hard_filters:
                try:
                    filter_results.append(filt(mol, with_score=False))
                except:
                    filter_results.append(False)
                
            if agg:
                output = all(filter_results)
                log_data = [smile]+filter_results+[output]
                    
            else:
                output = filter_results
                log_data = []
            
        return output, log_data
    
    def sf(self, mol):
        'run soft filters'
        mol = to_mol(mol) # future note - update for proteins
        smile = to_smile(mol)
        
        if self.use_lookup and smile in self.soft_lookup.keys():
            output = self.soft_lookup[smile]
            log_data = []
            
        else:
            filter_results = []
            for filt in self.soft_filters:
                filter_results.append(filt(mol, with_score=True))
                
            output = sum(filter_results)
            log_data = [smile]+filter_results+[output]
            
        return output, log_data
    
    def screen_mols(self, mols):
        'separate `mols` into passes and failures'
        hardpasses = self.__call__(mols, filter_type='hard')
        
        fails = []
        remaining = []
        remaining_idxs = []
        
        for i in range(len(hardpasses)):
            if hardpasses[i]:
                remaining.append(mols[i])
                remaining_idxs.append(i)
            else:
                fails.append(mols[i])
           
        passes = []
        if remaining:
            softpasses = self.__call__(remaining, filter_type='soft')
            passes = list(zip(remaining, softpasses, remaining_idxs))
            
        return [passes, fails]
    
    def log_data(self, new_data, filter_type='hard'):
        
        if self.log and new_data:
            if filter_type=='hard':
                new_df = pd.DataFrame(new_data, columns=self.hard_log.columns)
                self.hard_log = self.hard_log.append(new_df)

            if filter_type=='soft':
                new_df = pd.DataFrame(new_data, columns=self.soft_log.columns)
                self.soft_log = self.soft_log.append(new_df)

            if self.use_lookup:
                for item in new_data:
                    smile = item[0]
                    score = item[-1]

                    if filter_type=='hard' and not smile in self.hard_lookup.keys():
                        self.hard_lookup[smile] = score

                    if filter_type=='soft' and not smile in self.soft_lookup.keys():
                        self.soft_lookup[smile] = score

    def clean_logs(self):
        'de-duplicate logs'
        self.hard_log.drop_duplicates(subset='smiles')
        self.hard_log.reset_index(inplace=True, drop=True)
        self.soft_log.drop_duplicates(subset='smiles')
        self.soft_log.reset_index(inplace=True, drop=True)
        
    def clear_data(self):
        'delete logged data'
        self.hard_log = pd.DataFrame(columns=['smiles']+list(range(len(self.hard_filters)))+['final'])
        self.hard_lookup = {}

        self.soft_log = pd.DataFrame(columns=['smiles']+list(range(len(self.soft_filters)))+['final'])
        self.soft_lookup = {}
            
    def sample(self, n, log='hard', seed=None):
        'sample logged data'
        
        if seed is None:
            seed = np.random.seed() # required to prevent identical sampling in multiprocessing
        
        if log=='hard':
            to_sample = self.hard_log[self.hard_log.final==True]
            sample = to_sample.sample(n, replace=False, random_state=seed).reset_index(drop=True)
        else:
            sample = self.soft_log.sample(n, replace=False, random_state=seed).reset_index(drop=True)
            
        return sample[['smiles', 'final']]
    
    
    def sample_smiles(self, n, log='hard'):
        return list(self.sample(n, log=log).smiles.values)
    
    def save(self, filename, with_data=True):
        '''
        save - save `Template` object
        
        Inputs
            'filename' - str, save filename
            `with_data` - bool, if True Template is saved with logged data
        '''
        if not with_data:
            hard_log = self.hard_log
            hard_lookup = self.hard_lookup

            soft_log = self.soft_log
            soft_lookup = self.soft_lookup
            
            self.clear_data()
        
        with open(filename, 'wb') as f:
            pickle.dump(self, f)
            
        if not with_data:  
            self.hard_log = hard_log
            self.hard_lookup = hard_lookup

            self.soft_log = soft_log
            self.soft_lookup = soft_lookup
    
    @classmethod
    def from_file(cls, filename):
        'load template from file'
        template = pickle.load(open(filename, 'rb'))  
        return template
    
    def __add__(self, other, merge_data=True):
        'merge two templates. If `merge_data`, logged data from each template is rescreened'
        hard_filters = self.hard_filters + other.hard_filters
        hard_filters = sorted(hard_filters, key=lambda x: x.priority, reverse=True)
        
        soft_filters = self.soft_filters + other.soft_filters
        soft_filters = sorted(soft_filters, key=lambda x: x.priority, reverse=True)
        
        if merge_data:
            soft_smiles = list(self.soft_log.smiles.values) + list(other.soft_log.smiles.values)
            soft_smiles = list(set(soft_smiles))

            hard_smiles = list(self.hard_log.smiles.values) + list(other.hard_log.smiles.values)
            hard_smiles = list(set(hard_smiles))

            new_template = Template(hard_filters, soft_filters, use_lookup=self.use_lookup)
            _ = new_template(hard_smiles, filter_type='hard')
            _ = new_template(soft_smiles, filter_type='soft')
        
        return new_template
    
    def __repr__(self):
        hf = 'Hard Filter:\n\t\t' + '\n\t\t'.join([i.__repr__() for i in self.hard_filters])
        sf = 'Soft Filter:\n\t\t' + '\n\t\t'.join([i.__repr__() for i in self.soft_filters])
        rep_str = 'Template\n\t' + hf + '\n\t' + sf
        return rep_str

In [None]:
# export

class BlankTemplate(Template):
    "Empty template (no hard or soft filters)"
    def __init__(self):
        super().__init__([],[])

class ValidMoleculeTemplate(Template):
    'Template for checking if an input is a single valid chemical structure'
    def __init__(self, hard=True, soft=False):
        
        if hard:
            hard_filters = [
                ValidityFilter(),
                SingleCompoundFilter()
            ]
        else:
            hard_filters = []
        
        if soft:
            soft_fiters = [
                ValidityFilter(score=1),
                SingleCompoundFilter(score=1)
            ]
        else:
            soft_filters = []
            
        super().__init__(hard_filters, soft_filters)

class RuleOf5Template(Template):
    "Template for Lipinski's rule of 5 (en.wikipedia.org/wiki/Lipinski%27s_rule_of_five)"
    def __init__(self, hard=True, soft=False):
        
        if hard:
            hard_filters = [
                HBDFilter(None, 5),
                HBAFilter(None, 10),
                MolWtFilter(None, 500),
                LogPFilter(None, 5)
            ]
        else:
            hard_filters = []
        
        if soft:
            soft_filters = [
                HBDFilter(None, 5, score=1),
                HBAFilter(None, 10, score=1),
                MolWtFilter(None, 500, score=1),
                LogPFilter(None, 5, score=1)
            ]
        else:
            soft_filters = []
            
        super().__init__(hard_filters, soft_filters)
        
class GhoseTemplate(Template):
    "Template for Ghose filters (doi.org/10.1021/cc9800071)"
    def __init__(self, hard=True, soft=False):
        
        if hard:
            hard_filters = [
                MolWtFilter(160, 480),
                LogPFilter(-0.4, 5.6),
                HeavyAtomsFilter(20, 70),
                MRFilter(40, 130)
            ]
        else:
            hard_filters = []
        
        if soft:
            soft_filters = [
                MolWtFilter(160, 480, score=1),
                LogPFilter(-0.4, 5.6, score=1),
                HeavyAtomsFilter(20, 70, score=1),
                MRFilter(40, 130, score=1)
            ]
        else:
            soft_filters = []
            
        super().__init__(hard_filters, soft_filters)
        
class VeberTemplate(Template):
    "Template for Veber filters (doi.org/10.1021/jm020017n)"
    def __init__(self, hard=True, soft=False):
        
        if hard:
            hard_filters = [
                RotBondFilter(None, 10),
                TPSAFilter(none, 140)
            ]
        else:
            hard_filters = []
            
        if soft:
            soft_filters = [
                RotBondFilter(None, 10, score=1),
                TPSAFilter(none, 140, score=1)
            ]
        else:
            soft_filters = []
            
        super().__init__(hard_filters, soft_filters)
        
class REOSTemplate(Template):
    "Template for REOS filters (10.1016/s0169-409x(02)00003-0)"
    def __init__(self, hard=True, soft=False):
        
        if hard:
            hard_filters = [
                MolWtFilter(200, 500),
                LogPFilter(-5, 5),
                HBDFilter(0, 5),
                HBAFilter(0, 10),
                ChargeFilter(-2, 2),
                RotBondFilter(0, 8),
                HeavyAtomsFilter(15, 50)
            ]
        else:
            hard_filters = []
            
        if soft:
            soft_filters = [
                MolWtFilter(200, 500, score=1),
                LogPFilter(-5, 5, score=1),
                HBDFilter(0, 5, score=1),
                HBAFilter(0, 10, score=1),
                ChargeFilter(-2, 2, score=1),
                RotBondFilter(0, 8, score=1),
                HeavyAtomsFilter(15, 50, score=1)
            ]
        else:
            soft_filters = []
            
        super().__init__(hard_filters, soft_filters)
        
class RuleOf3Template(Template):
    "Template for rule of 5 filter (doi.org/10.1016/S1359-6446(03)02831-9)"
    def __init__(self, hard=True, soft=False):
        
        if hard:
            hard_filters = [
                MolWtFilter(None, 300),
                LogPFilter(None, 3),
                HBDFilter(None, 3),
                HBAFilter(None, 3),
                RotBondFilter(None, 3)
            ]
        else:
            hard_filters = []
            
        if soft:
            soft_filters = [
                MolWtFilter(None, 300, score=1),
                LogPFilter(None, 3, score=1),
                HBDFilter(None, 3, score=1),
                HBAFilter(None, 3, score=1),
                RotBondFilter(None, 3, score=1)
            ]
        else:
            soft_filters = []
            
        super().__init__(hard_filters, soft_filters)
        

In [None]:
smiles = [
    'c1ccccc1',
    'Cc1cc(NC)ccc1',
    'Cc1cc(NC)cnc1',
    'Cc1cccc(NCc2ccccc2)c1'
]

mols = [to_mol(i) for i in smiles]

# hard filters
hard_filters = [
    ValidityFilter(),
    SingleCompoundFilter(),
    MolWtFilter(None, 500),
    HBDFilter(None, 5),
    HBAFilter(None, 10),
    LogPFilter(None, 5)
    ]

# soft filters
soft_filters = [
    TPSAFilter(None, 110, score=1),
    RotBondFilter(None, 8, score=1),
    StructureFilter(['[*]-[#6]1:[#6]:[#6]:[#6]2:[#7]:[#6]:[#7H]:[#6]:2:[#6]:1'],
                    exclude=False, score=1)
    ]

template = Template(hard_filters, soft_filters)
assert template.hf('CC1=CN=C(C(=C1OC)C)CS(=O)C2=NC3=C(N2)C=C(C=C3)OC')[0]
assert template.sf('CC1=CN=C(C(=C1OC)C)CS(=O)C2=NC3=C(N2)C=C(C=C3)OC')[0]==3.0
assert template('CC1=CN=C(C(=C1OC)C)CS(=O)C2=NC3=C(N2)C=C(C=C3)OC')

In [None]:
t1 = Template(hard_filters[:3], soft_filters[:2])
t2 = Template(hard_filters[3:], soft_filters[2:])

assert (t1+t2)(mols) == template(mols)
assert (t1+t2)(mols, filter_type='soft') == template(mols, filter_type='soft')

In [None]:
template = RuleOf3Template()
df = pd.read_csv('files/smiles.csv')
passes, fails = template.screen_mols(df.smiles.values)
assert len(passes) == 92
assert all(template.sample(50, log='hard').final.values==True)

In [None]:
template.save('files/test_temp.template', with_data=False)
template2 = Template.from_file('files/test_temp.template')
assert template2.hard_log.shape[0]==0
assert template.hard_log.shape[0]==2000
os.remove('files/test_temp.template')

In [None]:
template = Template([])
assert template(mols) == [True, True, True, True]

In [None]:
# hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_chem.ipynb.
Converted 02_template.filters.ipynb.
Converted 03_template.template.ipynb.
Converted 04_template.blocks.ipynb.
Converted 05_torch_core.ipynb.
Converted 06_layers.ipynb.
Converted 07_dataloaders.ipynb.
Converted index.ipynb.
Converted template.overview.ipynb.
Converted tutorials.ipynb.
Converted tutorials.structure_enumeration.ipynb.
Converted tutorials.template.advanced.ipynb.
Converted tutorials.template.beginner.ipynb.
Converted tutorials.template.intermediate.ipynb.
