In [None]:
# default_exp templates.blocks

# Blocks

> Blocks are used for advanced templating where different templates are applied to different sections of the molecule



In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.filters import *
from mrl.templates.template import *

  return f(*args, **kwds)


In [None]:
# export

class Block():
    def __init__(self, template, links, name, subblocks=[]):
        self.template = template
        self.links = links
        self.name = name
        self.subblocks = subblocks
        self.sublinks = []
        self.update_links()
        
    def update_links(self):
        for b in self.subblocks:
            self.sublinks.append(b.links)
            for sl in b.sublinks:
                self.sublinks.append(sl)
                
    def eval_mol(self, mol):
        mol = to_mol(mol)
        smile = to_smile(mol)
        
        if self.match_fragment(smile):
            hardpass, hardlog = self.template.hf(mol)
        else:
            hardpass = False
            hardlog = []

        if hardpass:
            score, softlog = self.template.sf(mol)
        else:
            score = self.template.fail_score
            softlog = []
        
        return [hardpass, score, hardlog, softlog]
                
    def match_fragment(self, fragment):
        # determine if fragment matches block link pattern
        raise NotImplementedError
    
    def match_fragment_recursive(self, fragment):
        # recursively match fragment to all subblocks
        if self.match_fragment(fragment):
            output = True
        else:
            output = False
            for block in self.subblocks:
                if block.match_fragment_recursive(fragment):
                    output = True
                    
        return output
    
    def sample(self, n, log='hard'):
        return self.template.sample(n, log=log)
    
    def load_data(self, fragments, recurse=False):
        # checks fragment attachments, then sends to template `load_data`
        # optionally recursive
        if recurse:
            for b in self.subblocks:
                b.load_data(fragments, recurse=True)
        
        matches = maybe_parallel(self.match_fragment, fragments)
        fragments = [fragments[i] for i in range(len(fragments)) if matches[i]]
        self.template.screen_mols(fragments)
        
    def decompose_fragments(self, fragment_string):
        # decomposes a string of multiple fragments into a list of single fragments
        raise NotImplementedError
    
    def join_fragments(self, fragment_list):
        # joins list of fragments into single string
        raise NotImplementedError
    
    def fuse_fragments(self, fragment_string):
        # fuses fragment string into single output
        raise NotImplementedError
    
    def join_and_fuse(self, fragment_list):
        return self.fuse_fragments(self.join_fragments(fragment_list))
    
    def recurse_fragments(self, fragment, add_constant=True):
        # recursively break down fragments, route to subblocks, fuse and evaluate
        raise NotImplementedError
        
    def __repr__(self):
        
        rep_str = f'Block {self.name}: {self.links}\n\t' + '\n\t'.join(self.template.__repr__().split('\n'))
        
        if self.subblocks:
            rep_str += '\n'
            for b in self.subblocks:
                rep_str += '\n\t' + '\n\t'.join(b.__repr__().split('\n'))
                
        return rep_str
        

In [None]:
# export

class MolBlock(Block):
    def __init__(self, template, links, name, subblocks=[]):
        super().__init__(template, links, name, subblocks=subblocks)
        
        # self.links = ['1*:2', '1*:3']
        self.pattern = re.compile('\[.\*:.]')
        
        for link in self.links:
            assert not '0*' in link, "Do not use 0 as an isotope, RDKit automatically removes it"
            
    def pattern_match(self, fragment):
        matches = self.pattern.findall(fragment)
        return [i[1:-1] for i in matches]
    
    def is_mapped(self, fragment):
        if fragment.count('*') == len(self.pattern_match(fragment)):
            mapped = True
        else:
            mapped = False
            
        return mapped
    
    def add_mapping(self, fragment, links=None):
        if self.is_mapped(fragment):
            # already mapped
            mapped = fragment
        else:
            if len(self.pattern_match(fragment))>0:
                # partially mapped, something went wrong
                fragment = self.remove_mapping(fragment)
                
            if links is None:
                links = list(self.links)
                random.shuffle(links)
                
            mapped = ''
            link_count = 0
            for s in fragment:
                if s=='*':
                    s = f'[{links[link_count]}]'
                    link_count += 1
                mapped += s
        
        return mapped
                
    def remove_mapping(self, fragment):
        matches = self.pattern_match(fragment)
        for match in matches:
            fragment = fragment.replace(f'[{match}]', '*')
        return fragment
    
    def match_fragment(self, fragment):
        
        match = False
        if fragment.count('*') == len(self.links):
            if not self.is_mapped(fragment):
                fragment = self.add_mapping(fragment)
                
            matches = self.pattern_match(fragment)
            if len(matches)==len(set(matches)) and set(matches)==set(self.links):
                match = True

        return match
    
    def load_fragment(self, fragment):
        if fragment.count('*') == len(self.links):
            fragment = self.add_mapping(fragment)
            fragpass = True
        else:
            fragpass = False
            
        return [fragment, fragpass]
    
    def load_data(self, fragments, recurse=False):
        if recurse:
            for b in self.subblocks:
                b.load_data(fragments, recurse=True)
                
        fragments = maybe_parallel(self.load_fragment, fragments)
        fragments = [i[0] for i in fragments if i[1]]
        self.template.screen_mols(fragments)
        
    def sample_smiles(self, n, log='hard'):
        return self.template.sample_smiles(n, log=log)
    
    def shuffle_mapping(self, fragment):
        current_mapping = self.pattern_match(fragment)
        new_mapping = list(current_mapping)
        random.shuffle(new_mapping)
        
        fragment = self.remove_mapping(fragment)
        fragment = self.add_mapping(fragment, links=new_mapping)
        return fragment
    
    def decompose_fragments(self, fragment_string):
        return fragment_string.split('.')
    
    def join_fragments(self, fragment_list):
        return '.'.join(fragment_list)
    
    def fuse_fragments(self, fragment_string):
        new_smile = fuse_on_atom_mapping(fragment_string)
        return new_smile
    
    def recurse_fragments(self, fragments, add_constant=True):

        output_dicts = []
        total_pass = []
        total_score = 0.

        if type(fragments) == str:
            fragments = [fragments]

        fragments = [self.decompose_fragments(i) for i in fragments]
        fragments = [item for sublist in fragments for item in sublist]

        if self.subblocks:
            new_fragments = []

            unrouted = list(fragments) # copy list

            for sb in self.subblocks:
                routed = [i for i in unrouted if sb.match_fragment_recursive(i)]
                unrouted = [i for i in unrouted if not i in routed]

                if routed:
                    r_fused, r_pass, r_score, subdicts = sb.recurse_fragments(routed)
                    new_fragments.append(r_fused)
                    total_pass.append(r_pass)
                    total_score += r_score
                    output_dicts += subdicts

                if isinstance(sb, ConstantBlock) and add_constant:
                    new_fragments.append(sb.smile)

            fragments = new_fragments + unrouted

        joined_fragments = self.join_fragments(fragments)
        fused = self.fuse_fragments(joined_fragments)

        frag_pass, frag_score, hardlog, softlog = self.eval_mol(fused)
        total_pass.append(frag_pass)
        total_score += frag_score

        total_pass = all(total_pass)

        output_dict = {
            'block' : self.name,
            'fused' : fused,
            'fragments' : fragments,
            'block_pass' : frag_pass,
            'block_score' : frag_score,
            'all_pass' : total_pass,
            'all_score' : total_score,
            'hardlog' : hardlog,
            'softlog' : softlog
        }

        output_dicts.append(output_dict)

        return fused, total_pass, total_score, output_dicts


In [None]:
# export

class ConstantBlock():
    def __init__(self, constant, name):
        self.constant = constant
        self.name = name
        self.links = []
        self.subblocks = []
        self.sublinks = []
        
    def match_fragment(self, fragment):
        return False
    
    def match_fragment_recursive(self, fragment):
        return False

    def load_data(self, fragments, recurse=False):
        pass
            
    def sample_data(self, n):
        return pd.DataFrame([self.constant, 0.]*n, columns=['smiles', 'final'])
    
    def __repr__(self):
        
        rep_str = f'Constant Block: {self.constant}'
        
        return rep_str
    

class ConstantMolBlock(ConstantBlock):
    def __init__(self, smile, name):
        super().__init__(smile, name)
        self.smile = canon_smile(smile)
        if '[*' in self.smile:
            self.smile = self.smile.replace('[*', '[0*')
        self.pattern = re.compile('\[.\*:.]')
        self.links = [i[1:-1] for i in self.pattern.findall(smile)]
    
    def sample_smiles(self, n):
        return [self.smile]*n
    
    def __repr__(self):
        
        rep_str = f'Constant Block: {self.smile}'
        
        return rep_str


In [None]:
# scheme - constant scaffold, two variable r groups

# scaffod
scaffold_smile = 'c1nc2c([1*:2])cncc2cc1[1*:1]'
scaffold_block = ConstantMolBlock(scaffold_smile, 'scaffold')

# R1, must have ring, be between 50-250 g/mol. must have 1 ring. ideally less thn 100-200 g/mol

r1_template = Template(
                    [MolWtFilter(50, 250),
                     RingFilter(1,1)],
                    [MolWtFilter(100, 200, 1)],
                    fail_score=-1
                    )

r1_block = MolBlock(r1_template, ['2*:1'], 'r1')


# R2, must have no rings, be between 0-200 g/mol. must have 0 rings. ideally less thn 50-150 g/mol

r2_template = Template(
                    [MolWtFilter(0, 200),
                     RingFilter(None,0)],
                    [MolWtFilter(50,150,1)],
                    fail_score=-1
                    )

r2_block = MolBlock(r2_template, ['2*:2'], 'r2')


# full compound, must be between 200 and 550 g/mol

full_template = Template(
                    [MolWtFilter(200, 550)],
                    fail_score=-1)

main_block = MolBlock(full_template, [], 'full_molecule', subblocks=[scaffold_block, r1_block, r2_block])

In [None]:
os.environ['ncpus'] = '8'

In [None]:
df = pd.read_csv('files/smiles.csv')

In [None]:
fragments = fragment_smiles(df.smiles.values, [1])

In [None]:
len(fragments)

15898

In [None]:
main_block.load_data(df.smiles.values, recurse=True)

In [None]:
r1_block.template.soft_log

Unnamed: 0,smiles,0,final


In [None]:
main_block.load_data(fragments, recurse=True)

In [None]:
r1_block.template.soft_log

Unnamed: 0,smiles,0,final
0,O=C(NC(=O)[2*:1])NC1CC1,1.0,1.0
1,O=C(CNC(=O)C1CCCCC1)OCC(=O)[2*:1],0.0,0.0
2,COCc1cccc(NC(=O)N[2*:1])c1,1.0,1.0
3,O=CC(=Cc1ccc([N+](=O)[O-])o1)C[2*:1],1.0,1.0
4,O=C(CCc1ccc(F)c(F)c1)N[2*:1],1.0,1.0
...,...,...,...
4615,CCC(CC)(C[NH+]1CCCCC1)[2*:1],1.0,1.0
4616,O=C(NC[2*:1])C1CCCC1,1.0,1.0
4617,O=C(C[2*:1])NC1CCCC1,1.0,1.0
4618,Oc1cccc(C[2*:1])c1,1.0,1.0


In [None]:
frag_strings = ['.'.join(i) for i in list(zip(*[i.sample_smiles(5) for i in main_block.subblocks]))]

In [None]:
[main_block.fuse_fragments(i) for i in frag_strings]

['Cc1cc(C(=O)NCc2cnc3c(N(C)CCS(C)(=O)=O)cncc3c2)sc1C#CCO',
 'Cc1ccc(C(=O)NCc2cnc3c(C(=O)N(C)CCC(C)C)cncc3c2)n1C',
 'CC(C)NC(=O)NNC(=O)c1cncc2cc(C[NH+]3CCC(CO)CC3)cnc12',
 'CCCc1noc(C[NH+](C)c2cnc3c(SCCCCCCS)cncc3c2)n1',
 'CC(C)(CBr)c1cncc2cc(CC(=O)Nc3ccc(Cl)c(C(F)(F)F)c3)cnc12']

In [None]:
frag_strings = []

for i in range(100):
    frag_strings += ['.'.join(i) for i in list(zip(*[i.sample_smiles(5) for i in main_block.subblocks]))]

In [None]:
len(frag_strings), len(set(frag_strings))

(500, 500)

In [None]:
frag_strings = list(set(frag_strings))

In [None]:
# out = [main_block.recurse_fragments(i, add_constant=False) for i in frag_strings]
out = maybe_parallel(main_block.recurse_fragments, frag_strings, add_constant=False)

In [None]:
out[0][-1]

[{'block': 'r1',
  'fused': 'Cc1ccn(CCNC(=O)[2*:1])n1',
  'fragments': ['Cc1ccn(CCNC(=O)[2*:1])n1'],
  'block_pass': True,
  'block_score': 1.0,
  'all_pass': True,
  'all_score': 1.0,
  'hardlog': [],
  'softlog': []},
 {'block': 'r2',
  'fused': 'CC(C)O[2*:2]',
  'fragments': ['CC(C)O[2*:2]'],
  'block_pass': True,
  'block_score': 1.0,
  'all_pass': True,
  'all_score': 1.0,
  'hardlog': [],
  'softlog': []},
 {'block': 'full_molecule',
  'fused': 'Cc1ccn(CCNC(=O)c2cnc3c(OC(C)C)cncc3c2)n1',
  'fragments': ['Cc1ccn(CCNC(=O)[2*:1])n1',
   'CC(C)O[2*:2]',
   'c1nc2c([1*:2])cncc2cc1[1*:1]'],
  'block_pass': True,
  'block_score': 0,
  'all_pass': True,
  'all_score': 2.0,
  'hardlog': ['Cc1ccn(CCNC(=O)c2cnc3c(OC(C)C)cncc3c2)n1', True, True],
  'softlog': ['Cc1ccn(CCNC(=O)c2cnc3c(OC(C)C)cncc3c2)n1', 0]}]

In [None]:
# scheme - constant scaffold, two variable r groups. one r group has a ring with two attachments, 
# one constant and one variabe

# scaffod
scaffold_smile = 'c1nc2c([1*:2])cncc2cc1[1*:1]'
scaffold_block = ConstantMolBlock(scaffold_smile, 'scaffold')

# R1 has 3 groups - constant carbonyl, variable ring, variabe ring attachment

r1_carbonyl_block = ConstantMolBlock('C(O)(=O)[1*:4]', 'carbonyl')

r1_ring_substitution_tempate = Template(
                                        [RotBondFilter(0,3),
                                         RingFilter(None,0)],
                                        [RotBondFilter(0,2,1)],
                                        fail_score=-1
                                        )

r1_ring_substitution_block = MolBlock(r1_ring_substitution_tempate, ['1*:3'], 'r1 ring substitution')

r1_ring_template = Template(
                            [RingFilter(1,1),
                             RotBondFilter(0,1)],
                            [RingFilter(1,1,1)],
                            fail_score=-1
                            )

r1_ring_block = MolBlock(r1_ring_template, ['2*:3', '2*:4', '2*:2'], 'r1_ring')


r1_full_group_template = Template(
                            [MolWtFilter(50,350)],
                            [MolWtFilter(100,200,1)],
                            fail_score=-1
                            )

r1_block = MolBlock(r1_full_group_template, ['2*:2'], 'r1_full', 
                     subblocks=[r1_carbonyl_block, r1_ring_substitution_block, r1_ring_block])


# R1, must have no rings, be between 0-200 g/mol. must have 0 rings. ideally less thn 50-150 g/mol

r2_template = Template(
                        [MolWtFilter(0,200)],
                        [RingFilter(0,0)],
                        fail_score=-1
                        )

r2_block = MolBlock(r2_template, ['2*:1'], 'r2')


# full compound, must be between 200 and 550 g/mol

full_template = Template(
                        [MolWtFilter(200,550),
                         StructureFilter(['[#6](-[#8])(=[#8])-[*]',
                             '[#6]1:[#7]:[#6]2:[#6](-[*]):[#6]:[#7]:[#6]:[#6]:2:[#6]:[#6]:1-[*]'
                             ], criteria='all', exclude=False)],
                        [MolWtFilter(250,500,1)],
                        fail_score=-1
                        )

main_block = MolBlock(full_template, [], 'full_molecule', subblocks=[scaffold_block, r1_block, r2_block])

In [None]:
main_block.load_data(df.smiles.values, recurse=False)

In [None]:
main_block.load_data(fragments, recurse=True)

In [None]:
frag_df = pd.read_csv('../../chem_research/fragments.csv')

In [None]:
main_block.load_data(frag_df.smiles.values, recurse=True)

In [None]:
f = ['CNC(=O)C=C[1*:3]',
 'CC(C)(C)C(C1C(=O)CC([2*:2])([2*:3])OC1=O)[2*:4]',
 'CCCCN(C)c1ccc([2*:1])cc1']

fragments = [main_block.decompose_fragments(i) for i in f]
fragments = flatten_list_of_lists(fragments)

In [None]:
unrouted = list(fragments)
for sb in main_block.subblocks:
    routed = [i for i in unrouted if sb.match_fragment_recursive(i)]
    print(sb.name, routed)

scaffold []
r1_full ['CNC(=O)C=C[1*:3]', 'CC(C)(C)C(C1C(=O)CC([2*:2])([2*:3])OC1=O)[2*:4]']
r2 ['CCCCN(C)c1ccc([2*:1])cc1']


In [None]:
# export

class BlockTree():
    def __init__(self, head_block):
        self.head_block = head_block
        self.nodes = self.nodes_to_list(self.head_block)
        self.leaf_nodes = [i for i in self.nodes if not i.subblocks]
        self.live_leafs = [i for i in self.leaf_nodes if not isinstance(i, ConstantBlock)]
        self.node_dict = {i.name:i for i in self.nodes}
        self.log = []
        
    def nodes_to_list(self, block):
        nodes = [block]
        if block.subblocks:
            for subblock in block.subblocks:
                nodes += self.nodes_to_list(subblock)
                
        return nodes
    
    def log_outputs(self, outputs):
        
        log_dict = {}
        
        for output_dict in outputs:
            if not output_dict['block'] in log_dict.keys():
                log_dict[output_dict['block']] = {'hard':[],
                                                  'soft':[]}
                
            if output_dict['hardlog']:
                log_dict[output_dict['block']]['hard'].append(output_dict['hardlog'])
                
            if output_dict['softlog']:
                log_dict[output_dict['block']]['soft'].append(output_dict['softlog'])
        
        for blockname in log_dict.keys():
            node = self.node_dict[blockname]
            if not log_dict[blockname]['hard']==[]:
                node.template.log_data(log_dict[blockname]['hard'], filter_type='hard')
                
            if not log_dict[blockname]['soft']==[]:
                node.template.log_data(log_dict[blockname]['soft'], filter_type='soft')
                
    def recurse_fragments(self, fragments, add_constant=True):
        
        if type(fragments) == str:
            fragments = [fragments]
            
        outputs = maybe_parallel(self.head_block.recurse_fragments, fragments, add_constant=add_constant)
        output_data = []
        output_dicts = []
        
        for i, out in enumerate(outputs):
            fused, allpass, allscore, log_dicts = out
            output_data.append([fragments[i], fused, allpass, allscore])
            output_dicts += log_dicts
            
        self.log_outputs(output_dicts)
        self.log += output_data
        return output_data
                
    def sample_leaf_nodes(self, include_constant=False):
        pass
    
    def combinatorial_sample(self, n_sample, max_n, include_constant=False):
        pass
    
    
class MolBlockTree(BlockTree):
    def __init__(self, head_block):
        super().__init__(head_block)
        
    def _sample_leaf_nodes(self, include_constant=False):
        if include_constant:
            leaf_nodes = self.leaf_nodes
        else:
            leaf_nodes = self.live_leafs
            
        output = []
        
        for node in leaf_nodes:
            output.append(node.sample_data(1))
        
        output = [list(i.smiles.values) for i in output]
            
        return output
    
    def sample_leaf_nodes(self, n, include_constant=False):
        return maybe_parallel(self._sample_leaf_nodes, [include_constant]*n)
    
    def combinatorial_sample(self, n_sample, max_n, include_constant=False):
        
        if include_constant:
            leaf_nodes = self.leaf_nodes
        else:
            leaf_nodes = self.live_leafs
            
        output = []
        
        for node in leaf_nodes:
            output.append(node.sample_data(n_sample))
        
        output = [list(set(i.smiles.values)) for i in output]
        it = itertools.product(*output)
        
        prods = []
        
        for i, p in enumerate(it):
            prods.append(p)
            if i>max_n:
                break
        
        return prods

In [None]:
os.environ['ncpus'] = '0'

In [None]:
blocktree = MolBlockTree(main_block)

In [None]:
frags[0]

['SCCC[1*:3]', 'O=C(N[2*:2])C1CC([2*:4])CCN1[2*:3]', 'C1CC(C[2*:1])CC[NH2+]1']

In [None]:
_ = blocktree.recurse_fragments(frags, add_constant=True)

In [None]:
_[0]

[['SCCC[1*:3]',
  'O=C(N[2*:2])C1CC([2*:4])CCN1[2*:3]',
  'C1CC(C[2*:1])CC[NH2+]1'],
 'O=C(O)C1CCN(CCCS)C(C(=O)Nc2cncc3cc(CC4CC[NH2+]CC4)cnc23)C1',
 True,
 3.0]

In [None]:
len([i for i in _ if i[2]])

2412

In [None]:
blocktree.recurse_fragments('.'.join(frags[0]), add_constant=True)

[['SCCC[1*:3].O=C(N[2*:2])C1CC([2*:4])CCN1[2*:3].C1CC(C[2*:1])CC[NH2+]1',
  'O=C(O)C1CCN(CCCS)C(C(=O)Nc2cncc3cc(CC4CC[NH2+]CC4)cnc23)C1',
  True,
  3.0]]

check if template with no hard fiters returns true

Clean up docs
add returns between lines in docstrings for better rendering
Make overview page for templates
Make tutorial for enumeration (remove existing from chem notebook)
make tutoriaal for baasic templates
make tutorial for intermediate templates (custom filters, etc)
make tutorial for advanced templates (blocks)
figure out page links in nbdev
figure out import all 

torch core
models (lstm, vae, transformer)

score functions

training loop

poicy gradients

q-network

diff-loss

exploration strategies

combichem

pharmacophore 


pages
overview

generrative screening primerr