# Runner

> runner functions and classes

In [None]:
#| default_exp runner

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import (
                                Query, 
                                Batch,
                                DataSourceFunction,
                                FilterFunction,
                                ScoreFunction,
                                PruneFunction,
                                UpdateFunction
                            )
from emb_opt.data_source import DataSourceModule
from emb_opt.filter import FilterModule
from emb_opt.score import ScoreModule
from emb_opt.prune import PruneModule
from emb_opt.update import UpdateModule
from emb_opt.log import Log

The `Runner` class holds plugin functions for each step and executes the embedding search.

In [None]:
#| export

class Runner():
    def __init__(self,
                 data_plugin: DataSourceFunction,          # data source function
                 filter_plugin: Optional[FilterFunction],  # optional filter function
                 score_plugin: ScoreFunction,              # score function
                 prune_plugin: Optional[PruneFunction],    # optional prune function
                 update_plugin: UpdateFunction             # update function
                ):
        self.data_module = DataSourceModule(data_plugin)
        self.filter_module = FilterModule(filter_plugin)
        self.score_module = ScoreModule(score_plugin)
        self.prune_module = PruneModule(prune_plugin)
        self.update_module = UpdateModule(update_plugin)
        
    def prepare_batch(self, batch: Batch, iteration: int):
        for query in batch.queries:
            query.update_internal(iteration=iteration)

            
    def step(self, 
             batch: Batch, 
             log: Log, 
             iteration: int, 
             verbose: bool=True
            ) -> Batch:
        
        self.prepare_batch(batch, iteration)
        
        batch = self.data_module(batch)
        batch = self.filter_module(batch)
        batch = self.score_module(batch)
        batch = self.prune_module(batch)
        
        log.add_batch(batch)
        self.report_scores(batch, iteration, verbose)
        
        if len(list(batch.valid_queries()))>0:
            batch = self.update_module(batch)
        else:
            batch = None
        return batch
        
    def search(self, 
               batch: Batch, 
               iterations: int, 
               log: Optional[Log]=None, 
               verbose: bool=True
              ) -> (Batch, Log):
        if log is None:
            log = Log()
            
        i_start = len(log.batch_log)
            
        for i in range(i_start, i_start+iterations):
            batch = self.step(batch, log, i, verbose)
            if batch is None:
                break
            
        return batch, log
            
    def report_scores(self, batch: Batch, iteration: int, report: bool):
        if report:
            mean_scores = [np.array([i.score for i in query.valid_results()]).mean() 
                      for query in batch.flatten_queries()[1]]
            print(iteration, ' '.join([f'{i:.2f}' for i in mean_scores]))