# Assembly

> assembly related functions

In [None]:
#| default_exp assembly

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export

from __future__ import annotations
from chem_templates.imports import *
from chem_templates.utils import *
from chem_templates.chem import Molecule
from chem_templates.template import Template, TemplateResult

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class AssemblyPool():
    def __init__(self, items: list[Molecule]):
        self.items = items
        
    def __len__(self) -> int:
        return len(self.items)
    
    def __getitem__(self, idx: int) -> Molecule:
        return self.items[idx]
    
    def filter(self, filter_func: Callable, worker_pool: Optional[Pool]=None) -> AssemblyPool:
        if worker_pool:
            bools = worker_pool.map(filter_func, self.items)
            
        else:
            bools = [filter_func(i) for i in self.items]
            
        return AssemblyPool([self.items[i] for i in range(len(self.items)) if bools[i]])
    
    def __repr__(self) -> str:
        return f'AssemblyPool: {len(self.items)} items'

In [None]:
pool = AssemblyPool([Molecule('C'), Molecule('CCCCC')])
assert len(pool)==2
def filter_func(molecule):
    return len(molecule.smile)>1

pool2 = pool.filter(filter_func)
assert len(pool2)==1

In [None]:
#| export

class AssemblyInputs():
    def __init__(self, 
                 pool_dict: dict[str, AssemblyPool], 
                 assembly_chunksize: int,
                 max_assemblies_per_node: int,
                 worker_pool: Optional[Pool]=None, 
                 log: bool=True):
        
        self.pool_dict = pool_dict
        self.assembly_chunksize = assembly_chunksize
        self.max_assemblies_per_node = max_assemblies_per_node
        
        self.worker_pool = worker_pool
            
        self.log = log
        self.assembly_log = {}

In [None]:
#| export

class Node():
    def __init__(self, 
                 name: str, 
                 template: Optional[Template]=None):
        self.name = name
        self.template = template
        
    def template_screen(self, molecule: Molecule) -> bool:
        if self.template is not None:
            output = self.template(molecule)
        else:
            output = TemplateResult(True, [], [])
        
        molecule.add_data({'template_data' : output, 'template_result' : output.result})
            
        return output.result
    
    def _fuse(self, fusion_input):
        raise NotImplementedError

    def fuse(self, fusion_inputs, worker_pool: Optional[Pool]=None):
        if worker_pool:
            outputs = worker_pool.map(self._fuse, fusion_inputs)
        else:
            outputs = [self._fuse(i) for i in fusion_inputs]
        return AssemblyPool(outputs)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()