# Filters

> Filter related functions

In [None]:
#| default_exp filter

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from chem_templates.imports import *
from chem_templates.utils import *
from chem_templates.chem import Molecule, Catalog, mol_func_wrapper
from rdkit.Chem.FilterCatalog import SmartsMatcher

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class FilterResult():
    'Container for filter results'
    def __init__(self, 
                 filter_result: bool, # overall filter result (True or False)
                 filter_name:   str,  # name of filter
                 filter_data:   dict  # filter data dict
                ):
        
        self.filter_result = filter_result
        self.filter_name = filter_name
        self.filter_data = filter_data
        
    def __repr__(self):
        return f'{self.filter_name} result: {self.filter_result}'

class Filter():
    'Filter base class'
    def __init__(self, name='filter' # filter name
                ):
        self.name = name
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        return FilterResult(True, self.name, {})
    
    def __repr__(self):
        return self.name

In [None]:
#| export

class ValidityFilter(Filter):
    'Checks if molecule is valid'
    def __init__(self):
        self.name = 'validity_filter'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        return FilterResult(molecule.valid, self.name, {})
    
class SingleCompoundFilter(Filter):
    'Checks if molecule is a single compound'
    def __init__(self):
        self.name = 'single_compound'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result = not ('.' in molecule.smile)
        return FilterResult(result, self.name, {})

In [None]:
mol1 = Molecule('CCCC')
mol2 = Molecule('CCCc')
mol3 = Molecule('CCC.CCCC')

f1 = ValidityFilter()
f2 = SingleCompoundFilter()

assert f1(mol1).filter_result
assert not f1(mol2).filter_result
assert f1(mol3).filter_result

assert f2(mol1).filter_result
assert not f2(mol3).filter_result

In [None]:
#| export

class AttachmentCountFilter(Filter):
    'Checks number of dummy attachment atoms'
    def __init__(self, 
                 num_attachments: int):
        
        self.num_attachments = num_attachments
        self.name = f'attachment_count_{num_attachments}'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        num_attachments = molecule.smile.count('*')
        
        result = num_attachments == self.num_attachments
        data = {'num_attachments' : num_attachments}
        
        return FilterResult(result, self.name, data)


In [None]:
mol1 = Molecule('[*:1]CC')
mol2 = Molecule('[*:1]CC[*:2]')
f = AttachmentCountFilter(1)

assert f(mol1).filter_result
assert not f(mol2).filter_result

In [None]:
#| export

class BinaryFunctionFilter(Filter):
    def __init__(self, 
            func: Callable[[Molecule], bool], # callable function that takes a Molecule as input and returns a bool
            name: str # filter name
                ):
        'Filters based on the result of `func`'
        
        self.name = name
        self.func = func
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result = self.func(molecule)
        
        return FilterResult(result, self.name, {})
    
class DataFunctionFilter(Filter):
    def __init__(self, 
            func: Callable[[Molecule], Tuple[bool, dict]], # callable that takes a Molecule and returns (bool, dict)
            name: str # filter name
                ):
        "Filters based on the result of `func`. Data from function is added to the filter result"
        
        self.name = name
        self.func = func
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result, data = self.func(molecule)
        
        return FilterResult(result, self.name, data)

In [None]:
from rdkit.Chem import rdMolDescriptors
mol1 = Molecule('Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1)=NC2')
mol2 = Molecule('CCCC')

def filter_func(molecule):
    n_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
    return n_rings>1

f = BinaryFunctionFilter(filter_func, 'has_ring')

assert f(mol1).filter_result
assert not f(mol2).filter_result

In [None]:
#| export

class RangeFunctionFilter(Filter):
    def __init__(self, 
                 func:    Callable[[Molecule], Union[int, float]], # callable function, takes a Molecule as input, returns a numeric value
                 name:    str, # filter name
                 min_val: Union[int, float, None]=None, # min acceptable range value (if None, defaults to -inf)
                 max_val: Union[int, float, None]=None  # max acceptable range value (if None, defaults to inf)
                ):
        
        '''
        `RangeFunctionFilter` passes a `Molecule` to `func`, then checks if the output is 
        between `min_val` and `max_val`
        '''
        
        min_val, max_val = validate_range(min_val, max_val, float('-inf'), float('inf'))
        
        self.func = func
        self.min_val = min_val
        self.max_val = max_val
        self.name = name
                
    def __call__(self, molecule: Molecule) -> FilterResult:
        value = self.func(molecule)
        data = {'computed_value' : value, 'min_val' : self.min_val, 'max_val' : self.max_val}
        result = self.min_val <= value <= self.max_val
        
        return FilterResult(result, self.name, data)

In [None]:
from rdkit.Chem import rdMolDescriptors

molwt = mol_func_wrapper(rdMolDescriptors.CalcExactMolWt)

f = RangeFunctionFilter(molwt, 'molwt_filter', 250, 350)


mol1 = Molecule('Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1)=NC2')
mol2 = Molecule('CCCC')

assert f(mol1).filter_result
assert not f(mol2).filter_result

In [None]:
#| export

class SmartsFilter(Filter):
    def __init__(self, 
                 smarts:  str, # SMARTS string 
                 name:    str, # filter name
                 exclude: bool=True, # if filter should be exclusion or inclusion
                 min_val: Union[int, float, None]=None, # min number of occurences 
                 max_val: Union[int, float, None]=None # max number of occurences 
                ): 
        
        '''
        `SmartsFilter` checks to see if `smarts` is present in a Molecule. If 
        `min_val` and `max_val` are passed, the filter will check to see if the number 
        of occurences are between those values. If `exclude=True`, the filter will 
        fail molecules that match the filter. Otherwise, filter will fail molecules 
        that don't match the filter
        '''
        
        min_val, max_val = validate_range(min_val, max_val, 1, int(1e8))
        
        self.smarts = smarts
        self.name = name
        self.exclude = exclude
        self.min_val = min_val
        self.max_val = max_val
        self.smarts_matcher = SmartsMatcher(self.name, self.smarts, self.min_val, self.max_val)
        
    def has_match(self, molecule: Molecule) -> bool:
        return self.smarts_matcher.HasMatch(molecule.mol)
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        
        has_match = self.has_match(molecule)
        result = not has_match if self.exclude else has_match
        data = {'filter_result' : has_match}
        
        return FilterResult(result, self.name, data)

In [None]:
smarts = '[#6]1:[#6]:[#6]:[#6]:[#6]:[#6]:1'

f1 = SmartsFilter(smarts, 'one_phenyl', min_val=1, max_val=1)
f2 = SmartsFilter(smarts, 'two_phenyl', min_val=2)

smiles = [
    'c1ccccc1',
    'Cc1cc(NC)cnc1',
    'Cc1cccc(NCc2ccccc2)c1'
]

molecules = [Molecule(i) for i in smiles]

assert f1.has_match(molecules[0])
assert not f1.has_match(molecules[1])
assert not f1.has_match(molecules[2])

assert not f2.has_match(molecules[0])
assert not f2.has_match(molecules[1])
assert f2.has_match(molecules[2])

In [None]:
#| export

class CatalogFilter(Filter):
    def __init__(self, 
                 catalog: Catalog, # SMARTS catalog
                 name:    str, # filter name
                 exclude: bool=True # if filter should be exclusion or inclusion
                ):
        
        '''
        `CatalogFilter` checks to see if a molecule has a match against the provided `Catalog`. 
        If `exclude=True`, matching molecules fail the filter. Otherwise, matching molecules will pass
        '''
        
        self.catalog = catalog
        self.name = name
        self.exclude = exclude
        
    def has_match(self, molecule: Molecule) -> bool:
        return self.catalog.has_match(molecule)
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        has_match = self.has_match(molecule)
        result = not has_match if self.exclude else has_match
        data = {'filter_result' : has_match}
        
        return FilterResult(result, self.name, data)

In [None]:
from rdkit.Chem.FilterCatalog import FilterCatalogParams

catalog = Catalog.from_params(FilterCatalogParams.FilterCatalogs.PAINS)
f = CatalogFilter(catalog, 'pains')
molecule = Molecule('c1ccccc1N=Nc1ccccc1')

assert f.has_match(molecule)
assert not f(molecule).filter_result

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()