# Filters

> Filter related functions

In [None]:
#| default_exp filter

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from chem_templates.imports import *
from chem_templates.utils import *
from chem_templates.chem import Molecule

In [None]:
#| export

class FilterResult():
    def __init__(self, filter_result: bool, filter_name: str, filter_data: dict):
        self.filter_result = filter_result
        self.filter_data = filter_data

class Filter():
    def __init__(self, name='filter'):
        self.name = name
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        return FilterResult(True, self.name, {})
    
    def __repr__(self):
        return self.name

In [None]:
#| export

class ValidityFilter(Filter):
    def __init__(self):
        self.name = 'validity_filter'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        return FilterResult(molecule.valid, self.name, {})
    
class SingleCompoundFilter(Filter):
    def __init__(self):
        self.name = 'single_compound'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result = not ('.' in molecule.smile)
        return FilterResult(result, self.name, {})

In [None]:
mol1 = Molecule('CCCC')
mol2 = Molecule('CCCc')
mol3 = Molecule('CCC.CCCC')

f1 = ValidityFilter()
f2 = SingleCompoundFilter()

assert f1(mol1).filter_result
assert not f1(mol2).filter_result
assert f1(mol3).filter_result

assert f2(mol1).filter_result
assert not f2(mol3).filter_result

In [None]:
#| export

class AttachmentCountFilter(Filter):
    def __init__(self, num_attachments: int):
        self.num_attachments = num_attachments
        self.name = f'attachment_count_{num_attachments}'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        num_attachments = molecule.smile.count('*')
        
        result = num_attachments == self.num_attachments
        data = {'num_attachments' : num_attachments}
        
        return FilterResult(result, self.name, data)


In [None]:
mol1 = Molecule('[*:1]CC')
mol2 = Molecule('[*:1]CC[*:2]')
f = AttachmentCountFilter(1)

assert f(mol1).filter_result
assert not f(mol2).filter_result

In [None]:
#| export

class BinaryFunctionFilter(Filter):
    def __init__(self, func: Callable[[Molecule], bool], name: str):
        self.name = name
        self.func = func
        
    def __call__(self, molecule: Molecule):
        result = self.func(molecule)
        
        return FilterResult(result, self.name, {})
    
class DataFunctionFilter(Filter):
    def __init__(self, func: Callable[[Molecule], Tuple[bool, dict]], name: str):
        self.name = name
        self.func = func
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result, data = self.func(molecule)
        
        return FilterResult(result, self.name, data)

In [None]:
from rdkit.Chem import rdMolDescriptors
mol1 = Molecule('Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1)=NC2')
mol2 = Molecule('CCCC')

def filter_func(molecule):
    n_rings = rdMolDescriptors.CalcNumRings(molecule.mol)
    return n_rings>1

f = BinaryFunctionFilter(filter_func, 'has_ring')

assert f(mol1).filter_result
assert not f(mol2).filter_result

In [None]:
#| export

class RangeFunctionFilter(Filter):
    def __init__(self, func: Callable[[Molecule], bool], min_val: [int, float], max_val: [int, float], name: str):
        self.func = func
        self.min_val = min_val
        self.max_val = max_val
        self.name = name
        
    def __call__(self, molecule: Molecule):
        value = self.func(molecule)
        data = {'computed_value' : value, 'min_val' : self.min_val, 'max_val' : self.max_val}
        result = self.min_val <= value <= self.max_val
        
        return FilterResult(result, self.name, data)

In [None]:
from rdkit.Chem import rdMolDescriptors

def molwt(molecule):
    wt = rdMolDescriptors.CalcExactMolWt(molecule.mol)
    return wt

f = RangeFunctionFilter(molwt, 250, 350, 'molwt_filter')


mol1 = Molecule('Cc1nnc2n1-c1ccc(Cl)cc1C(c1ccccc1)=NC2')
mol2 = Molecule('CCCC')

assert f(mol1).filter_result
assert not f(mol2).filter_result

In [None]:
#| hide
# dummy heavy atom filter
# filter on attachment/dummy schema

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()