# Runner

> runner functions and classes

In [None]:
#| default_exp runner

[autoreload of emb_opt.log failed: Traceback (most recent call last):
  File "/home/dmai/miniconda3/envs/emb_opt/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 261, in check
    superreload(m, reload, self.old_objects)
  File "/home/dmai/miniconda3/envs/emb_opt/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 459, in superreload
    module = reload(module)
  File "/home/dmai/miniconda3/envs/emb_opt/lib/python3.9/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 613, in _exec
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
  File "/home/dmai/emb_opt/emb_opt/log.py", line 22, in <module>
    from .log import Log
ImportError: cannot import name 'Log' from 'emb_opt.log' (/home/dmai/emb_opt/emb_opt/log.py)
]


In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import (
                                Query, 
                                Batch,
                                DataSourceFunction,
                                FilterFunction,
                                ScoreFunction,
                                PruneFunction,
                                UpdateFunction
                            )
from emb_opt.data_source import DataSourceModule
from emb_opt.filter import FilterModule
from emb_opt.score import ScoreModule
from emb_opt.prune import PruneModule
from emb_opt.update import UpdateModule
from emb_opt.log import Log

In [None]:
#| export

class Runner():
    def __init__(self,
                 data_plugin: DataSourceFunction,
                 filter_plugin: Optional[FilterFunction],
                 score_plugin: ScoreFunction,
                 prune_plugin: Optional[PruneFunction],
                 update_plugin: UpdateFunction
                ):
        self.data_module = DataSourceModule(data_plugin)
        self.filter_module = FilterModule(filter_plugin) #if filter_plugin else None
        self.score_module = ScoreModule(score_plugin)
        self.prune_module = PruneModule(prune_plugin) #if prune_plugin else None
        self.update_module = UpdateModule(update_plugin)
        
    def prepare_batch(self, batch: Batch, iteration: int):
        for query in batch.queries:
            query.update_internal(iteration=iteration)

            
    def step(self, batch: Batch, log: Log, iteration: int, verbose: bool=True):
        self.prepare_batch(batch, iteration)
        
        batch = self.data_module(batch)
        batch = self.filter_module(batch)
        batch = self.score_module(batch)
        batch = self.prune_module(batch)
        
        log.add_batch(batch)
        self.report_scores(batch, iteration, verbose)
        
        if len(list(batch.valid_queries()))>0:
            batch = self.update_module(batch)
        else:
            batch = None
        return batch
        
    def search(self, batch: Batch, iterations: int, log: Optional[Log]=None, verbose: bool=True):
        if log is None:
            log = Log()
            
        i_start = len(log.batch_log)
            
        for i in range(i_start, i_start+iterations):
            batch = self.step(batch, log, i, verbose)
            if batch is None:
                break
            
        return batch, log
            
    def report_scores(self, batch, iteration, report):
        if report:
            mean_scores = [np.array([i.score for i in query.valid_results()]).mean() 
                      for query in batch.flatten_queries()[1]]
            print(iteration, ' '.join([f'{i:.2f}' for i in mean_scores]))