# Module

> module functions and classes

In [None]:
#| default_exp module

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import Batch

In [None]:
#| export

class Module():
    '''
    Module - module base class
    
    Given an input `Batch`, the `Module`:
    1. gathers inputs to the `function`
    2. executes the `function`
    3. validates the results of the `function` with `output_schema`
    4. scatters results back into the `Batch`
    
    '''
    def __init__(self, 
                 output_schema: BaseModel,                              # expected output schema
                 function: Callable[List[BaseModel], List[BaseModel]],  # function to be called
                ):
        self.output_schema = output_schema
        self.function = function
        
    def gather_inputs(self, batch: Batch) -> (List[Tuple], List[BaseModel]):
        raise NotImplementedError
        
    def validate_schema(self, results: List[BaseModel]) -> List[BaseModel]:
        results = [self.output_schema.model_validate(i) for i in results]
        return results
        
    def scatter_results(self, batch: Batch, idxs: List[Tuple], results: List[BaseModel]) -> None:
        raise NotImplementedError
        
    def __call__(self, batch: Batch) -> Batch:
        
        if self.function is not None:
            idxs, inputs = self.gather_inputs(batch)
            results = self.function(inputs)
            results = self.validate_schema(results)
            self.scatter_results(batch, idxs, results)
        return batch