# core

> core functions and classes

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.utils import batch_list, unbatch_list

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class Executor():
    def __init__(self, 
                 batched: bool,
                 batch_size: int=1
                ):
        self.batched = batched
        self.batch_size = batch_size
    
    def batch_inputs(self, executor_inputs):
        if self.batched:
            executor_inputs = batch_list(executor_inputs, self.batch_size)
        return executor_inputs
            
    def unbatch_inputs(self, results):
        if self.batched:
            results = unbatch_list(results)
        return results
    
    def execute(self, executor_function, executor_inputs):
        results = [executor_function(i) for i in executor_inputs]
        return results
        
    def __call__(self, 
                 executor_function: Callable, 
                 executor_inputs: BaseModel,
                ) -> BaseModel:
        
        executor_inputs = self.batch_inputs(executor_inputs)
        results = self.execute(executor_function, executor_inputs)
        results = self.unbatch_inputs(results)
            
        return results

In [None]:
#| export

class ProcessExecutor(Executor):
    def __init__(self,
                 batched: bool,
                 batch_size: int=1,
                 concurrency: Optional[int]=1,
                ):
        self.batched = batched
        self.concurrency = concurrency
        self.batch_size = batch_size
        
    def execute(self, executor_function, executor_inputs):
        if (self.concurrency is None) or (self.concurrency==1):
            results = [executor_function(i) for i in executor_inputs]
        else:
            with ProcessPoolExecutor(min(self.concurrency, len(executor_inputs))) as p:
                results = list(p.map(executor_function, executor_inputs))
            
        return results
    
    def __call__(self, 
                 executor_function: Callable, 
                 executor_inputs: BaseModel,
                ) -> BaseModel:
        
        executor_inputs = self.batch_inputs(executor_inputs)
        results = self.execute(executor_function, executor_inputs)
        results = self.unbatch_inputs(results)
            
        return results

In [None]:
#| export

class ThreadExecutor(Executor):
    def __init__(self,
                 batched: bool,
                 concurrency: int=1,
                 batch_size: int=1,
                ):
        self.batched = batched
        self.concurrency = concurrency
        self.batch_size = batch_size
    
    def execute(self, executor_function, executor_inputs):
        if (self.concurrency is None) or (self.concurrency==1):
            results = [executor_function(i) for i in executor_inputs]
        else:
            with ThreadPoolExecutor(min(self.concurrency, len(executor_inputs))) as p:
                results = list(p.map(executor_function, executor_inputs))
            
        return results

In [None]:
#| export

class DatasetExecutor(Executor):
    def __init__(self,
                 output_schema: BaseModel,
                 batched: bool,
                 concurrency: Optional[int]=1,
                 batch_size: int=1,
                 map_kwargs: Optional[dict]=None
                ):
        self.output_schema = output_schema
        self.batched = batched
        self.concurrency = concurrency
        self.batch_size = batch_size
        self.map_kwargs = map_kwargs if map_kwargs else {}
        
    def execute(self, executor_function, executor_inputs):
        
        dataset = datasets.Dataset.from_list([i.model_dump() for i in executor_inputs])
        dataset = dataset.map(lambda row: executor_function(row), batched=self.batched, 
                             batch_size=self.batch_size, num_proc=self.concurrency, **self.map_kwargs)
        results = [self.output_schema(**i) for i in dataset.to_list()]
        return results
    
    def __call__(self, 
                 executor_function: Callable, 
                 executor_inputs: BaseModel
                ) -> BaseModel:
        results = self.execute(executor_function, executor_inputs)
        return results

In [None]:
class TestInput(BaseModel):
    value: float
        
class TestOutput(BaseModel):
    result: bool
        
def test_function(input: TestInput) -> TestOutput:
    return TestOutput(result=input.value>0.5)

def test_function_batched(inputs: list[TestInput]) -> list[TestOutput]:
    return [TestOutput(result=i.value>0.5) for i in inputs]

def test_function_hf(input: dict) -> dict:
    return {'result' : input['value']>0.5}

def test_function_hf_batched(input: dict) -> dict:
    return {'result' : [i>0.5 for i in input['value']]}


np.random.seed(42)
values = np.random.uniform(size=100).tolist()

inputs = [TestInput(value=i) for i in values]
expected_outputs = [TestOutput(result=i>0.5) for i in values]

# standard

executor = Executor(batched=False)
res1 = executor(test_function, inputs)
assert res1 == expected_outputs

executor = Executor(batched=True, batch_size=5)
res2 = executor(test_function_batched, inputs)
assert res2 == expected_outputs

# process

executor = ProcessExecutor(batched=False, concurrency=1)
res3 = executor(test_function, inputs)
assert res3 == expected_outputs

executor = ProcessExecutor(batched=False, concurrency=2)
res4 = executor(test_function, inputs)
assert res4 == expected_outputs

executor = ProcessExecutor(batched=True, batch_size=5)
res5 = executor(test_function_batched, inputs)
assert res5 == expected_outputs

executor = ProcessExecutor(batched=True, batch_size=5, concurrency=2)
res6 = executor(test_function_batched, inputs)
assert res6 == expected_outputs

# thread

executor = ThreadExecutor(batched=False, concurrency=1)
res7 = executor(test_function, inputs)
assert res7 == expected_outputs

executor = ThreadExecutor(batched=False, concurrency=2)
res8 = executor(test_function, inputs)
assert res8 == expected_outputs

executor = ThreadExecutor(batched=True, batch_size=5)
res9 = executor(test_function_batched, inputs)
assert res9 == expected_outputs

executor = ThreadExecutor(batched=True, batch_size=5, concurrency=2)
res10 = executor(test_function_batched, inputs)
assert res10 == expected_outputs

# dataset

executor = DatasetExecutor(TestOutput, batched=False, concurrency=None, batch_size=1)
res11 = executor(test_function_hf, inputs)
assert res11 == expected_outputs

executor = DatasetExecutor(TestOutput, batched=False, concurrency=2, batch_size=1)
res12 = executor(test_function_hf, inputs)
assert res12 == expected_outputs

executor = DatasetExecutor(TestOutput, batched=True, concurrency=2, batch_size=5)
res13 = executor(test_function_hf_batched, inputs)
assert res13 == expected_outputs

executor = DatasetExecutor(TestOutput, batched=True, concurrency=None, batch_size=5)
res14 = executor(test_function_hf_batched, inputs)
assert res14 == expected_outputs

                                                                                

In [None]:
#| export

class Plugin():
    def __init__(self, function: Callable, executor: Executor):
        self.function = function
        self.executor = executor
        
    def gather_inputs(self, inputs: BaseModel) -> BaseModel:
        raise NotImplementedError
        
    def scatter_results(self, inputs: BaseModel, results: BaseModel) -> BaseModel:
        raise NotImplementedError
        
    def __call__(self, inputs: BaseModel) -> BaseModel:
        
        function_inputs = self.gather_inputs(inputs)
        results = self.executor(self.function, function_inputs)
        outputs = self.scatter_results(inputs, results)
        return outputs