# Quality checks, morphological profile assembly and feature selection
Written by: Amanda Ng R.H.
<br>Overall status: Clean
<br>Language: `python3`
<br>Created on: 18 Apr 2023
<br>Last updated on: 06 Sep 2023
<br>Prior data processing: None (using image-level and object-level measurements directly)
<br>Documentation status: Good ([Sphinx documentation style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html))

## Import packages

In [None]:
import os
import warnings

import pandas as pd

import seaborn as sns
from matplotlib import pylab as plt

import statistics
import numpy as np
import scipy

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import IsolationForest

import imageio
import cv2

## Import self-written functions from `post_feature_extraction_modules.py`

In [None]:
# Add the paths to the modules to where python will search for modules
import sys
root = "/research/lab_winter/users/ang"
module_paths = [
    f"{root}/isogenicCPA_repo/2_feature_analysis"
]
for module_path in module_paths:
    sys.path.insert(0, module_path)

from post_feature_extraction_modules import *
from general_modules import *

## Setting up

In [None]:
#######
# Paths
#######
# Cell line directories
cell_dir_list = ["c662_rko_wt", "c1141_rko_ko", "c1327_rko_oe"]

# Path to the parent directory for the CPA run
parent_dir = r"/research/lab_winter/users/ang/projects/GW015_SuFEX_IMiDs/GW015_006__full_cpa"

# Path to the output directory
post_feature_extraction_output_dir = f"{parent_dir}/cleanPostFeatureExtraction/output_dir"
makeDirectory(post_feature_extraction_output_dir)

# 1| Flag problems
This section:
- add the "UpdatedImageNumber" to the individual batch CSV files (Cells, Nuclei, Cytoplasm and Image) which is a unique identifier for each set of images associated to a site across all plates
- merges the Image.csv files across all batches as {cell}__Image.csv
- checks for uneven staining across the plate
- checks for the distribution of cell counts and channel intensities
- flags low quality images using `IsolationForest` or, as in this case, I found that setting cut-offs on PercentMaximal and PowerLogLogSlope measurements more useful for flagging low quality images
- treatments that induce the supernumerary nuclei phenotype
- and checks the segmentation performance.

In [None]:
# Path to the output directory specific for section 1
output_dir = f"{post_feature_extraction_output_dir}/1_FlagProblems_output"
makeDirectory(output_dir)

In [None]:
###############################################################
# Merge the image CSV files for all batches
# and add the UpdatedImageNumber for image and object CSV files
# (do NOT merge the object CSV files across batches)
###############################################################
def modify_featureExtractionCSV_byCell(
    post_feature_extraction_output_dir = post_feature_extraction_output_dir,
    cell_dir_list = cell_dir_list,
    parent_dir = parent_dir
):
    """
    Function that does the following for each cel line dataset:
    
    1. adds "UpdatedImageNumber" column to individual item CSV files from
    the feature extraction pipeline. The "UpdatedImageNumber" is a unique identifier for
    sets of images (one for each channel) associated to a unique site in a unique plate
    for a cell line.
    
    2. merges the Image.csv files from each batch into a single {cell}__Image.csv.
    
    :param post_feature_extraction_output_dir: Path to the parent directory for all outputs
        from the post-feature extraction pipeline.
    :type post_feature_extraction_output_dir: str
    :param cell: Name of the directory where all the cell line data is stored. It is also
        used for naming the merged Image.csv file. (e.g. "rko_wt")
    :type cell: str
    :param parent_dir: Path to the directory containing all the data associated to the
        Cell Painting Assay run. This is used to define the path to the module_7b_output
        directory.
    :type parent_dir: str
    :return cell2imagecsv_dict: Dictionary for mapping {cell} to
        the path to the {cell}_Image.csv.
    :rtype cell2imagecsv_dict: dict
    """
    output_dir = f"{post_feature_extraction_output_dir}/Merged_Image_CSV"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    cell2imagecsv_dict = dict()
    for cell in cell_dir_list:
        print(cell)
        # Define the path to the module_7b_output directory
        module_7b_output_dir = f"{parent_dir}/{cell}/output_dir2/module_7b_output"
        cell2imagecsv_dict[cell] = modify_featureExtractionCSV(
            cell = cell,
            module_7b_output_dir = module_7b_output_dir,
            output_dir = output_dir
        )
        print("-"*10)
    return(cell2imagecsv_dict)

cell2imagecsv_dict = modify_featureExtractionCSV_byCell()
print("Path to Image_merged.csv for each cell line has been retrieved.")
print(cell2imagecsv_dict)

In [None]:
###################################################################
# Define the dictionaries for mapping the cell line to:
# 1. the {cell}__Image.csv
# 2. dataframe for logging the problematic images for data analysis
###################################################################
# I have to define the first dictionary since I am continuing in a separate session
cell2imagecsv_dict = {'c662_rko_wt': '/research/lab_winter/users/ang/projects/GW015_SuFEX_IMiDs/GW015_006__full_cpa/cleanPostFeatureExtraction/output_dir/Merged_Image_CSV/c662_rko_wt__Image.csv', 'c1141_rko_ko': '/research/lab_winter/users/ang/projects/GW015_SuFEX_IMiDs/GW015_006__full_cpa/cleanPostFeatureExtraction/output_dir/Merged_Image_CSV/c1141_rko_ko__Image.csv', 'c1327_rko_oe': '/research/lab_winter/users/ang/projects/GW015_SuFEX_IMiDs/GW015_006__full_cpa/cleanPostFeatureExtraction/output_dir/Merged_Image_CSV/c1327_rko_oe__Image.csv'}
print("Dictionary mapping the path to the image.csv for each cell line has been defined.")
print(cell2imagecsv_dict)

# Make dataframes for storing information on flagged images/treatments
cell2flagged_dict = dict()
for cell in cell_dir_list:
    cell2flagged_dict[cell] = pd.DataFrame(
        columns = [
            "Reference_Column",
            "Value",
            "Description"
        ]
    )
print("\nDataframe for storing flagged images has been made for each cell line.")
print(cell2flagged_dict)

## Supervised quality checks to flag problematic images

### Plate level: Check for any staining issues
This check is done first to remove any images that could skew measurements in general. Images flagged at this stage should be removed before proceeding on to other checks, which are in themselves dependent on image quality measurements.

In [None]:
def visualize_plate(
    image_df,
    cell,
    output_dir,
    normalization_mode = None,
    features = [
        "Count_RelatedUnfilteredCells",
        "Intensity_MedianIntensity_AGP",
        "Intensity_MedianIntensity_DNA",
        "Intensity_MedianIntensity_ER",
        "Intensity_MedianIntensity_Mito"
    ],
    cmap = "Blues"
):
    """
    Function for visualizing a few features across a plate.
    
    :param image_df: Dataframe containing general measurements of Image object.
    :type image_df: pd.DataFrame
    :param cell: Name of cell line in use (e.g. "rko_wt").
    :type cell: str
    :param output_dir: Path to the directory for outputs.
    :type output_dir: str
    :param normalization_mode: Option for normalization mode used, defaults to None.
    :type normalization_mode: str, optional. Can be: "z-score", "min-max", None.
    :param features: List of features to visualize, defaults to [
            "Count_RelatedUnfilteredCells",
            "Intensity_MedianIntensity_AGP",
            "Intensity_MedianIntensity_DNA",
            "Intensity_MedianIntensity_ER",
            "Intensity_MedianIntensity_Mito"
        ]
    :type features: list, optional
    :param cmap: Matplotlib color palette, defaults to "Blues".
    :type cmap: str, optional
    :raises ValueError: normalization_mode provided is not accepted.
    :return:    
    """
    
    #############################################################################
    # Aggregate the feature measurment in the image_df to the median of each well
    # (i.e. sites per well --> well)
    #############################################################################
    # Trim down the image_df to just the columns you need
    columns_of_interest = features + ["Metadata_Plate", "Metadata_Well"]
    image_df = image_df[columns_of_interest]
    
    # Aggregate the features by "Metadata_Well"
    agg_image_df = image_df.groupby(["Metadata_Plate", "Metadata_Well"]).median()
    agg_image_df = agg_image_df.reset_index(drop = False)
    
    # Free up RAM
    del image_df
    
    ###########################################################
    # Split the "Metadata_Well" into row and column information
    ###########################################################
    row_list = []
    col_list = []
    for index in agg_image_df["Metadata_Well"].tolist():
        row = index[0]
        col = int(index[1:3])
        row_list.append(row)
        col_list.append(col)

    agg_image_df["Metadata_Row"] = row_list
    agg_image_df["Metadata_Col"] = col_list
    
    ################################################
    # Visualize the feature across a plate at a time
    ################################################
    # Retrieve the list of unique plates
    plates = set(agg_image_df["Metadata_Plate"].tolist())
    plates = sorted(plates)
    
    # Loop through data of each plate
    for plate in plates:
        
        # Trim down the agg_image_df to the plate-specific information
        plate_df = agg_image_df[agg_image_df["Metadata_Plate"] == plate]
        
        # Plot the heatmaps for each feature using matplotlib subplots
        fig, axes = plt.subplots(1, len(features), figsize = (7*len(features), 5), sharey = False)
        fig.suptitle(
            f"{cell} {plate}: Overview across the plate",
            fontsize = 15
        )
        for i, feature in enumerate(features):
            
            # Normalize the feature values if the feature is not about the number of cells
            if not feature.startswith("Count") and normalization_mode != None:
                x = plate_df[feature]
                feature = f"{feature} ({normalization_mode})"
                if normalization_mode == "z-score":
                    plate_df[feature] = (x - x.mean())/x.std()
                elif normalization_mode == "min-max":
                    plate_df[feature] = (x - min(x))/(max(x) - min(x))
                else:
                    normalization_mode_options = ["z-score", "min-max", None]
                    raise ValueError(f"{normalization_mode} is not an acceptable normalization_mode option. Did you mean any of the following options instead?\n{normalization_mode_options}")
                    
            # Re-organize the dataframe for plotting heatmap with seaborn
            sorted_index_list = sorted(list(set(plate_df["Metadata_Col"].tolist())))
            heatmap_df = pd.DataFrame(index = sorted_index_list)
            for j, row in enumerate(plate_df["Metadata_Row"].tolist()):
                heatmap_df[row] = plate_df[plate_df["Metadata_Row"] == row][feature].tolist()
            
            # Plot the heatmap into the specific subplot
            sns.heatmap(
                ax = axes[i],
                data = heatmap_df.transpose(),
                annot = False,
                linewidth = 0.5,
                cmap = cmap
            )
            axes[i].set_title(feature)
        
        # Export the plot
        output_path = f"{output_dir}/{cell}__{plate}.png"
        savePlot(output_path)
        
        # Show the plot
        plt.show()
        
    return

def plate_check(
    cell2imagecsv_dict,
    output_dir,
    normalization_mode = None
):
    """
    Function for visualizing features across all plates for all cell lines and
    saves the plate plots.
    
    :param cell2imagecsv_dict: Dictionary mapping cell name to the path to
        the relevant Image_merged.csv.
    :type cell2imagecsv_dict: dict
    :param output_dir: Path to the directory for outputs.
    :type output_dir: str
    :param normalization_mode: Option for normalization mode used, defaults to None.
    :type normalization_mode: str, optional. Can be: "z-score", "min-max", None.
    :return:
    """
    print(f"normalization_mode: {normalization_mode}")
    if normalization_mode == None:
        output_dir = f"{output_dir}/PlateHeatmaps/raw"
    else:
        output_dir = f"{output_dir}/PlateHeatmaps/{normalization_mode}"
    makeDirectory(output_dir)
    
    for cell, imagecsv in cell2imagecsv_dict.items():
        visualize_plate(
            image_df = pd.read_csv(imagecsv),
            cell =  cell,
            output_dir = output_dir,
            normalization_mode = normalization_mode
        )
        
    print(f"EXPORTED: Heatmaps of count and channel intensities across the plate exported to\n{output_dir}")
    return

plate_check(cell2imagecsv_dict, output_dir, normalization_mode = None)

## Site level: Check if there are any outlier sites that should be flagged
The sites (or fields) are the sections of the well that have been imaged.

In [None]:
def kdeplot_byCell(
    cell2imagecsv_dict,
    cell2flagged_dict,
    remove_flagged = True,
    upper_percentile = 0.975,
    sharex = True,
    sharey = True,
    features = [
        "Count_RelatedUnfilteredCells",
        "Intensity_MedianIntensity_AGP",
        "Intensity_MedianIntensity_DNA",
        "Intensity_MedianIntensity_ER",
        "Intensity_MedianIntensity_Mito"
    ]
):
    """
    Function for plotting a KDE plot with a rug plot for a given feature across cell lines.
    
    :param cell2imagecsv_dict: Dictionary mapping cell name to the path to
        the relevant Image_merged.csv.
    :type cell2imagecsv_dict: dict
    :param cell2flagged_dict: Dictionary mapping cell name to a dataframe for
        storing flagged values.
    :type cell2flagged_dict: dict
    :param remove_flagged: Option for removing images that were flagged previously,
        defaults to True.
    :type remove_flagged: bool, optional.
    :param upper_percentile: Draws a red line marking the tentative upper limit
        for a feature (aside from Count_RelatedUnfilteredCells) for visualisation purposes,
        defaults to 0.975.
    :type upper_percentile: float, optional. Can be from 0 to 1.
    :param sharex: Option for sharing the x-axis, defaults to False.
    :type sharex: bool, optional.
    :param sharey: Option for sharing the y-axis, defaults to False.
    :type sharey: bool. optional.
    :param features: List of features to visualize, defaults to [
            "Count_RelatedUnfilteredCells",
            "Intensity_MedianIntensity_AGP",
            "Intensity_MedianIntensity_DNA",
            "Intensity_MedianIntensity_ER",
            "Intensity_MedianIntensity_Mito"
        ]
    :type features: list, optional.
    :return: upper_percentile
    :rtype: float
    """    
    df_list = []
    for cell in cell2imagecsv_dict:
        df = pd.read_csv(cell2imagecsv_dict[cell])
        if remove_flagged == True:
            flagged_df = cell2flagged_dict[cell]
            df = remove_flaggedImages(cell, df, flagged_df)
            del flagged_df
        df = df[features]
        df["Cell"] = [cell] * len(df)
        df_list.append(df)
    df = pd.concat(df_list, ignore_index = True)
    
    for feature in features:
        p = sns.displot(
            data = df,
            col = "Cell",
            x = feature,
            kind = "kde",
            fill = True,
            rug = True,
            color = "grey",
            height = 4,
            facet_kws = {
                "sharey": sharey,
                "sharex": sharex
            }
        )
        p.fig.subplots_adjust(top = 0.8)
        p.fig.suptitle(f"Distribution of {feature} by site per cell line")
        
        if feature.split("_")[0] != "Count":
            for i, cell in enumerate(list(cell2imagecsv_dict.keys())):
                temp_df = df[df["Cell"] == cell]
                upper_cutoff = temp_df[feature].quantile(q = upper_percentile)
                p.axes[0][i].axvline(upper_cutoff, ls = "--", color = "red")
        
        plt.show()
        print("\n"*3)
    return(upper_percentile)
upper_percentile = kdeplot_byCell(cell2imagecsv_dict, cell2flagged_dict)

Let me also check on the images above the upper percentile and see if they are indeed outliers that I should exclude from downstream analysis. The single channel images also have an accompanying histogram, which allows me to check if there is any over-saturation or intensity-peaking.

In [None]:
# Retrieve sample images that have measurements above the upper_percentile
def retrieve_images(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    cell,
    remove_flagged = True,
    number_of_images = 5,
    features = [
        "Intensity_MedianIntensity_AGP",
        "Intensity_MedianIntensity_DNA",
        "Intensity_MedianIntensity_ER",
        "Intensity_MedianIntensity_Mito"
    ]
):
    """
    Retrieves images above a tentative threshold to determine if a threshold is necessary
    for certain features as part of quality control.
    
    :param cell2imagecsv_dict: Dictionary mapping cell name to the path to the
        relevant Image_merged.csv.
    :type cell2imagecsv_dict: dict
    :param cell2flagged_dict: Dictionary mapping cell name to a dataframe for
        storing flagged values.
    :type cell2flagged_dict: dict
    :param upper_percentile: Draws a red line marking the tentative upper limit
        for a feature (aside from Count_RelatedUnfilteredCells) for visualisation purposes.
    :type upper_percentile: float, optional. Can be from 0 to 1.
    :param cell: Name of cell line in use (e.g. "rko_wt").
    :type cell: str
    :param remove_flagged: If True, images that have been flaggged previously are excluded
        from the sampling. Defaults to True.
    :type remove_flagged: bool
    :param number_of_images: Number of images to retrieve, defaults to 5.
    :type number_of_images: int
    :param features: List of features to visualize, defaults to [
            "Intensity_MedianIntensity_AGP",
            "Intensity_MedianIntensity_DNA",
            "Intensity_MedianIntensity_ER",
            "Intensity_MedianIntensity_Mito"
        ]
    :type features: list, optional.
    :return:
    """
    # Load the Image_merged.csv
    image_df = pd.read_csv(cell2imagecsv_dict[cell])
    
    # If remove_flagged == True,
    # discard the flagged images prior to sampling for example images
    if remove_flagged == True:
        flagged_df = cell2flagged_dict[cell]
        image_df = remove_flaggedImages(cell, image_df, flagged_df)
        del flagged_df
    
    for feature in features:
        channel = feature.split("_")[2]
        upper_cutoff = image_df[feature].quantile(q = upper_percentile)
        
        # Trim down the image_df to just what images would be outliers
        # based on the median channel intensity
        outlier_df = image_df[image_df[feature] > upper_cutoff]
        
        # Retrieve a list of UpdatedImageNumber corresponding to these supposed outlier images
        updatedImageNumbers = outlier_df.sample(n = number_of_images)["UpdatedImageNumber"].tolist()
        for updatedImageNumber in updatedImageNumbers:
            ImageManipulation.find_ImageByNumber(
                cell,
                image_df,
                updatedImageNumber,
                channel = channel
            )
    return

In [None]:
retrieve_images(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    cell = "c662_rko_wt"
)

In [None]:
retrieve_images(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    cell = "c1141_rko_ko"
)

In [None]:
retrieve_images(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    cell = "c1327_rko_oe"
)

## Flag images which are blur or contain saturation artefacts using PLLS and PM only
Power log log slope (PLLS) and percent maximal (PM) are parameters were found to be effective in detecting images that are blurry or contain saturation artefacts (e.g. fibres) (see [Bray _et al._](https://www.sciencedirect.com/science/article/pii/S2472555222075943?via%3Dihub)).

In [None]:
#########################################################
# Plot a scatterplot for PM against PLLS for all channels
#########################################################
def scatterplot_byCell(
    cell2imagecsv_dict = cell2imagecsv_dict,
    cell2flagged_dict = cell2flagged_dict,
    remove_flagged = True
):
    # Retrieve the image.csv information and exclude any flagged images
    for cell, imagecsv in cell2imagecsv_dict.items():
        image_df = pd.read_csv(imagecsv)
        if remove_flagged == True:
            flagged_df = cell2flagged_dict[cell]
            image_df = remove_flaggedImages(cell, image_df, flagged_df)
            del flagged_df
        
        # Prepare the data required for the scatterplot
        df_list = []
        for channel in ["AGP", "DNA", "ER", "Mito"]:
            plls = f"ImageQuality_PowerLogLogSlope_Orig{channel}"
            pm = f"ImageQuality_PercentMaximal_Orig{channel}"
            df = image_df[[plls, pm]]
            df = df.rename(columns = {plls: "PLLS", pm: "PM"})
            df["Channel"] = [channel] * len(df)
            df_list.append(df)
        df = pd.concat(df_list, ignore_index = True)

        # Plot the scatter plots
        p = sns.relplot(
            data = df,
            x = "PLLS",
            y = "PM",
            col = "Channel",
            kind = "scatter",
            height = 4
        )
        p.fig.subplots_adjust(top = 0.8)
        p.fig.suptitle(f"{cell}: PLLS against PM by site for each fluorescent channel")
        
        plt.show()
        print("\n"*3)
    return
scatterplot_byCell()

In [None]:
###################################################################################
# Retrieve the UpdatedImageNumbers corresponding to images outside a certain cutoff
# for the user-defined feature
###################################################################################
def retrieve_updatedImageNumbers(
    image_df,
    feature,
    cutoff_parameters
):
    operator, limit = cutoff_parameters
    if operator == "less than":
        outlier_df = image_df[image_df[feature] < limit]
    elif operator == "more than":
        outlier_df = image_df[image_df[feature] > limit]
    return(outlier_df["UpdatedImageNumber"].tolist())

########################################################################
# Retrieve the UpdatedImageNumbers for a set of cutoff parameters
# and display the corresponding images
# to check if blurry images or images with saturation artefacts show up
# and then add the UpdatedImageNumbers to the flagged_df
# to discard these images before data aggregation/profile assembly
########################################################################
def retrieve_images(
    cell2imagecsv_dict = cell2imagecsv_dict,
    cell2flagged_dict = cell2flagged_dict,
    edit_flagged = True
):
    # Templates
    plls_template = "ImageQuality_PowerLogLogSlope_Orig%s"
    pm_template = "ImageQuality_PercentMaximal_Orig%s"
    
    # Cutoffs
    feature2cutoff_dict = dict()
    feature2cutoff_dict[plls_template % "AGP"] = ["less than", -1.9]
    feature2cutoff_dict[plls_template % "DNA"] = ["less than", - 2.2]
    feature2cutoff_dict[plls_template % "ER"] = ["less than", -1.8]
    feature2cutoff_dict[plls_template % "Mito"] = ["less than", -2.0]
    for channel in ["AGP", "DNA", "ER", "Mito"]:
        feature2cutoff_dict[pm_template % channel] = ["more than", 0.15]
    
    # Retrieve the outlier images with the feature2cutoff
    for feature, cutoff_parameters in feature2cutoff_dict.items():
        description = f"{feature} {cutoff_parameters[0]} {cutoff_parameters[1]}"
        print(description)
        for cell, imagecsv in cell2imagecsv_dict.items():
            image_df = pd.read_csv(imagecsv)
            updatedImageNumbers = retrieve_updatedImageNumbers(image_df, feature, cutoff_parameters)
            
            # Update the flagged_df if edit_flagged = True
            if edit_flagged == True:
                flagged_df = pd.DataFrame()
                flagged_df["Reference_Column"] = ["UpdatedImageNumber"] * len(updatedImageNumbers)
                flagged_df["Value"] = updatedImageNumbers
                flagged_df["Description"] = [description] * len(updatedImageNumbers)
                cell2flagged_dict[cell] = pd.concat([cell2flagged_dict[cell], flagged_df], ignore_index = True).drop_duplicates()
            
            # Retrieve the outlier images
            for updatedImageNumber in updatedImageNumbers:
                ImageManipulation.find_ImageByNumber(
                    cell,
                    image_df,
                    updatedImageNumber,
                    channel = "all"
                )
        print("-"*20)
        print("\n"*3)
    return(cell2flagged_dict)

cell2flagged_dict = retrieve_images()

In [None]:
# Check if the flagged_df was updated accordingly
def check_flagged(cell2flagged_dict = cell2flagged_dict):
    for cell, flagged_df in cell2flagged_dict.items():
        print(f"{cell}: {len(flagged_df)} flagged values")
        display(flagged_df.tail())
        print("\n"*3)
    return
check_flagged()

## Flag treatments that induce the supernumerary nuclei phenotype
In the downstream analysis, I will prepare morphological profiles that assume each cell only has one nuclei. As such, the downstream analysis cannot handle supernumerary nuclei cells appropriately. Treatments that cause this phenotype need to be flagged and analyzed using a different approach.

In [None]:
def flag_treatment(
    cell2imagecsv_dict,
    cell2flagged_dict,
    remove_flagged = True,
    number_of_stdev_from_control = 1.0
):
    """
    Add treatments that induce supernumerary nuclei phenotype
    using AMG-900 (pan-aurora kinase inhibitor that induced the supernumerary nuclei phenotype)
    as a reference.
    
    :param cell2imagecsv_dict: Dictionary mapping cell name to the path to
        the relevant Image_merged.csv.
    :type cell2imagecsv_dict: dict
    :param cell2flagged_dict: Dictionary mapping cell name to a dataframe for
        storing flagged values.
    :type cell2flagged_dict: dict
    :param remove_flagged: Option for removing images that were flagged previously,
        defaults to True.
    :type remove_flagged: bool, optional.
    :param number_of_stdev_from_control: This parameter is used to calculate what is
        the upper limit on the median ratio of nuclei to cells for  the treatment to be flagged as an
        inducer of the supernumerary nuclei phenotype. The control being used is
        AMG-900, which is a pan-aurora kinase inhibitor. Defaults to 1.0.
        I chose the default based on whether other treatments that I know do not cause the supernumerary
        nuclei phenotype end up being flagged or not (which can also be verified by looking at the images).
    :type number_of_stdev_from_control: float, optional.
    :return: cell2flagged_dict updated with the treatments that induce the
        supernumerary nuclei phenotype.
    :rtype: dict
    """
    for cell, imagecsv in cell2imagecsv_dict.items():
        print(cell)
        image_df = pd.read_csv(imagecsv)
        
        # Discard images that have been flagged before if the user has set remove_flagged = True
        if remove_flagged == True:
            flagged_df = cell2flagged_dict[cell]
            image_df = remove_flaggedImages(cell, image_df, flagged_df)
            del flagged_df
        
        # Calculate the ratio of matched nuclei to matched cells
        image_df["Ratio_NucleiToCells"] = image_df["Count_Nuclei"] / image_df["Count_RelatedUnfilteredCells"]
        
        # Trim down the image_df to the relevant columns
        image_df = image_df[["Metadata_Treatment", "Ratio_NucleiToCells"]]
        image_df.replace([np.inf, -np.inf], np.nan, inplace = True)
        image_df = image_df.dropna()
        
        # Aggregate the ratios calculated for each image on a per treatment basis
        agg_image_df = image_df.groupby("Metadata_Treatment").median()
        
        # Re-organize the agg_image_df as a dictionary
        # to map the treatment to the median ratio
        treatment2medianRatio_dict = dict()
        for treatment in agg_image_df.index:
            treatment2medianRatio_dict[treatment] = agg_image_df.loc[treatment, "Ratio_NucleiToCells"]
        del agg_image_df
        
        # Retrieve the upper limit for the Ratio_NucleiToCells
        amg900_df = image_df[image_df["Metadata_Treatment"] == "AMG-900"]
        amg900_stdev = statistics.pstdev(amg900_df["Ratio_NucleiToCells"].tolist())
        upper_limit = treatment2medianRatio_dict["AMG-900"] - number_of_stdev_from_control * amg900_stdev
        print(f"Upper limit in use: {upper_limit}")
        
        # Flag treatments which have a median Ratio_NucleiToCells higher than the upper limit
        flagged_treatments = []
        for treatment, medianRatio in treatment2medianRatio_dict.items():
            if medianRatio > upper_limit:
                flagged_treatments.append(treatment)

        # Update the flagged_df
        flagged_df = pd.DataFrame()
        flagged_df["Reference_Column"] = ["Metadata_Treatment"] * len(flagged_treatments)
        flagged_df["Value"] = flagged_treatments
        flagged_df["Description"] = ["Treatment induces supernumerary nuclei phenotype."] * len(flagged_treatments)
        cell2flagged_dict[cell] = pd.concat([cell2flagged_dict[cell], flagged_df], ignore_index = True).drop_duplicates()
        
        # Check on the flagged_df
        display(cell2flagged_dict[cell].tail())
        
    return(cell2flagged_dict)

cell2flagged_dict = flag_treatment(cell2imagecsv_dict, cell2flagged_dict)

## Check segmentation performance
The segmentation of the whole cells and nuclei was done separately, so there'll be some discrepancy. I need to check how the segmentation performance is like.

In [None]:
def segmentation_kdeplot_byCell(
    cell2imagecsv_dict,
    cell2flagged_dict,
    remove_flagged = True,
    upper_percentile = 0.975,
    sharex = True,
    sharey = True
):
    """
    Function for plotting a KDE plot with a rug plot for the discrepancy and ratio between
    the number of nuclei and whole cells segmented.
    
    :param cell2imagecsv_dict: Dictionary mapping cell name to the path to
        the relevant Image_merged.csv.
    :type cell2imagecsv_dict: dict
    :param cell2flagged_dict: Dictionary mapping cell name to a dataframe for
        storing flagged values.
    :type cell2flagged_dict: dict
    :param remove_flagged: Option for removing images that were flagged previously,
        defaults to True.
    :type remove_flagged: bool
    :param upper_percentile: Draws a red line marking the tentative upper limit
        for a feature for visualisation purposes, defaults to 0.975.
    :type upper_percentile: float, optional. Can be from 0 to 1.
    :param sharex: Option for sharing the x-axis, defaults to False.
    :type sharex: bool, optional.
    :param sharey: Option for sharing the y-axis, defaults to False.
    :type sharey: bool. optional.
    :return: upper_percentile
    :rtype: float
    """
    image_df_list = []
    for cell, imagecsv in cell2imagecsv_dict.items():
        image_df = pd.read_csv(imagecsv)
        if remove_flagged == True:
            flagged_df = cell2flagged_dict[cell]
            image_df = remove_flaggedImages(cell, image_df, flagged_df)
            del flagged_df
        image_df = image_df[image_df["Count_Nuclei"] > 0]
        image_df["Count_Discrepancy"] = image_df["Count_Nuclei"] - image_df["Count_RelatedUnfilteredCells"]
        image_df["Ratio_NucleiToCells"] = image_df["Count_Nuclei"] / image_df["Count_RelatedUnfilteredCells"]
        image_df["Cell"] = [cell] * len(image_df)
        image_df = image_df[["Cell", "Count_Discrepancy", "Ratio_NucleiToCells"]]
        image_df.replace([np.inf, -np.inf], np.nan, inplace = True)
        image_df = image_df.dropna()
        image_df_list.append(image_df)
    image_df = pd.concat(image_df_list, ignore_index = True)
    
    for feature in ["Count_Discrepancy", "Ratio_NucleiToCells"]:
        p = sns.displot(
            data = image_df,
            col = "Cell",
            x = feature,
            kind = "kde",
            fill = True,
            rug = True,
            color = "grey",
            height = 4,
            facet_kws = {
                "sharey": sharey,
                "sharex": sharex
            }
        )
        p.fig.subplots_adjust(top = 0.8)
        p.fig.suptitle(f"Distribution of {feature} by site per cell line")
        
        for i, cell in enumerate(list(cell2imagecsv_dict.keys())):
            temp_df = image_df[image_df["Cell"] == cell]
            upper_cutoff = temp_df[feature].quantile(q = upper_percentile)
            p.axes[0][i].axvline(upper_cutoff, ls = "--", color = "red")
        
        plt.show()
    
    return(upper_percentile)

upper_percentile = segmentation_kdeplot_byCell(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile = 0.99,
    sharex = False
)

Note:
- `Count_Nuclei`: Number of nuclei segmented by cellpose.
- `Count_RelatedUnfiltered_Cells`: Number of cells segmented by cellpose with matching nuclei.
- `Count_Discrepancy` = `Count_Nuclei` - `Count_RelatedUnfilteredCells`
- `Ratio_NucleiToCells` = `Count_Nuclei` ÷ `Count_RelatedUnfilteredCells`

In [None]:
def check_segmentation(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    feature_extraction_output_dir,
    remove_flagged = True,
    cell = "c662_rko_wt",
    max_number_of_images = 3,
    batches = 10,
    figsize = (7, 7)
):
    """
    Function for visualizing segmented masks on top of the dual channel images used
    for segmentation with the highest positive or negative difference in number of
    nuclei segmented compared to whole cells.
    
    :param cell2imagecsv_dict: Dictionary mapping cell name to the path to
        the relevant Image_merged.csv.
    :type cell2imagecsv_dict: dict
    :param cell2flagged_dict: Dictionary mapping cell name to a dataframe for
        storing flagged values.
    :type cell2flagged_dict: dict
    :param parent_dir: Path to the directory containing all the data pertaining
        to the Cell Painting Assay run.
    :type parent_dir: str
    :param upper_percentile: Tentative upper limit on a feature (ratio of nuclei to cells in this case).
    :type upper_percentile: float. Can be from 0 to 1.
    :param remove_flagged: Option for removing images that were flagged previously,
        defaults to True.
    :type remove_flagged: bool
    :param cell: Name of cell line in use (e.g. "c662_rko_wt"),
        defaults to "c662_rko_wt".
    :type cell: str, optional.
    :param max_number_of_images: Maximum number of images to retrieve, defaults to 3.
    :type max_number_of_images: int, optional.
    :param batches: Maximum number of batches used during the feature extraction,
        defaults to 10.
    :type batches: int, optional.
    :param figsize: 2-element tuple for setting the size of the images shown,
        defaults to (7, 7).
    :type figsize: tuple
    :return:
    """
    # Retrieve the Image_merged.csv
    image_df = pd.read_csv(cell2imagecsv_dict[cell])
    
    # Discard images that were flagged previously
    if remove_flagged == True:
        flagged_df = cell2flagged_dict[cell]
        image_df = remove_flaggedImages(cell, image_df, flagged_df)
        del flagged_df
    
    # Calculate the ratio of the numbers of
    # nuclei detected and whole cells detected
    image_df["Ratio_NucleiToCells"] = image_df["Count_Nuclei"] / image_df["Count_RelatedUnfilteredCells"]
    image_df.replace([np.inf, -np.inf], np.nan, inplace = True)
    image_df = image_df.dropna()
    
    print("Case: Nuclei >> Cells detected")
        
    # Trim down the image_df to the images with a ratio higher than the upper_percentile
    temp_df = image_df[image_df["Ratio_NucleiToCells"] > upper_percentile]
    temp_df = image_df.sort_values(by = "Ratio_NucleiToCells", ascending = False).reset_index(drop = True)

    # Visualize the segmentation
    if len(temp_df) < max_number_of_images:
        max_number_of_images = len(temp_df)
    for i, updatedImageNumber in enumerate(temp_df["UpdatedImageNumber"].tolist()):
        if i < max_number_of_images + 1:
            ratio = temp_df.loc[i, "Ratio_NucleiToCells"]
            print(f"Image {i + 1} of {max_number_of_images}: {ratio}")
            visualize_segmentation(
                updatedImageNumber,
                temp_df,
                feature_extraction_output_dir,
                cell,
                batches = batches,
                figsize = figsize
            )
    return

In [None]:
cell = "c662_rko_wt"
check_segmentation(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    batches = 20,
    feature_extraction_output_dir = f"{parent_dir}/{cell}/output_dir2",
    cell = cell,
    max_number_of_images = 5
)

In [None]:
cell = "c1141_rko_ko"
check_segmentation(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    batches = 20,
    feature_extraction_output_dir = f"{parent_dir}/{cell}/output_dir2",
    cell = cell,
    max_number_of_images = 5
)

In [None]:
cell = "c1327_rko_oe"
check_segmentation(
    cell2imagecsv_dict,
    cell2flagged_dict,
    upper_percentile,
    batches = 20,
    feature_extraction_output_dir = f"{parent_dir}/{cell}/output_dir2",
    cell = cell,
    max_number_of_images = 5
)

**Comments**



## Quality check using known CRBN-dependent treatments that induce cytotoxicity
In this experiment, I have treatments that induce cytotoxicity in a CRBN-dependent manner. In theory, RKO WT and RKO CRBN OE would have low cell counts (compared to DMSO-treated conditions) under these treatments while RKO CRBN KO should have cell counts simialr to DMSO-treated conditions. I can, thus, use the cell counts for these treatments as a check on the quality of the data obtained from this experiment.

The treatments that are cytotoxic at the treatment concentrations used in RKO WT and RKO CRBN OE (see AN-B-183 for IC50 values quantified via cell titre glo assay) are:
- CC-885 (also verified in GW015_003 to check if the cell lines I thawed are the right ones)
- CC-90009
- dBET1
- dBET6

In [None]:
def quality_check(
    cell2imagecsv_dict,
    cell2flagged_dict,
    treatments = [
        "CC-885",
        "CC-90009",
        "dBET1",
        "dBET6",
        "DMSO" # <-- for normalization purposes
    ]
):
    # Initialize a list for storing the cell count details across cell lines
    df_list = []
    
    for cell in cell2imagecsv_dict:
        
        # Retrieve the image_df
        image_df = pd.read_csv(cell2imagecsv_dict[cell])
        
        # Discard the images that have been flagged
        flagged_df = cell2flagged_dict[cell]
        image_df = remove_flaggedImages(cell, image_df, flagged_df)
        del flagged_df
        
        # Trim down the image_df to the treatments
        image_df = image_df[image_df["Metadata_Treatment"].isin(treatments)]
        
        # Trim the image_df further to the relevant columns
        image_df = image_df[["Metadata_Treatment", "Count_RelatedUnfilteredCells"]]
        
        # Normalize the cell count to the DMSO control (robust Z score calculation)
        dmso_values = image_df[image_df["Metadata_Treatment"] == "DMSO"]["Count_RelatedUnfilteredCells"].tolist()
        dmso_median = statistics.median(dmso_values)
        dmso_mad = scipy.stats.median_abs_deviation(dmso_values)
        image_df["NormalizedCount_RelatedUnfilteredCells"] = (image_df["Count_RelatedUnfilteredCells"] - dmso_median) / dmso_mad
        
        # Append the normalized cell counts for each site and treatment to the list
        image_df["Metadata_Cell"] = [cell] * len(image_df)
        df_list.append(image_df)
        
    # Concatenate the normalized cell counts for each cell line together
    df = pd.concat(df_list)
    
    for treatment in treatments:
        treatment_df = df[df["Metadata_Treatment"] == treatment]
        sns.stripplot(
            data = treatment_df,
            x = "Metadata_Cell",
            y = "NormalizedCount_RelatedUnfilteredCells"
        )
        plt.title(treatment)
        plt.show()
        
    
    return

quality_check(cell2imagecsv_dict, cell2flagged_dict)

### Export flagged problems

In [None]:
def export_flagged(
    cell2flagged_dict = cell2flagged_dict,
    output_dir = output_dir
):
    """
    Function for exporting the flagged images (recorded in the dataframes) for all cell lines.
    
    :param cell2flagged_dict: Dictionary mapping cell name to a dataframe for
        storing flagged values.
    :type cell2flagged_dict: dict
    :param output_dir: Path to the directory for outputs.
    :type output_dir: str
    :return:
    """
    for cell, flagged_df in cell2flagged_dict.items():
        output_path = f"{output_dir}/{cell}__flagged.csv"
        flagged_df.to_csv(output_path, index = False)
    print(f"EXPORTED: Flagged problems exported to\n{output_dir}")
    return

export_flagged()

# 2| Profile assembly and feature selection
This section:
- assembles the morphological profiles for each cell line,
- concatenates them together into one profile (for all cell lines),
- and selects features in a global and treatment-ccentric manner.

In [None]:
def profileAssembly(
    cell_dir_list,
    parent_dir,
    post_feature_extraction_output_dir,
    show_max_memory_usage = True,
    verbose = False
):
    """
    Function that assembles the profile for each cell line and merges the profiles of each cell line into
    a single cell line with the "Metadata_Cell" column.
    
    :param cell_dir_list: List of cell line names. It is used for exporting/retrieving the
        batch object CSV files and the merged image CSV file. (e.g. ["rko_wt", "rko_ko", "rko_oe"])
    :type cell_dir_list: list
    :param parent_dir: Path to the directory containing all the data associated to the
        Cell Painting Assay run. This is used to define the path to the module_7b_output
        directory.
    :type parent_dir: str
    :param post_feature_extraction_output_dir: Path to the parent directory for all outputs
        from the post-feature extraction pipeline.
    :type post_feature_extraction_output_dir: str
    :param show_max_memory_usage: Option for printing the details on the object CSV files,
        defaults to True. This parameter is quite useful to check on RAM usage during profile assembly.
        The maximum RAM usage for the assembled profile for ALL cell lines can be estimated
        using the details by:
            memory usage * number of batches per cell line * number of plates in total / 200
        The denominator could be 300 as well, but 200 is used for a high estimate on RAM consumption.
    :type show_max_memory_usage: bool, optional.
    :param verbose: Option for printing out all output messages or not, defaults to True.
    :type verbose: bool, optional.
    :return concatProfile_path: Path to the merged profile for all cell lines.
    :rtype concatProfile_path: str
    """
    # Initialize PrepProfile(...)
    pp = PrepProfile(show_max_memory_usage, verbose)
    print("Prepping the morphological profiles for each cell line...")
    
    # Carry out the profile assembly for each cell line
    # and generate a dictionary mapping the cell line to
    # the path to its profile
    cell2profile_dict = dict()
    flagProblems_output_dir = f"{post_feature_extraction_output_dir}/1_FlagProblems_output"
    profile_output_dir = f"{post_feature_extraction_output_dir}/2_ProfileAssemblyAndFeatureSelection_output"
    if not os.path.exists(profile_output_dir):
        os.makedirs(profile_output_dir)
        print(f"The following directory has been made:\n{profile_output_dir}")
    for cell in cell_dir_list:
        print(f"\n### Cell line context in use: {cell} ###")
        module_7b_output_dir = f"{parent_dir}/{cell}/output_dir2/module_7b_output"
        image_merged_path = f"{post_feature_extraction_output_dir}/Merged_Image_CSV/{cell}__Image.csv"
        cell2profile_dict[cell] = pp.profile(
            cell,
            module_7b_output_dir,
            image_merged_path,
            flagProblems_output_dir,
            profile_output_dir
        )
        print("-"*10)
    print("COMPLETED: The morphological profile for each cell line has been assembled.\n")
        
    ### Merge the profiles per cell line into a single profile ###
    print("Merging the profiles per cell line into a single dataframe...")
    df_list = []
    i = 0
    for cell, profile_path in cell2profile_dict.items():
        i += 1
        if i == 1:
            output_dir = "/".join(profile_path.split("/")[:-1])
        df = pd.read_csv(profile_path)
        df["Metadata_Cell"] = [cell] * len(df)
        df_list.append(df)
    df = pd.concat(df_list, ignore_index = True)
    concatProfile_path = f"{output_dir}/concatenated__profile.csv"
    df.to_csv(concatProfile_path, index = False)
    print("Preview of five random rows from three columns of the concatenated profile:")
    display(df[["Metadata_Cell", "Metadata_Treatment", "AreaShape_FormFactor"]].sample(n = 5))
    print(f"COMPLETED: Profiles for each cell line concatenated into a single profile with the Metadata_Cell column added at\n{concatProfile_path}\n")
    return(concatProfile_path)

concatProfile_path = profileAssembly(
    cell_dir_list,
    parent_dir,
    post_feature_extraction_output_dir,
    show_max_memory_usage = True,
    verbose = False
)

**Comments**

The number of treatments retained is one less than the total number of treatments used as expected (since AMG-900 is kicked out). 👍

#### Explanation on feature selection
For the feature selection, I decided to employ two strategies using the `SelectFeatures` class namely:
1. Global feature selection using only the RKO WT profile,
2. And treatment-ccentric feature selection that prioritizes features that correlate with the CRBN status of the cell lines for each treatment.

In [None]:
concatProfile_path = "/research/lab_winter/users/ang/projects/GW015_SuFEX_IMiDs/GW015_006__full_cpa/cleanPostFeatureExtraction/output_dir/2_ProfileAssemblyAndFeatureSelection_output/concatenated__profile.csv"

In [None]:
SelectFeatures(
    concatProfile_path,
    prefix = "",
    corr_threshold = 0.8,
    verbose = True,
    output_dir = "default",
    param_global = dict({
        "skip": False,
        "get_consensus": False,
        "global_cell": "c662_rko_wt",
        "vote_threshold": 0.5
    }),
    param_treatment = dict({
        "skip": False,
        "keep_byProduct": True,
        "description": "Cells ordered by CRBN expression (KO -> WT -> OE)",
        "ordered_cells": ["c1141_rko_ko", "c662_rko_wt", "c1327_rko_oe"],
        "alpha": 0.05,
        "kendall_alternative": "two-sided",
        "vote_threshold": 0.5
    })
)

## Completion of QC, morphological profile assembly and feature selection
<span style="color:red">Note: Use **📜 baseline_output.csv** as the morphological profile for section 3.</span> The 📜 baseline_output.csv contains the features that have been normalized/standardized to the DMSO controls.

Proceed on to <span style="color:blue">**Profile Calculations**</span>.