# Specifications of preprocessing specific subclasses

> This module defines the specifications of preprocessing related subclasses of `ProcessingStrategy` and `ProcessingObject`

In [None]:
#| default_exp preprocessing/specs

In [None]:
#| export

import numpy as np
from shapely.geometry import Polygon
from typing import List, Dict
from skimage.io import imsave

from findmycells.core import ProcessingObject, ProcessingStrategy, DataLoader
from findmycells.database import Database
from findmycells import readers

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

class PreprocessingStrategy(ProcessingStrategy):
    
    """
    Extending the `ProcssingStrategy` base class for preprocessing as processing subtype.
    """
    
    @property
    def processing_type(self):
        return 'preprocessing'

In [None]:
#| export

class PreprocessingObject(ProcessingObject):
    
    """
    Extending the `ProcessingObject` base class for preprocessing as processing subtype.
    Responsible for loading the microscopy image(s) and corresponding ROI(s) for each file,
    running the specified preprocessing strategies, updating the database, and eventually 
    for saving the preprocessed images to disk for further processing steps down the line.
    
    Note: Even though the `file_ids` argument accepts (and actually expects & requires) a 
          list as input, only a single file_id will be passed to a `PreprocessingObject`
          upon initialization. This is handled in the api module of findmycells.
    """


    @property
    def processing_type(self):
        return 'preprocessing'
    
    
    def __init__(self, database: Database, file_ids: List[str], strategies: List[ProcessingStrategy]) -> None:
        super().__init__(database = database, file_ids = file_ids, strategies = strategies)
        self.file_id = file_ids[0]
        self.file_info = self.database.get_file_infos(identifier = self.file_id)
        self.preprocessed_image = self._load_microscopy_image()
        self.preprocessed_rois = self._load_rois()


    def _load_microscopy_image(self) -> np.ndarray:
        microscopy_image_data_loader = DataLoader()
        microscopy_image_reader_class = microscopy_image_data_loader.determine_reader(file_extension = self.file_info['microscopy_filetype'],
                                                                                      data_reader_module = readers.microscopy_images)
        microscopy_image = microscopy_image_data_loader.load(data_reader_class = microscopy_image_reader_class,
                                                             filepath = self.file_info['microscopy_filepath'],
                                                             database = self.database)
        return microscopy_image
    

    def _load_rois(self) -> Dict[str, Dict[str, Polygon]]:
        if self.file_info['rois_present'] ==  False:
            raise NotImplementedError('As of now, it is not supported to not provide a ROI file for each image. If you would like to '
                                      'quantify the image features in the entire image, please provide a ROI that covers the entire image. '
                                      'Warning: However, please be aware that this feature was not fully tested yet and will probably cause '
                                      'problems, specifically if any cropping is used as preprocessing method.')
        else: # means: self.file_info['rois_present'] == True
            roi_data_loader = DataLoader()
            roi_reader_class = roi_data_loader.determine_reader(file_extension = self.file_info['rois_filetype'])
            extracted_roi_data = roi_data_loader.load(data_reader_class = roi_reader_class,
                                                      filepath = self.file_info['rois_filepath'],
                                                      database = self.database)
        return extracted_roi_data


    def add_processing_specific_infos_to_updates(self, updates: Dict) -> Dict:
        if self.preprocessed_image.shape[3] == 3:
            updates['RGB'] = True
        else:
            updates['RGB'] = False
        updates['total_planes'] = self.preprocessed_image.shape[0]
        return updates
    

    def save_preprocessed_images_on_disk(self) -> None:
        for plane_index in range(self.preprocessed_image.shape[0]):
            image = self.preprocessed_image[plane_index].astype('uint8')
            filepath_out = self.database.preprocessed_images_dir.joinpath(f'{self.file_id}-{str(plane_index).zfill(3)}.png')
            imsave(filepath_out, image)


    def save_preprocessed_rois_in_database(self) -> None:
        self.database.import_rois_dict(file_id = self.file_id, rois_dict = self.preprocessed_rois)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()