# Building Block Assembly

> Building block assembly related functions

In [None]:
#| default_exp building_block_assembly

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from __future__ import annotations
from chem_templates.imports import *
from chem_templates.utils import *
from chem_templates.building_blocks import Synthon, BuildingBlock, ReactionGroup, ReactionUniverse, REACTION_GROUPS
from chem_templates.template import Template, TemplateResult

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class AssemblyPool():
    def __init__(self, synthons: Optional[Synthon]=None):
        self.synthons = []
        self.mark_to_synthon = defaultdict(list)
        if synthons:
            for synthon in synthons:
                self.add_synthon(synthon)
                
    def __len__(self):
        return len(self.synthons)
        
    def add_synthon(self, synthon: Synthon):
        self.synthons.append(synthon)
        for mark in synthon.marks:
            self.mark_to_synthon[mark].append(synthon)
            
    def reaction_filter(self, rxn_universe: ReactionUniverse) -> AssemblyPool:
        valid = []
        for synthon in self.synthons:
            if rxn_universe.get_matching_reactions(synthon):
                valid.append(synthon)
                
        return AssemblyPool(valid)
    
    def get_matching(self, query_synthon: Synthon) -> list[Synthon]:
        matching_synthons = []
        for mark in query_synthon.compatible_marks:
            matching_synthons += self.mark_to_synthon[mark]
            
        return deduplicate_list(matching_synthons)

In [None]:
#| export

class BuildingBlockNode():
    def __init__(self, name: str, template: Optional[Template]=None):
        self.name = name
        self.template = template
        self.n_func = set()
        
    def template_screen(self, synthon: Synthon, store_data: bool=True) -> TemplateResult:
        if self.template is not None:
            output = self.template(synthon)
        else:
            output = TemplateResult(True, [], [])
        
        if store_data:
            synthon.add_data({'template_data' : output, 'template_result' : output.result})
            
        return output
    
    def synthon_screen(self, synthon: Synthon) -> bool:
        n_func = synthon.reconstruction_smile.count(':')
        if (n_func in self.n_func) or (not self.n_func):
            template_result = self.template_screen(synthon)
            return template_result.result
        else:
            return False

In [None]:
#| export

class SynthonNode(BuildingBlockNode):
    def __init__(self, 
                 name: str, 
                 n_func: Optional[set[int]],
                 template: Optional[Template]=None):
        super().__init__(name, template)
        self.n_func = n_func
        
    def assemble(self, assembly_dict: dict) -> AssemblyPool:
        pool = assembly_dict.get(self.name)
        pool = AssemblyPool([i for i in pool.synthons if self.synthon_screen(i)])
        return pool

In [None]:
#| export

class ReactionNode(BuildingBlockNode):
    def __init__(self, 
                 name: str, 
                 reaction_universe: ReactionUniverse):
        super().__init__(name, None)
        self.reaction_universe = reaction_universe
        
    def filter_pool(self, pool: AssemblyPool) -> AssemblyPool:
        return pool.reaction_filter(self.reaction_universe)

In [None]:
#| export

def create_assemblies(incoming_pool, next_pool, rxn_universe):
    incoming_pool = incoming_pool.reaction_filter(rxn_universe)
    next_pool = next_pool.reaction_filter(rxn_universe)
    
    assemblies = []
    for synthon in incoming_pool.synthons:
        matches = next_pool.get_matching(synthon)
        for match in matches:
            valid_rxns = rxn_universe.get_matching_reactions(synthon, match)
            if valid_rxns:
                assemblies.append((synthon, match, valid_rxns))
        
    return assemblies

def fuse_assembly(inputs):
    s1, s2, valid_rxns = inputs
    product_dicts = []
    for rxn in valid_rxns:
        product_dicts += rxn.react_to_dict(s1, s2)
        
    unique_products = defaultdict(list)
    for prod in product_dicts:
        unique_products[prod['synthon_smile']].append(prod)
        
    outputs = []
    for k,v in unique_products.items():
        prod = Synthon(v[0]['synthon_smile'], v[0]['reconstruction_smile'], 
                       [s1, s2], [i['reaction_tags'] for i in v])
        outputs.append(prod)
    return outputs

class ProductNode(BuildingBlockNode):
    def __init__(self, 
                 name: str, 
                 incoming_node: Union[SynthonNode, ProductNode], 
                 reaction_node: ReactionNode,
                 next_node: Union[SynthonNode, ProductNode],
                 n_func: Optional[set[int]],
                 template: Optional[Template]=None
                ):
        super().__init__(name, template)
        
        self.incoming_node = incoming_node
        self.reaction_node = reaction_node
        self.next_node = next_node
        
    def assemble(self, assembly_dict: dict) -> AssemblyPool:
        incoming_pool = self.incoming_node.assemble(assembly_dict)
        incoming_pool = self.reaction_node.filter_pool(incoming_pool)
        
        next_pool = self.next_node.assemble(assembly_dict)
        next_pool = self.reaction_node.filter_pool(next_pool)
        
        assemblies = create_assemblies(incoming_pool, next_pool, self.reaction_node.reaction_universe)
        
        with Pool(processes=os.cpu_count()) as p:
            products = p.map(fuse_assembly, assemblies)
            
        products = AssemblyPool(flatten_list(products))
        products = AssemblyPool([i for i in products.synthons if self.synthon_screen(i)])
        return products

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
# rxn_universe = ReactionUniverse('all_reactions', REACTION_GROUPS)

# dataset = datasets.load_from_disk('../dev/synthon_testset.hf/')
# dataset



# with Pool(processes=64) as p:
#     bbs = p.map(BuildingBlock, [i['smiles'] for i in dataset.shard(50,0)])
# len(bbs)

# synthons = flatten_list([i.synthons for i in bbs])
# len(synthons)

# monos = []
# bis = []

# for synthon in synthons:
#     if synthon.reconstruction_smile.count(':')==1:
#         monos.append(synthon)
#     elif synthon.reconstruction_smile.count(':')==2:
#         bis.append(synthon)
        
# len(monos), len(bis)



# from chem_templates.filter import RangeFunctionFilter, ValidityFilter, SingleCompoundFilter

# from rdkit.Chem import rdMolDescriptors, Descriptors
# from chem_templates.chem import mol_func_wrapper

# rings = mol_func_wrapper(rdMolDescriptors.CalcNumRings)
# hbd = mol_func_wrapper(rdMolDescriptors.CalcNumHBD)
# hba = mol_func_wrapper(rdMolDescriptors.CalcNumHBA)
# molwt = mol_func_wrapper(rdMolDescriptors.CalcExactMolWt)
# logp = mol_func_wrapper(Descriptors.MolLogP)
# rotb = mol_func_wrapper(rdMolDescriptors.CalcNumRotatableBonds)

# # bb1
# bb1_filters = [
#     RangeFunctionFilter(rings, 'rings', 1, 1),
# ]
# bb1_template = Template(bb1_filters)

# bb1 = SynthonNode('bb1', set([1]), bb1_template)


# # bb2
# bb2_filters = [
#     RangeFunctionFilter(rings, 'rings', 0, 0),
#     RangeFunctionFilter(rotb, 'rotatable_bonds', None, 3),
# ]
# bb2_template = Template(bb2_filters)

# bb2 = SynthonNode('bb2', set([2]), bb2_template)

# # bb3
# bb3_filters = [
#     RangeFunctionFilter(rings, 'rings', 1, 1),
# ]
# bb3_template = Template(bb3_filters)

# bb3 = SynthonNode('bb1', set([1]), bb3_template)

# # rxn1
# rxn1 = ReactionNode('r1', rxn_universe)

# # p1
# prod1 = ProductNode('prod1', bb1, rxn1, bb2, set([1]))

# # rxn2
# rxn2 = ReactionNode('r2', rxn_universe)

# # p2

# # full
# full_filters = [
#     ValidityFilter(),
#     SingleCompoundFilter(),
#     RangeFunctionFilter(hbd, 'hydrogen_bond_donors', None, 3),
#     RangeFunctionFilter(hba, 'hydrogen_bond_acceptors', None, 3),
#     RangeFunctionFilter(molwt, 'molecular_weight', None, 500),
#     RangeFunctionFilter(logp, 'CLogP', None, 3),
#     RangeFunctionFilter(rotb, 'rotatable_bonds', None, 5)
# ]

# full_template = Template(full_filters)

# prod2 = ProductNode('prod2', prod1, rxn2, bb3, set([0]), full_template)

# p1 = AssemblyPool([i for i in monos[:500] if bb1.synthon_screen(i)])
# p2 = AssemblyPool([i for i in bis if bb2.synthon_screen(i)])
# p3 = AssemblyPool([i for i in monos[500:700] if bb3.synthon_screen(i)])

# len(p1), len(p2), len(p3)

# assembly_dict = {
#     'bb1' : p1,
#     'bb2' : p2,
#     'bb3' : p3
# }

# test = prod1.assemble(assembly_dict)

# len(test)

# test2 = prod2.assemble(assembly_dict)

# len(test2)

# test2.synthons[3].mol

do we have bbs as inputs or synthons?
need to be able to trace back to initial bbs/synthons and reconstruct assembly inputs
assembly class to track construction?

synthon nodes - take in synthon. have template

constant nodes - constant synthon

reaction node - holds allowed reactions, incoming node (synthon or product), next synthon node

perhaps
assign pool of synthons to each node based on node template
synthon nodes can have a parent attr to grab reaction node for rxn screening?

synthon node:
    to setup:
        assign pool of possible BBs based on attachments / template
    during assembly:
        pushes pool upward
        
reaction node
    assembly:
        filter incoming pool by reactant to create (incoming, rxn) pairs
        
        
        
synthon pool
    list of current synthons
    dict of `mapping` to list of synthons
        ie grab synthon from pool A, get `compatible_marks`, look up `compatible_marks` on mapping dict in pool B,
        
        
assembly modes
    product (full combi)
    random
        randomly grab synthon from pool A
        randomly grab synthons from pool B until a pair is found that's compatible by the current rxn universe
        react
    both methods:
        build `chunksize` preassemblies in parallel
        go until exhausted or limit reached
        
        
        
        
generics for fragment and building block
    input is assembly pool
    step 1: send pool to child nodes (skip for leaf)
    step 2: build assembly inputs from child nodes (skip for leaf)
    step 3: fuse result of child nodes (skip for leaf)
        2 + 3 should have some switch between "full product" and random sample up to n
    step 4: filter fusion results on template and mapping/num active groups
    step 5: send upward
    parallel processing as much as possible