In [None]:
class ConstantBlock():
    def __init__(self, constant, name):
        self.constant = constant
        self.name = name
        self.links = []
        self.subblocks = []
        self.sublinks = []
        
    def match_fragment(self, fragment):
        return False
    
    def match_fragment_recursive(self, fragment):
        return False

    def load_data(self, fragments, recurse=False):
        pass
            
    def sample_data(self, n):
        return pd.DataFrame([self.constant]*n, columns=['smiles'])
    
    def __repr__(self):
        
        rep_str = f'Constant Block: {self.constant}'
        
        return rep_str
    

class ConstantMolBlock(ConstantBlock):
    def __init__(self, smile, name):
        super().__init__(smile, name)
        self.smile = Chem.CanonSmiles(smile)
        if '[*' in self.smile:
            self.smile = self.smile.replace('[*', '[0*')
        self.pattern = re.compile('\[(.*?)\*:(.*?)]')
        self.links = self.pattern.findall(smile)
    
    def sample_smile(self, n):
        return [self.smile]*n
    
    def __repr__(self):
        
        rep_str = f'Constant Block: {self.smile}'
        
        return rep_str

In [None]:
# scheme - constant scaffold, two variable r groups

# scaffod
scaffold_smile = 'c1nc2c([0*:2])cncc2cc1[0*:1]'
scaffold_block = ConstantMolBlock(scaffold_smile, 'scaffold')

# R1, must have ring, be between 50-250 g/mol. must have 1 ring. ideally less thn 100-200 g/mol

r1_template = Template(
                    HardFilters([
                        MolWtFilter(50, 250, None),
                        RingFilter(1,1,None)
                    ]),
                    SoftFilters([
                        MolWtFilter(100,200,1),
                    ]),
                    failscore = -1)
r1_block = MolBlock(r1_template, [('1', '1')], 'r1')


# R1, must have no rings, be between 0-200 g/mol. must have 0 rings. ideally less thn 50-150 g/mol

r2_template = Template(
                    HardFilters([
                        MolWtFilter(0, 200, None),
                        RingFilter(0,0,None)
                    ]),
                    SoftFilters([
                        MolWtFilter(50,150,1),
                    ]),
                    failscore = -1)
r2_block = MolBlock(r2_template, [('1', '2')], 'r2')


# full compound, must be between 200 and 550 g/mol

full_template = Template(
                    HardFilters([
                        MolWtFilter(200, 550, None)
                    ]),
                    SoftFilters([
                        MolWtFilter(200,550,1),
                    ]),
                    failscore = -1)

main_block = MolBlock(full_template, [], 'full_molecue', subblocks=[scaffold_block, r1_block, r2_block])

In [None]:
main_block.load_data(df.smiles.values[:10000], recurse=True)

In [None]:
main_block.load_data(frag_df.smiles.values[:100000], recurse=True)

In [None]:
fragment_string = '[1*:2]C(=O)NC(CO)CC(C)C.[1*:1]S(=O)(=O)NCc1ccc(O)c(C(=O)O)c1'

In [None]:
main_block.fuse_fragments(fragment_string+'.'+scaffold_block.smile)

'CC(C)CC(CO)NC(=O)c1cncc2cc(S(=O)(=O)NCc3ccc(O)c(C(=O)O)c3)cnc12'

In [None]:
main_block.match_fragment('[1*:2]C(=O)NC(CO)CC(C)C')

False

In [None]:
main_block.match_fragment_recursive('[1*:2]C(=O)NC(CO)CC(C)C')

True

In [None]:
out = main_block.recurse_fragments('[1*:2]C(=O)NC(CO)CC(C)C.[1*:1]S(=O)(=O)NCc1ccc(O)c(C(=O)O)c1')

In [None]:
out[0]

'CC(C)CC(CO)NC(=O)c1cncc2cc(S(=O)(=O)NCc3ccc(O)c(C(=O)O)c3)cnc12'

In [None]:
out[1]

True

In [None]:
out[2]

2.0

In [None]:
# scheme - constant scaffold, two variable r groups. one r group has a ring with two attachments, 
# one constant and one variabe

# scaffod
scaffold_smile = 'c1nc2c([0*:2])cncc2cc1[0*:1]'
scaffold_block = ConstantMolBlock(scaffold_smile, 'scaffold')

# R1 has 3 groups - constant carbonyl, variable ring, variabe ring attachment

r1_carbonyl_block = ConstantMolBlock('C(O)(=O)[1*:4]', 'carbonyl')

r1_ring_substitution_tempate = Template(
                                    HardFilters([
                                        RotBondFilter(0, 3, None),
                                        RingFilter(0,0,None)
                                    ]),
                                    SoftFilters([
                                      RotBondFilter(0, 2, 1)
                                    ]),
                                    failscore = -1)

r1_ring_substitution_block = MolBlock(r1_ring_substitution_tempate, [('1', '3')], 'r1 ring substitution')

r1_ring_template = Template(
                            HardFilters([
                                RingFilter(1,1,None),
                                RotBondFilter(0,1, None)
                            ]),
                            SoftFilters([
                                RingFilter(1,1,1)
                            ]))
r1_ring_block = MolBlock(r1_ring_template, [('0', '3'), ('0', '4'), ('1', '2')], 'r1_ring')


r1_full_group_template = Template(
                            HardFilters([
                                MolWtFilter(50, 250, None)
                            ]),
                            SoftFilters([
                                MolWtFilter(100, 200, 1)
                            ]))
r1_block = MolBlock(r1_full_group_template, [('1', '2')], 'r1_full', 
                     subblocks=[r1_carbonyl_block, r1_ring_substitution_block, r1_ring_block])


# R1, must have no rings, be between 0-200 g/mol. must have 0 rings. ideally less thn 50-150 g/mol

r2_template = Template(
                    HardFilters([
                        MolWtFilter(0, 200, None),
                        RingFilter(0,0,None)
                    ]),
                    SoftFilters([
                        MolWtFilter(50,150,1),
                    ]),
                    failscore = -1)
r2_block = MolBlock(r2_template, [('1', '1')], 'r2')


# full compound, must be between 200 and 550 g/mol

full_template = Template(
                    HardFilters([
                        MolWtFilter(200, 550, None),
                        StructureFilter(['[#6](-[#8])(=[#8])-[*]',
                             '[#6]1:[#7]:[#6]2:[#6](-[*]):[#6]:[#7]:[#6]:[#6]:2:[#6]:[#6]:1-[*]'
                             ], 1, criteria='all', exclude=False)
                    ]),
                    SoftFilters([
                        MolWtFilter(200,550,1),
                    ]),
                    failscore = -1)

main_block = MolBlock(full_template, [], 'full_molecule', subblocks=[scaffold_block, r1_block, r2_block])

In [None]:
main_block.load_data(df.smiles.values[:200000], recurse=False)

In [None]:
main_block.load_data(frag_df.smiles.values[:200000], recurse=True)

In [None]:


def sample_leaf_nodes(block, include_constant=False):
    
    if block.subblocks:
        output = [sample_leaf_nodes(i, include_constant=include_constant) for i in block.subblocks]
    else:
        if not isinstance(block, ConstantBlock):
            output = block.sample_smile(1)
        else:
            if include_constant:
                output = block.sample_smile(1)
            else:
                output = None
    
    if type(output) == list:
        output = flatten_recursive(output)
        output = [i for i in output if i is not None]
            
    return output

def sample_leaf_nodes_n(n, block, include_constant=False):
    output = [sample_leaf_nodes(block, include_constant=include_constant) for i in range(n)]
    return output

In [None]:
frags = maybe_parallel(sample_leaf_nodes_n, [1]*10000, block=main_block, include_constant=False)

In [None]:
frags = [item for sublist in frags for item in sublist]

In [None]:
frags[0]

['[1*:3]CCCCC#N', '[0*:3]CC1C(O)C(O)C([1*:2])NC(=O)N1[0*:4]', '[1*:1]CCSC']

In [None]:
fused = maybe_parallel(main_block.recurse_fragments, frags)

In [None]:
fused_smiles = []
score_dicts = []
for item in fused:
    fused_smiles.append(item[0])
    score_dicts+= item[-1]

In [None]:
fuse_df = pd.DataFrame(score_dicts)

In [None]:
fuse_df.duplicated(subset='fused').mean()

0.63992

In [None]:
fuse_df.drop_duplicates(subset='fused', inplace=True)

In [None]:
fuse_df.reset_index(inplace=True, drop=True)

In [None]:
fuse_df

Unnamed: 0,block,fused,fragments,block_pass,block_score,all_pass,all_score
0,r1 ring substitution,N#CCCCC[1*:3],[[1*:3]CCCCC#N],True,0.0,True,0.0
1,r1_ring,O=C1NC([1*:2])C(O)C(O)C(C[0*:3])N1[0*:4],[[0*:3]CC1C(O)C(O)C([1*:2])NC(=O)N1[0*:4]],False,0.0,False,0.0
2,r1_full,N#CCCCCCC1C(O)C(O)C([1*:2])NC(=O)N1C(=O)O,"[O=C(O)[1*:4], N#CCCCC[1*:3], O=C1NC([1*:2])C(...",False,0.0,False,0.0
3,r2,CSCC[1*:1],[[1*:1]CCSC],True,1.0,True,1.0
4,full_molecule,CSCCc1cnc2c(C3NC(=O)N(C(=O)O)C(CCCCCC#N)C(O)C3...,"[c1nc2c([0*:2])cncc2cc1[0*:1], N#CCCCCCC1C(O)C...",True,1.0,False,2.0
...,...,...,...,...,...,...,...
17999,r1_full,NC(CO)C1NC(=O)NC(C(=O)O)C1[1*:2],"[O=C(O)[1*:4], NC(CO)[1*:3], O=C1NC([0*:3])C([...",True,0.0,False,1.0
18000,full_molecule,CC(CCc1cnc2c(C3C(C(=O)O)NC(=O)NC3C(N)CO)cncc2c...,"[c1nc2c([0*:2])cncc2cc1[0*:1], NC(CO)C1NC(=O)N...",True,1.0,False,3.0
18001,r1_full,CSCCC1C(O)C(CC(=O)O)N(C)S(=O)(=O)N(C)C1[1*:2],"[O=C(O)[1*:4], CSCC[1*:3], CN1C(C[0*:4])C(O)C(...",False,0.0,False,1.0
18002,full_molecule,CSCCC1C(O)C(CC(=O)O)N(C)S(=O)(=O)N(C)C1c1cncc2...,"[c1nc2c([0*:2])cncc2cc1[0*:1], CSCCC1C(O)C(CC(...",True,1.0,False,3.0


In [None]:
subfrags = pd.DataFrame(frags, columns=['f1', 'f2', 'f3'])

In [None]:
subfrags.shape[0] - subfrags.duplicated().sum()

9988

In [None]:
[i for i in fused if '*' in i[0]]

[]

In [None]:
fuse_df[(fuse_df.block=='full_molecule') & (fuse_df.block_pass==True)].iloc[0].fragments

['c1nc2c([0*:2])cncc2cc1[0*:1]',
 'N#CCCCCCC1C(O)C(O)C([1*:2])NC(=O)N1C(=O)O',
 'CSCC[1*:1]']

In [None]:
fuse_df[fuse_df.fused=='COC(=O)CCN1C([1*:2])C(O)C(O)C(C(=O)O)N(C)S1(=O)=O']

Unnamed: 0,block,fused,fragments,block_pass,block_score,all_pass,all_score


In [None]:
r1_full_group_template('COC(=O)CCN1C([1*:2])C(O)C(O)C(C(=O)O)N(C)S1(=O)=O')

[False, 0]

In [None]:
molwt(to_mol('COC(=O)CCN1C([1*:2])C(O)C(O)C(C(=O)O)N(C)S1(=O)=O'))

325.070561504

In [None]:
class BlockTree():
    def __init__(self, head_block):
        self.head_block = head_block
        self.nodes = self.nodes_to_list(self.head_block)
        self.leaf_nodes = [i for i in self.nodes if not i.subblocks]
        self.live_leafs = [i for i in self.leaf_nodes if not isinstance(i, ConstantBlock)]
        self.node_dict = {i.name:i for i in self.nodes}
        
    def nodes_to_list(self, block):
        nodes = [block]
        if block.subblocks:
            for subblock in block.subblocks:
                nodes += self.nodes_to_list(subblock)
                
        return nodes
    
    def sample_leaf_nodes(self, include_constant=False):
        pass
    
    def combinatorial_sample(self, n_sample, max_n, include_constant=False):
        pass
    
    
class MolBlockTree(BlockTree):
    def __init__(self, head_block):
        super().__init__(head_block)
        
    def _sample_leaf_nodes(self, include_constant=False):
        if include_constant:
            leaf_nodes = self.leaf_nodes
        else:
            leaf_nodes = self.live_leafs
            
        output = []
        
        for node in leaf_nodes:
            output.append(node.sample_data(1))
        
        output = [list(i.smiles.values) for i in output]
            
        return output
    
    def sample_leaf_nodes(self, n, include_constant=False):
        return maybe_parallel(self._sample_leaf_nodes, [include_constant]*n)
    
    def combinatorial_sample(self, n_sample, max_n, include_constant=False):
        
        if include_constant:
            leaf_nodes = self.leaf_nodes
        else:
            leaf_nodes = self.live_leafs
            
        output = []
        
        for node in leaf_nodes:
            output.append(node.sample_data(n_sample))
        
        output = [list(set(i.smiles.values)) for i in output]
        it = itertools.product(*output)
        
        prods = []
        
        for i, p in enumerate(it):
            prods.append(p)
            if i>max_n:
                break
        
        return prods

In [None]:
tree = MolBlockTree(main_block)