# Build a pipeline for labels preprocessing, statistics collection and cells meshing

In [None]:
import numpy as np
import pandas as pd
import os
from skimage.io import imread, imsave
from skimage.measure import regionprops, regionprops_table
from typing import Dict, List, Optional, Union, Tuple, Callable
import concurrent
import trimesh as tm
from time import time
import pickle

In [None]:
os.chdir('src')
os.getcwd()

In [None]:
import misc 
import SegmentationStatisticsCollector 
import LabelPreprocessing
import GenMeshes
from StatsCollection import *
from ExtendedTrimesh import ExtendedTrimesh

## 1. Load and preprocess curated segmentation images

In [None]:
#Load curated segmentation
PATH_TO_LABELED_IMG = '/nas/groups/iber/Users/Federico_Carrara/create_meshes/data/curated_labels/'
FILE_NAME = 'MBC19_S5_St1_Crop_GFP_clean_bottom.tif'

labeled_img = misc.load_labeled_img(os.path.join(PATH_TO_LABELED_IMG, FILE_NAME))   

In [None]:
#Create output folder
smoothing_iters = 10
erosion_iters = 6
dilation_iters = 8

PATH_TO_OUTPUT = './tests/output'
output_dir = misc.create_output_directory(
    output_folder=PATH_TO_OUTPUT, 
    input_img_path=os.path.join(PATH_TO_LABELED_IMG, FILE_NAME),
    smoothing_iterations=smoothing_iters, 
    erosion_iterations=erosion_iters, 
    dilation_iterations=8
)

In [None]:
#Preprocess labels
process = True

if process:
    preprocessed_labeled_img = LabelPreprocessing.process_labels(
        labeled_img=labeled_img, 
        erosion_iterations=erosion_iters,
        dilation_iterations=dilation_iters,
        output_directory=output_dir,
        overwrite=False
    )
else:
    preprocessed_labeled_img = labeled_img

Now we want to filter cells which are not good for computing morphological statistics.

We want to get the following:
- A list of indexes of cut cells, i.e., cells which are touching the background of the image --> VoxelProcessing.remove_labels_touching_edges
- A list of indexes of cells which touches the background --> VoxelProcessing.remove_labels_touching_background

In [None]:
preprocessed_labeled_img = preprocessed_labeled_img[:100, :100, :100]

In [None]:
#Filter cells for different statistics
cut_cells_idxs = LabelPreprocessing.get_labels_touching_edges(preprocessed_labeled_img, output_dir)
touching_background_idxs, background_touch_counts = LabelPreprocessing.get_labels_touching_background(preprocessed_labeled_img, output_dir, threshold=0)

In [None]:
np.intersect1d(cut_cells_idxs, touching_background_idxs)

In [None]:
cut_filtered_labeled_img = LabelPreprocessing.filter_labels(preprocessed_labeled_img, cut_cells_idxs)
bg_filtered_labeled_img = LabelPreprocessing.filter_labels(preprocessed_labeled_img, touching_background_idxs)

In [None]:
print(len(np.unique(cut_filtered_labeled_img)), len(np.unique(bg_filtered_labeled_img)))
print(len(np.unique(preprocessed_labeled_img)))
print(len(cut_cells_idxs), len(touching_background_idxs))

In [None]:
try_img = preprocessed_labeled_img.copy()
binary_mask = try_img == touching_background_idxs[0]
try_img[binary_mask] = 0
print(len(np.unique(try_img)))

In [None]:
import napari

viewer = napari.Viewer()
viewer.add_labels(preprocessed_labeled_img)
viewer.add_labels(cut_filtered_labeled_img)
viewer.add_labels(bg_filtered_labeled_img)

## 2. Mesh Generation

In [None]:
meshes = GenMeshes.convert_labels_to_meshes(
    preprocessed_labeled_img,
    [0.1625, 0.1625, 0.25],
    cut_cells_idxs,
    10,
    output_dir,
    False,
    10,
    'stl'
)

## 3. Statistics collection

In [None]:
cell_areas = compute_cell_surface_areas(meshes, cut_cells_idxs)
cell_volumes = compute_cell_volumes(meshes, cut_cells_idxs)
cell_axes, cell_elong = compute_cell_principal_axis_and_elongation(meshes, cut_cells_idxs)
# cell_neighbors = compute_cell_neighbors(preprocessed_labeled_img, cut_cells_idxs)

In [None]:
# cell_contact_area = compute_cell_contact_area(meshes, cell_neighbors, 0.1, 4)

In [None]:
# # temporary: store statistics long to compute in json

# def save_dictionary(dictionary, filename):
#     with open(filename, 'wb') as file:
#         pickle.dump(dictionary, file)

# # save neighbors
# if not os.path.exists('tests/output_s_10_e_6_d_8/temp_stats/'):
#     os.makedirs('tests/output_s_10_e_6_d_8/temp_stats/')
# save_dictionary(cell_neighbors, 'tests/output_s_10_e_6_d_8/temp_stats/cell_neighbors.pickle')
# save_dictionary(cell_contact_area, 'tests/output_s_10_e_6_d_8/temp_stats/cell_contact_area.pickle')

In [None]:
with open('tests/output_s_10_e_6_d_8/temp_stats/cell_neighbors.pickle', 'rb') as f:
    cell_neighbors = pickle.load(f)

with open('tests/output_s_10_e_6_d_8/temp_stats/cell_contact_area.pickle', 'rb') as f:
    cell_contact_area = pickle.load(f)

In [None]:
class StatsCollector:
    def __init__(
            self,
            meshes: Dict[int, tm.base.Trimesh],
            labels: np.ndarray[int],
            # original_ids: List[int],
            features: List[str],
            output_directory: str,
            path_to_img: str,
            tissue: str,
            num_workers: int
        ) -> None:

        #internal attributes
        self._features_to_functions = StatsCollector._feat_to_func_dict()
        self._tissues_to_types = StatsCollector._tissue_to_type_dict()
        
        #public attributes
        self.meshes = meshes
        self.labels = labels
        self.ids = list(self.meshes.keys())
        # self.original_ids = original_ids
        self.features = features 
        self.functions = [
            self._features_to_functions[feature] 
            for feature in self.features
        ]
        self.tissue = tissue
        self.tissue_type = self._tissues_to_types[tissue]
        self.output_dir = output_directory
        self.df_output_dir = os.path.join(self.output_dir, 'cell_stats')
        self.path_to_img = path_to_img
        file_name = os.path.basename(self.path_to_img)
        self.file_ext = os.path.splitext(file_name)[1]
        self.file_name = file_name.replace(self.file_ext, '')
        self.num_workers = num_workers
        self.cache_dir = os.path.join(self.df_output_dir, 'cached_stats')
        if not os.path.exists(self.cache_dir):
            os.makedirs(self.cache_dir)

        #apply filtering to get labels to be excluded from computation
        self.excluded_idxs = self.filter_cells()

        #initialize and save the dataframe to store statistics
        self.df = self._init_dataframe()
        #save the newly created data structure
        self._save_dataframe()

    @staticmethod
    def _feat_to_func_dict() -> Dict[str, Callable]:
        features = ['area', 'volume', 'elongation_and_axes',
                    'neighbors', 'contact_area']
        functions = [
            compute_cell_surface_areas,
            compute_cell_volumes,
            compute_cell_principal_axis_and_elongation,
            compute_cell_neighbors,
            compute_cell_contact_area
        ]

        return dict(zip(features, functions))

    @staticmethod
    def _tissue_to_type_dict() -> Dict[str, str]:
        tissues = ['bladder', 'intestine_villus', 'lung_bronchiole', 'esophagus']
        tissue_types = ['stratified_transitional', 'simple_columnar', 'simple_cuboidal', 'stratified_squamous']

        return dict(zip(tissues, tissue_types))
    
    
    def filter_cells(self) -> List[int]:
        if 'simple' in self.tissue_type:
            idxs_to_filter = LabelPreprocessing.get_labels_touching_edges(
                self.labels, self.output_dir
            )
        elif 'stratified' in self.tissue_type:
            raise NotImplementedError()
            # idxs_to_filter = LabelPreprocessing.get_labels_touching_edges(
            #     self.labels, self.output_dir
            # )
        
        return idxs_to_filter
        

    def _save_dataframe(
            self,
            overwrite: bool = True
    ) -> None:
        
        if not os.path.exists(self.df_output_dir):
            os.makedirs(self.df_output_dir)

        path_to_file = os.path.join(self.df_output_dir, self.file_name)
        if (not os.path.isfile(path_to_file)) or overwrite:
            self.df.to_csv(path_to_file)

    
    def _init_dataframe(self) -> pd.DataFrame:
        #initialize the data structure
        df = pd.DataFrame(
            data={
                'cell_ID': self.ids,
                'tissue': self.tissue,
                'file_name': self.path_to_img
                # 'original_cell_ID': self.original_ids
            }
        )
        df['mesh_dir'] = [
            os.path.join(self.output_dir, 'cell_meshes', f'cell_{id}.stl')
            for id in self.ids
        ]
        df['exclude_cell'] = [id in self.excluded_idxs for id in self.ids]

        return df

    @staticmethod
    def _unpack_feature_dict(
            feature_dict: Dict[int, any]
        ) -> pd.Series:
        '''
        Unpack the dictionary associated to each feature in a pd.Series.

        Parameters:
        -----------
            feature_dict: (Dict[int, any])
                A dict whose keys are cell ids and values are the associated statistics value.

        Returns:
        --------
            feature_unpacked: (pd.Series[any])
                A pd.Series of the statistics values.
        '''
        feature_unpacked = pd.Series(list(feature_dict.values()))
        return feature_unpacked
    
    
    def _add_to_dataframe(
        self,
        feature_dict: Dict[int, any],
        feature_name: str,
    ) -> None:

        #unpack the dictionary
        feature_data = StatsCollector._unpack_feature_dict(feature_dict)

        #add column to df
        self.df[feature_name] = feature_data 

    def _to_cache(
            self,
            feature_dict: Dict[int, any], 
            feature_name: str
    ) -> None:

        save_name = f'cell_{feature_name}.pickle'
        with open(os.path.join(self.cache_dir, save_name), 'wb') as file:
            pickle.dump(feature_dict, file)

    
    def _from_cache(
            self,
            feature_name: str
    ) -> Dict[int, any]:

        assert os.path.exists(self.cache_dir), 'Cannot load from cache as it is empty.'

        save_name = f'cached_stats/cell_{feature_name}.pickle'
        with open(os.path.join(self.df_output_dir, save_name), 'rb') as file:
            feature_dict = pickle.load(file)
    
        return feature_dict
    

    def _get_args(
            self,
            feature_name: str,
    ) -> List[any]:
        if feature_name == 'neighbors':
            args = (
                self.labels,
                self.excluded_idxs
            )
        elif feature_name == 'contact_area':
            try:
                neighbors_dict = self._from_cache('neighbors')
            except Exception as e:
                raise OSError('Neighbors dictionary not found in the cache.') from e
            args = (
                self.meshes,
                neighbors_dict,
                self.num_workers
            )
        else:
            args = (
                self.meshes,
                self.excluded_idxs
            )

        return args


    def _process_df(
        self,
    ) -> None:
        #count neighbors
        self.df['num_neighbors'] = self.df['neighbors'].apply(lambda x: len(x))

        #split elongation and principal axes
        self.df['elongation'] = self.df['elongation_and_axes'].apply(lambda x: x[0])
        self.df['principal_axes'] = self.df['elongation_and_axes'].apply(lambda x: x[1])
        self.df.drop(columns=['elongation_and_axes'], inplace=True)

        #split and extract statistics from contact area
        self.df['contact_area_fraction'] = self.df['contact_area'].apply(lambda x: x[0])
        self.df['contact_area_distribution'] = self.df['contact_area'].apply(lambda x: x[1])
        self.df['mean_contact_area'] = self.df['contact_area_distribution'].apply(lambda x: np.mean(x))
        self.df['total_contact_area'] = self.df['contact_area_distribution'].apply(lambda x: np.sum(x))
        self.df.drop(columns='contact_area', inplace=True)


    def collect_statistics(
            self,
            load_from_cache: Optional[bool] = False
    ) -> None:
        
        for func, feat in zip(self.functions, self.features):
            if load_from_cache:
                print(f'Loading cached cell {feat} ...')
                feat_dict = self._from_cache(feat)
            else:
                args = self._get_args(feat)
                feat_dict = func(*args)

                #cache the dict
                self._to_cache(feat_dict, feat)
            
            #add the new stat to the dataframe
            self._add_to_dataframe(feat_dict, feat)

            #save the dataframe
            self._save_dataframe(overwrite=True)

        #postprocess dataframe
        self._process_df()
        self._save_dataframe(overwrite=True)
            

In [None]:
stats_collector = StatsCollector(
    meshes=meshes,
    labels=preprocessed_labeled_img,
    # original_ids=list(meshes.keys()),
    features=['area', 'volume', 'elongation_and_axes', 'neighbors', 'contact_area'],
    output_directory='tests/output_s_10_e_6_d_8',
    path_to_img='tests/output_s_10_e_6_d_8/processed_labels.npy',
    tissue='lung_bronchiole',
    num_workers=4
)

In [None]:
stats_collector.collect_statistics(load_from_cache=True)

In [None]:
stats_collector.df

In [None]:
stats_collector.df["isoper_ratio"] = stats_collector.df.apply(lambda x: x['area']**3/x['volume']**2, axis=1)

In [None]:
stats_collector.df