# API

> This module defines the API of *findmycells*

In [None]:
#| default_exp api

In [None]:
from pathlib import Path, PosixPath
from typing import List, Dict, Tuple, Optional

from tqdm.notebook import tqdm
from datetime import datetime
import pickle

from findmycells.preprocessing.specs import PreprocessingStrategy, PreprocessingObject

In [None]:
class API:
    
    def __init__(self, project_root_dir: PosixPath,
                 project_configs_filepath: Optional[PosixPath]=None,
                 database_filepath: Optional[PosixPath]=None) -> None:
        assert type(project_root_dir) == PosixPath, '"project_root_dir" must be pathlib.Path object referring to an existing directory.'
        assert project_root_dir.is_dir(), '"project_root_dir" must be pathlib.Path object referring to an existing directory.'
        if project_configs_filepath != None:
            assert type(project_configs_filepath) == PosixPath, '"project_configs_filepath" must be pathlib.Path object referring to a .config file.'
            assert project_configs_filepath.suffix == '.config', '"project_configs_filepath" must be pathlib.Path object referring to a .config file.'
        if database_filepath != None:
            assert type(database_filepath) == PosixPath, '"database_filepath" must be pathlib.Path object referring to a .obj file'
            assert database_filepath.suffix == '.obj', '"database_filepath" must be pathlib.Path object referring to a .obj file'
        self.project_configs = ProjectConfigs(project_root_dir = project_root_dir, project_configs_filepath = project_configs_filepath)
        self.database = Database(project_configs = project_configs, database_filepath = database_filepath)
        
        
    def update_database_with_current_source_files(self, skip_checking: bool=False) -> None:
        self.database.compute_file_infos(skip_checking = skip_checking)
        
        
    def preprocess(self,
                   strategies: List[PreprocessingStrategy],
                   strategy_configs: Optional[List[Dict]]=None,
                   processing_configs: Optional[Dict]=None,
                   file_ids: Optional[List[str]]=None
                  ) -> None:
        processing_step_id = 'preprocessing'
        strategy_configs, processing_configs, file_ids = self._assert_and_update_input(processing_step_id = processing_step_id,
                                                                                       strategies = strategies,
                                                                                       strategy_configs = strategy_configs,
                                                                                       processing_configs = processing_configs,
                                                                                       file_ids = file_ids)
        for file_id in tqdm(file_ids, display = processing_configs['show_progress']):
            preprocessing_object = PreprocessingObject()
            preprocessing_object.prepare_for_processing(file_ids = [file_id], database = self.database)
            preprocessing_object.load_image_and_rois()
            preprocessing_object.run_all_strategies(strategies = strategies, strategy_configs = strategy_configs)
            preprocessing_object.save_preprocessed_images_on_disk()
            preprocessing_object.save_preprocessed_rois_in_database()
            preprocessing_object.update_database()
            del preprocessing_object
            if processing_configs['autosave'] == True:
                self.save_status()
                
    
    def save_status(self): -> None:
        self.project_configs.save_to_disk()
        self.database.save_to_disk()
        
        
    def load_status(self,
                    project_configs_filepath: Optional[PosixPath],
                    database_filepath: Optional[PosixPath]
                   ) -> None:
        if project_configs_filepath != None:
            assert type(project_configs_filepath) == PosixPath, '"project_configs_filepath" must be pathlib.Path object referring to a .config file.'
            assert project_configs_filepath.suffix == '.config', '"project_configs_filepath" must be pathlib.Path object referring to a .config file.'
        else:
            project_configs_filepath = self._look_for_latest_status_filepath(suffix = '.config', dir_path = self.root_dir)
        if database_filepath != None:
            assert type(database_filepath) == PosixPath, '"database_filepath" must be pathlib.Path object referring to a .obj file'
            assert database_filepath.suffix == '.obj', '"database_filepath" must be pathlib.Path object referring to a .obj file'
        else:
            database_filepath = self._look_for_latest_status_filepath(suffix = '.obj', dir_path = self.root_dir.joinpath('results'))
        self.project_configs.load_from_disk(project_configs_filepath = project_configs_filepath)
        self.database.load_from_disk(database_filepath = database_filepath, project_configs = self.project_configs)        
        

    def split_file_ids_into_batches(self, file_ids: List[str], batch_size: int) -> List[List[str]]:
        if len(file_ids) % batch_size == 0:
            total_batches = int(len(file_ids) / batch_size)
        else:
            total_batches = int(len(file_ids) / batch_size) + 1
        file_ids_per_batch = []
        for batch in range(total_batches):
            if len(file_ids) >= batch_size:
                sampled_file_ids = random.sample(file_ids, batch_size)
            else:
                sampled_file_ids = file_ids.copy()
            file_ids_per_batch.append(sampled_file_ids)
            for elem in sampled_file_ids:
                file_ids.remove(elem)    
        return file_ids_per_batch


    def _look_for_latest_status_file_in_dir(self, suffix: str, dir_path: PosixPath) -> PosixPath:
        matching_filepaths = [filepath for filepath in dir_path.iter_dir() if filepath.suffix == suffix]
        if len(matching_filepaths) == 0:
            raise FileNotFoundError(f'Could not find a "{suffix}" file in {dir_path}. Consider specifying the exact filepath!')
        else:
            date_strings = [filepath[:10] for filepath in matching_filepaths]
            dates = [datetime.strptime(date_str, '%Y_%m_%d') for date_str in date_strings]
            latest_date = max(dates)
            filepath_idx = dates.index(latest_date)
            latest_status_filepath = matching_filepaths[filepath_idx]
        return latest_status_filepath        
        
        
    def _assert_and_update_input(self, 
                                 processing_step_id: str,
                                 strategies: List[PreprocessingStrategy],
                                 strategy_configs: Optional[List[Dict]],
                                 processing_configs: Optional[Dict],
                                 file_ids: Optional[List[str]]
                                ) -> Tuple[List[Dict], Dict, List[str]]:
        self._assert_processing_step_input(processing_step_id = processing_step_id,
                                           strategies = strategies,
                                           strategy_configs = strategy_configs,
                                           processing_configs = processing_configs,
                                           file_ids = file_ids)
        strategy_configs = self._fill_strategy_configs_with_defaults_where_needed(strategies, strategy_configs)
        if processing_configs == None:
            processing_configs = getattr(self.project_configs, processing_step_id)
        processing_configs = self._fill_processing_configs_with_defaults_where_needed(processing_step_id, processing_configs)
        self.project_configs.add_processing_step_configs(processing_step_id, configs = processing_configs)
        file_ids = self.database.get_file_ids_to_process(input_file_ids = file_ids,
                                                         processing_step_id = processing_step_id,
                                                         overwrite = processing_configs['overwrite'])
        return strategy_configs, processing_configs, file_ids
        
            
        
    def _assert_processing_step_input(self, 
                                      processing_step_id: str,
                                      strategies: List[PreprocessingStrategy],
                                      strategy_configs: Optional[List[Dict]],
                                      processing_configs: Optional[Dict],
                                      file_ids: Optional[List[str]]
                                     ) -> None:
        assert type(strategies) == list, '"strategies" has to ba a list of ProcessingStrategy classes of the respective processing step!'
        if strategy_configs != None:
            assert type(strategy_configs) == list, '"strategy_configs" has to be None or a list of the same length as "strategies"!'
            assert len(strategy_configs) == len(strategies), '"strategy_configs" has to be None or a list of the same length as "strategies"!'
        else:
            strategy_configs = [None] * len(strategies)
        available_strategies = self.project_configs.available_processing_strategies[processing_step_id]
        for strat, config in zip(strategies, strategy_configs):
            assert strat in available_strategies, f'{strat} is not an available strategy for {processing_step_id}!'
            if config != None:
                strat().default_configs.assert_user_input(user_input = config)
        if processing_configs != None:
            processing_obj = self.project_configs.available_processing_objects[processing_step_id]()
            processing_obj.default_configs.assert_user_input(user_input = processing_configs)
        if file_ids != None:
            assert type(file_ids) == list, '"file_ids" has to be a list of strings referring to file_ids in the database!'
            for elem in file_ids:
                assert elem in self.database.file_infos['file_id'], f'{elem} is not a valid file_id!'
        
        
    def _fill_processing_configs_with_defaults_where_needed(self,
                                                            processing_step_id: str,
                                                            processing_configs: Dict
                                                           ) -> Dict:
        processing_obj = self.project_configs.available_processing_objects[processing_step_id]()
        return processing_obj.default_configs.fill_user_input_with_defaults_where_needed(user_input = processing_configs)                                              
             
        
    def _fill_strategy_configs_with_defaults_where_needed(self,
                                                          strategies: List[ProcessingStrategy],
                                                          strategy_configs: Optional[List[Dict]]
                                                         ) -> List[Dict]:
        all_final_configs = []
        if strategy_configs == None:
            strategy_configs = [{}] * len(strategies)
        for strat, configs in zip(strategies, strategy_configs):
            full_configs = strat().default_configs.fill_user_input_with_defaults_where_needed(user_input = configs)
            all_final_configs.append(full_configs)
        return all_final_configs            
        
            

In [None]:
class refactored_Project:

    def __init__(self, project_root_dir: Path, project_configs_filepath: Optional[Path]=None) -> None:
        self.project_configs = ProjectConfigs(project_root_dir = project_root_dir, project_configs_filepath = project_configs_filepath)
        self.project_root_dir = project_root_dir
        self.database = Database(project_configs = project_configs)


    def load_microscopy_image_filepaths(self) -> None:
        # If microscopy images are already present --> retrieve filepaths & group / subject / condition metadata
        pass
    
    
    def load_roi_filepaths(self) -> None:
        # If roi files are already present --> retrieve filepaths & ensure that they match with the microscopy image filepaths
        # Since ROIs should be optionally (ideally, since whole image analysis should also be possible) don´t break if missing
        pass


    def load_status(self, project_configs_filepath: Path) -> None:
        raise NotImplementedError()
        self.database.load_all()
        self.project_configs.attempt_loading_from_configs_filepath(project_configs_filepath = project_configs_filepath)
        

    def preprocess(self,
                   strategies: List[PreprocessingStrategy],
                   strategy_configs: Optional[List[Dict]]=None,
                   processing_configs: Optional[Dict]=None,
                   file_ids: Optional[List[str]]=None
                  ) -> None:
        updated_inputs = self._validate_input_and_prepare_processing_step(strategies = strategies,
                                                                          strategy_configs = strategy_configs,
                                                                          processing_configs = processing_configs,
                                                                          file_ids = file_ids,
                                                                          processing_object = PreprocessingObject,
                                                                          processing_type = 'preprocessing')
        updated_processing_configs, updated_strategy_configs, updated_file_ids = updated_inputs
        for file_id in updated_file_ids:
            preprocessing_object = PreprocessingObject(file_ids = [file_id], database = self.database, project_configs = self.project_configs)
            preprocessing_object.load_image_and_rois()
            preprocessing_object.run_all_strategies(strategies = strategies, strategy_configs = updated_strategy_configs)
            preprocessing_object.save_preprocessed_images_on_disk()
            preprocessing_object.save_preprocessed_rois_in_database()
            preprocessing_object.update_database()
            del preprocessing_object
            
    
    def save_status(self) -> None:
        raise NotImplementedError()
        self.database.save_all()
        self.project_configs.export_as_yml()
       
        
    def _assert_valid_input_for_processing_methods(self, 
                                                   strategies: List[ProcessingStrategy],
                                                   strategy_configs: Optional[List[Dict]],
                                                   processing_configs: Optional[Dict],
                                                   file_ids: Optional[List[str]],
                                                   processing_type: str
                                                  ) -> None:
        self._assert_types_for_all_elements_in_list(list_to_check = strategies,
                                                    allow_none_instead_of_list = False,
                                                    allowed_types = [ProcessingStrategy],
                                                    id_for_error = 'strategies')
        self._assert_types_for_all_elements_in_list(list_to_check = file_ids,
                                                    allow_none_instead_of_list = True,
                                                    allowed_types = [str],
                                                    id_for_error = 'file IDs')        
        self._assert_types_for_all_elements_in_list(list_to_check = strategy_configs,
                                                    allow_none_instead_of_list = True,
                                                    allowed_types = [Dict],
                                                    id_for_error = 'strategy configs')
        if type(strategy_configs) == list:
            assert len(strategies) != len(strategy_configs), 'You must provide exactly as many strategy config dictionaries as ProcessingStrategy objects!'
        for strategy in strategies:
            assert strategy().processing_type == processing_type, f'Not all processing strategies are intended to be used for this step!'
        assert type(processing_configs) not in [dict, type(None)], '"processing_configs" was neither a dictionary nor None.'
        # assert that all file_ids exist       


    def _assert_types_for_all_elements_in_list(self, 
                                               list_to_check: Optional[List[Any]], 
                                               allow_none_instead_of_list: bool, 
                                               allowed_types: List[type], 
                                               id_for_error: str
                                              ) -> None:
        if allow_none_instead_of_list == True:
            assert type(list_to_check) in [list, type(None)], f'Neither a list nor "None" was provided, but a {type(list_to_check)} instead. Identifier for Traceback: {id_for_error}.'
        else:
            assert type(list_to_check) == list, f'"list_to_check" is actually not a list! Identifier for Traceback: {id_for_error}.'
        if list_to_check =! None:
            for elem in list_to_check:
                assert type(elem) in allowed_types, f'Found {type(elem)} in a list where {allowed_types} are allowed. Identifier for Traceback: {id_for_error}.'
            
        
    def _validate_input_and_prepare_processing_step(self,
                                                     strategies: List[ProcessingStrategy],
                                                     strategy_configs: Optional[List[Dict]],
                                                     processing_configs: Optional[Dict],
                                                     file_ids: Optional[List[str]],
                                                     processing_object_class: ProcessingObject,
                                                     processing_type: str
                                                    ) -> Tuple[Dict, List[Dict], List[str]]:
        self._assert_valid_input_for_processing_methods(strategies = strategies,
                                                        strategy_configs = strategy_configs
                                                        processing_configs = processing_configs,
                                                        file_ids = file_ids,
                                                        processing_type = processing_type)
        updated_processing_configs = self.project_configs.set_processing_type_specific_configs(processing_object_class = processing_object_class, 
                                                                                               processing_type_specific_configs = processing_configs)
        updated_strategy_configs = []
        if type(strategy_configs) == list:
            for strategy, configs in zip(strategies, strategy_configs):
                validated_configs = self.project_configs.set_strategy_specific_configs(strategy = strategy, strategy_configs = configs)
                updated_strategy_configs.append(validated_configs)
        else:
            for strategy in strategies:
                validated_configs = self.project_configs.set_strategy_specific_configs(strategy = strategy, strategy_configs = None)
                updated_strategy_configs.append(validated_configs)
        updated_file_ids = self.database.get_file_ids_to_process(requested_file_ids = file_ids, 
                                                                 processing_type = processing_type, 
                                                                 overwrite = updated_processing_configs['overwrite'])
        return updated_processing_configs, updated_strategy_configs, updated_file_ids
    
    

    

In [None]:
class Project:

    def __init__(self, user_input: Dict):
        self.project_root_dir = user_input['project_root_dir']
        self.database = Database(user_input)


    def save_status(self) -> None:
        self.database.save_all()


    def load_status(self) -> None:
        self.database.load_all()


    def preprocess(self, strategies: List[PreprocessingStrategy], file_ids: Optional[List]=None, overwrite: bool=False) -> None:
        file_ids = self.database.get_file_ids_to_process(input_file_ids = file_ids, process_tracker_key = 'preprocessing_completed', overwrite = overwrite)
        for file_id in file_ids:
            preprocessing_object = PreprocessingObject(database = self.database, file_ids = [file_id], strategies = strategies)
            preprocessing_object.run_all_strategies()
            preprocessing_object.save_preprocessed_images_on_disk()
            preprocessing_object.save_preprocessed_rois_in_database()
            preprocessing_object.update_database()
            del preprocessing_object
            

    def segment(self, strategies: List[SegmentationStrategy], file_ids: Optional[List]=None, batch_size: Optional[int]=None,
                run_strategies_individually: bool=True, overwrite: bool=False, autosave: bool=True, clear_tmp_data: bool=False) -> None:
        # check if there is a new strategy - if yes: reset "segmentation_completed" for all files to "None"
        self.database.file_infos['segmentation_completed'] = self.reset_file_infos_if_new_strategy(strategies = strategies)
        
        if type(batch_size) == int:
            file_ids_per_batch = self.create_batches(batch_size = batch_size, file_ids = file_ids, process_tracker_key = 'segmentation_completed', overwrite = overwrite)
            if file_ids_per_batch == None:
                return None
        else:
            file_ids_per_batch = [file_ids]
        
        if run_strategies_individually:
            for segmentation_strategy in strategies:
                for batch_file_ids in file_ids_per_batch:
                    tracker = f'{segmentation_strategy().segmentation_type}_segmentations_done'
                    tmp_file_ids = self.database.get_file_ids_to_process(input_file_ids = batch_file_ids, process_tracker_key = tracker, overwrite = overwrite)
                    if len(tmp_file_ids) > 0:
                        segmentation_object = SegmentationObject(database = self.database, file_ids = tmp_file_ids, strategies = [segmentation_strategy])
                        segmentation_object.run_all_strategies()
                        del segmentation_object
                        if autosave:
                            self.database.save_all()
            if file_ids_per_batch[0] == None:
                all_file_ids = None
            else:
                all_file_ids = []
                for batch_file_ids in file_ids_per_batch:
                    all_file_ids += batch_file_ids
            all_file_ids = self.database.get_file_ids_to_process(input_file_ids = all_file_ids, process_tracker_key = 'segmentation_completed', overwrite = overwrite)
            if len(all_file_ids) > 0:
                segmentation_object = SegmentationObject(database = self.database, file_ids = all_file_ids, strategies = strategies)
                segmentation_object.update_database()
                del segmentation_object
                if autosave:
                    self.database.save_all()
        else:
            for batch_file_ids in file_ids_per_batch:
                batch_file_ids = self.database.get_file_ids_to_process(input_file_ids = batch_file_ids, process_tracker_key = 'segmentation_completed', overwrite = overwrite)
                segmentation_object = SegmentationObject(database = self.database, file_ids = batch_file_ids, strategies = strategies)
                segmentation_object.run_all_strategies()
                segmentation_object.update_database()
                del segmentation_object
                if autosave:
                    self.database.save_all()
        
        if clear_tmp_data:
            file_ids = self.database.get_file_ids_to_process(input_file_ids = None, process_tracker_key = 'segmentation_completed', overwrite = True)
            segmentation_object = SegmentationObject(database = self.database, file_ids = file_ids, strategies = strategies)
            segmentation_object.clear_all_tmp_data()


    def reset_file_infos_if_new_strategy(self, strategies: List[ProcessingStrategy]) -> List:
        new_strategy = False
        for strategy in strategies:
            processing_type = strategy().processing_type
            strategy_name = strategy().strategy_name
            matching_index = [key for key, column in self.database.file_infos.items() if f'{processing_type}_step' in key and strategy_name in column]
            if len(matching_index) == 0:
                new_strategy = True
                break
        if f'{processing_type}_completed' not in self.database.file_infos.keys():
            column = [None] * len(self.database.file_infos['file_id'])
        elif new_strategy:
            column = [None] * len(self.database.file_infos['file_id'])
        else:
            column = self.database.file_infos[f'{processing_type}_completed']
        return column


    def create_batches(self, batch_size: int, file_ids: List[str], process_tracker_key: str, overwrite: bool) -> Optional[List]:
            if batch_size <= 0:
                raise ValueError('"batch_size" must be greater than 0!')
            all_file_ids = self.database.get_file_ids_to_process(input_file_ids = file_ids, process_tracker_key = process_tracker_key, overwrite = overwrite)
            if len(all_file_ids) == 0:
                file_ids_per_batch = None
            else:
                file_ids_per_batch = []
                while len(all_file_ids) > 0:
                    if len(all_file_ids) >= batch_size:
                        subsample = random.sample(all_file_ids, batch_size)
                        for elem in subsample:
                            all_file_ids.remove(elem)
                        file_ids_per_batch.append(subsample)
                    else:
                        file_ids_per_batch.append(all_file_ids)
                        all_file_ids = []
            return file_ids_per_batch
    

    def postprocess(self, strategies: List[PostprocessingStrategy], segmentations_to_use: str, file_ids: Optional[List]=None, overwrite: bool=False) -> None:
        if segmentations_to_use not in ['semantic', 'instance']:
            raise ValueError('"segmentations_to_use" must be either "semantic" or "instance"')
        else:
            segmentations_to_use_dir = getattr(self.database, f'{segmentations_to_use}_segmentations_dir')
            segmentations_present = False
            for elem in segmentations_to_use_dir.iterdir():
                if elem.is_file():
                    segmentations_present = True
                    break
            if not segmentations_present:
                if segmentations_to_use == 'semantic':
                    error_message_line0 = f'It seems like there are no {segmentations_to_use} segmentations present in the corresponding directory.\n'
                    error_message_line1 = 'You need to run segmentations first, before you can postprocess them.'
                    error_message = error_message_line0 + error_message_line1
                    raise ValueError(error_message)
                else: # has to be instance then
                    error_message_line0 = f'It seems like there are no {segmentations_to_use} segmentations present in the corresponding directory.\n'
                    error_message_line1 = 'Did you mean to use "semantic" instead? Otherwise, please run the respective instance segmentations first.'
                    error_message = error_message_line0 + error_message_line1
                    raise ValueError(error_message)

            file_ids = self.database.get_file_ids_to_process(input_file_ids = file_ids, process_tracker_key = 'postprocessing_completed', overwrite = overwrite)
            for file_id in file_ids:
                print(f'Postprocessing segmentations of file ID: {file_id} ({file_ids.index(file_id) + 1}/{len(file_ids)})')
                postprocessing_object = PostprocessingObject(database = self.database, 
                                                             file_ids = [file_id], 
                                                             strategies = strategies, 
                                                             segmentations_to_use = segmentations_to_use)
                postprocessing_object.run_all_strategies()
                postprocessing_object.save_postprocessed_segmentations()
                postprocessing_object.update_database()
                del postprocessing_object


    def quantify(self, strategies: List[QuantificationStrategy], file_ids: Optional[List]=None, overwrite: bool=False) -> None:
        file_ids = self.database.get_file_ids_to_process(input_file_ids = file_ids, process_tracker_key = 'quantification_completed', overwrite = overwrite)
        for file_id in file_ids:
            print(f'Quantification of file ID: {file_id} ({file_ids.index(file_id) + 1}/{len(file_ids)})')
            quantification_object = QuantificationObject(database = self.database, file_ids = [file_id], strategies = strategies)
            quantification_object.run_all_strategies()
            quantification_object.update_database()
            del quantification_object
            
    """        
    def inspect(self, quantification_strategy_index: int=0, file_ids: Optional[List]=None, 
                area_roi_ids: Optional[List]=None, label_indices: Optional[List]=None, show: bool=True, save: bool=False) -> None:
        from .inspection import InspectionObject
        quantification_strategy_str = list(self.database.quantification_results.keys())[quantification_strategy_index]
        file_ids_not_quantified = self.database.get_file_ids_to_process(input_file_ids = file_ids, process_tracker_key = 'quantification_completed', overwrite = False)
        if file_ids == None:
            file_ids = self.database.file_infos['file_id']
        file_ids_quantified = [elem for elem in file_ids if elem not in file_ids_not_quantified]
        for file_id in file_ids_quantified:
            valid_area_roi_ids = self.database.area_rois_for_quantification[file_id]['all_planes'].keys()
            if area_roi_ids == None:
                tmp_area_roi_ids = valid_area_roi_ids
            else:
                tmp_area_roi_ids = [elem for elem in area_roi_ids if elem in valid_area_roi_ids]
            for area_roi_id in tmp_area_roi_ids:
                total_labels = self.database.quantification_results[quantification_strategy_str][file_id][area_roi_id]
                if label_indices == None:
                    tmp_label_indices = [elem for elem in range(total_labels)]
                else:
                    tmp_label_indices = [elem for elem in label_indices if elem < total_labels]
                for label_index in tmp_label_indices:
                    inspection_object = InspectionObject(database = self.database, file_id = file_id, area_roi_id = area_roi_id, label_index = label_index, show = show, save = save)
                    inspection_object.run_all_inspection_steps()

                    
    def run_inspection(self, file_id: str, inspection_strategy):
        from .inspection import InspectionStrategy
        inspection_strategy.run(self.database, file_id)
    """
        
    
    def remove_file_id_from_project(self, file_id: str):
        self.database.remove_file_id_from_project(file_id = file_id)

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

def listdir_nohidden(path: Path) -> List:
    return [f for f in os.listdir(path) if f.startswith('.') == False]