## Perform cell clustering using pixel metaclusters on Spain and Stanford cohorts

In [None]:
import json
import os
from datetime import datetime as dt
import feather
import matplotlib.pyplot as plt
import pandas as pd
from alpineer import load_utils
from matplotlib import rc_file_defaults
from ark.phenotyping import (cell_cluster_utils, cell_meta_clustering,
                             cell_som_clustering, weighted_channel_comp)
from ark.utils import data_utils, example_dataset, plot_utils
from ark.utils.metacluster_remap_gui import (MetaClusterGui,
                                             colormap_helper,
                                             metaclusterdata_from_files)

In [None]:
base_dir = "Z:\\Noah Greenwald\\TNBC_Cohorts\\SPAIN"

In [None]:
pixel_output_dir = '20231031_full_cohort_pixel_output_dir'
cell_clustering_params_name = 'cell_clustering_params.json'

In [None]:
with open(os.path.join(base_dir, "pixie", pixel_output_dir, cell_clustering_params_name)) as fh:
    cell_clustering_params = json.load(fh)
    
fovs = cell_clustering_params['fovs']
channels = cell_clustering_params['channels']
tiff_dir = cell_clustering_params['tiff_dir']
img_sub_folder = cell_clustering_params['img_sub_folder']
segmentation_dir = cell_clustering_params['segmentation_dir']
seg_suffix = cell_clustering_params['seg_suffix']
pixel_data_dir = cell_clustering_params['pixel_data_dir']+'_temp'
pc_chan_avg_som_cluster_name = cell_clustering_params['pc_chan_avg_som_cluster_name']
pc_chan_avg_meta_cluster_name = cell_clustering_params['pc_chan_avg_meta_cluster_name']

cell_table_path = os.path.join(base_dir, 'segmentation', 'samples', 'cell_table', 'cell_table_size_normalized_samples.csv')

In [None]:
cell_cluster_prefix = "20231031_full_cohort"

if cell_cluster_prefix is None:
    cell_cluster_prefix = dt.now().strftime('%Y-%m-%dT%H:%M:%S')

In [None]:
cell_output_dir = '%s_cell_output_dir' % cell_cluster_prefix
if not os.path.exists(os.path.join(base_dir, "pixie", cell_output_dir)):
    os.mkdir(os.path.join(base_dir, "pixie", cell_output_dir))
    
cell_som_weights_name = os.path.join("pixie", cell_output_dir, 'cell_som_weights.feather')
cluster_counts_name = os.path.join("pixie", cell_output_dir, 'cluster_counts.feather')
cluster_counts_size_norm_name = os.path.join("pixie", cell_output_dir, 'cluster_counts_size_norm.feather')
weighted_cell_channel_name = os.path.join("pixie", cell_output_dir, 'weighted_cell_channel.feather')
cell_som_cluster_count_avg_name = os.path.join("pixie", cell_output_dir, 'cell_som_cluster_count_avg.csv')
cell_meta_cluster_count_avg_name = os.path.join("pixie", cell_output_dir, 'cell_meta_cluster_count_avg.csv')
cell_som_cluster_channel_avg_name = os.path.join("pixie", cell_output_dir, 'cell_som_cluster_channel_avg.csv')
cell_meta_cluster_channel_avg_name = os.path.join("pixie", cell_output_dir, 'cell_meta_cluster_channel_avg.csv')
cell_meta_cluster_remap_name = os.path.join("pixie", cell_output_dir, 'cell_meta_cluster_mapping.csv')

In [None]:
pixel_cluster_col = 'pixel_meta_cluster_rename'

if pixel_cluster_col == 'pixel_som_cluster':
    pc_chan_avg_name = pc_chan_avg_som_cluster_name
elif pixel_cluster_col == 'pixel_meta_cluster_rename':
    pc_chan_avg_name = pc_chan_avg_meta_cluster_name

In [None]:
import pandas as pd
meta_avg = pd.read_csv(os.path.join(base_dir, pc_chan_avg_meta_cluster_name))
meta_avg

In [None]:
if os.path.exists(os.path.join(base_dir, cluster_counts_name)) and os.path.exists(os.path.join(base_dir, cluster_counts_size_norm_name)):
    # load the data if it exists
    cluster_counts = feather.read_dataframe(os.path.join(base_dir, cluster_counts_name))
    cluster_counts_size_norm = feather.read_dataframe(os.path.join(base_dir, cluster_counts_size_norm_name))
else:
    # generate the preprocessed data 
    cluster_counts, cluster_counts_size_norm = cell_cluster_utils.create_c2pc_data(
        fovs, os.path.join(base_dir, pixel_data_dir), cell_table_path, pixel_cluster_col
    )

    # write both unnormalized and normalized input data for reference
    feather.write_dataframe(
        cluster_counts,
        os.path.join(base_dir, cluster_counts_name),
        compression='uncompressed'
    )
    feather.write_dataframe(
        cluster_counts_size_norm,
        os.path.join(base_dir, cluster_counts_size_norm_name),
        compression='uncompressed'
    )
    
# define the count columns found in cluster_counts_norm
cell_som_cluster_cols = cluster_counts_size_norm.filter(
    regex=f'{pixel_cluster_col}.*'
).columns.values

In [None]:
if pixel_cluster_col == 'pixel_som_cluster':
    pc_chan_avg_name = pc_chan_avg_som_cluster_name
elif pixel_cluster_col == 'pixel_meta_cluster_rename':
    pc_chan_avg_name = pc_chan_avg_meta_cluster_name

if not os.path.exists(os.path.join(base_dir, weighted_cell_channel_name)):
    pixel_channel_avg = pd.read_csv(os.path.join(base_dir, pc_chan_avg_name))
    weighted_cell_channel = weighted_channel_comp.compute_p2c_weighted_channel_avg(
        pixel_channel_avg,
        channels,
        cluster_counts,
        fovs=fovs,
        pixel_cluster_col=pixel_cluster_col
    )

    feather.write_dataframe(
        weighted_cell_channel,
        os.path.join(base_dir, weighted_cell_channel_name),
        compression='uncompressed'
    )

In [None]:
TONIC_directory = r'Z:\Noah Greenwald\TONIC_Cohort\pixel_clustering\20220715_full_cohort_cell_output_dir'

In [None]:
from ark.phenotyping.cluster_helpers import CellSOMCluster
cell_som_weights_name = os.path.join(TONIC_directory, '20220715_full_cohort_cell_weights.feather')
cell_pysom = CellSOMCluster(
    cluster_counts_size_norm, cell_som_weights_name, fovs, cell_som_cluster_cols,
    num_passes=1, xdim=17, ydim=17, lr_start=0.05, lr_end=0.01,
    seed=42, normalize=True
)

In [None]:
cluster_counts_size_norm = cell_som_clustering.cluster_cells(
    base_dir,
    cell_pysom,
    cell_som_cluster_cols=cell_som_cluster_cols
)

feather.write_dataframe(
    cluster_counts_size_norm,
    os.path.join(base_dir, cluster_counts_size_norm_name),
    compression='uncompressed'
)

cell_som_clustering.generate_som_avg_files(
    base_dir,
    cluster_counts_size_norm,
    cell_som_cluster_cols=cell_som_cluster_cols,
    cell_som_expr_col_avg_name=cell_som_cluster_count_avg_name
)

In [None]:
import pandas as pd
remap_data = pd.read_csv(os.path.join(TONIC_directory, "20220715_full_cohort_cell_meta_cluster_mapping.csv"))
remap_data

In [None]:
import feather
som_to_meta = dict(remap_data[['cluster', 'metacluster']].values)
cluster_counts_size_norm["cell_meta_cluster"] = cluster_counts_size_norm["cell_som_cluster"].map(som_to_meta)
cluster_counts_size_norm

In [None]:
feather.write_dataframe(
    cluster_counts_size_norm,
    os.path.join(base_dir, cluster_counts_size_norm_name),
    compression='uncompressed'
)

In [None]:
from alpineer import io_utils
import numpy as np
def generate_meta_avg_files(base_dir, mapping, cell_som_cluster_cols,
                            cell_som_input_data,
                            cell_som_expr_col_avg_name,
                            cell_meta_expr_col_avg_name, overwrite=False):
    """Computes and saves the average cluster column expression across pixel meta clusters.
    Assigns meta cluster labels to the data stored in `cell_som_expr_col_avg_name`.

    Args:
        base_dir (str):
            The path to the data directory
        mapping (pandas.DataFrame):
            The mapping from SOM to meta clusters
        cell_som_cluster_cols (list):
            The list of columns used for SOM training
        cell_som_input_data (pandas.DataFrame):
            The input data used for SOM training.
            Will have meta labels appended after this process is run.
        cell_som_expr_col_avg_name (str):
            The average values of `cell_som_cluster_cols` per cell SOM cluster.
            Used to run consensus clustering on.
        cell_meta_expr_col_avg_name (str):
            Same as above except for cell meta clusters
        overwrite (bool):
            If set, regenerate the averages of `cell_som_cluster_cols` per meta cluster
    """
    # define the paths to the data
    som_expr_col_avg_path = os.path.join(base_dir, cell_som_expr_col_avg_name)
    meta_expr_col_avg_path = os.path.join(base_dir, cell_meta_expr_col_avg_name)

    # check paths
    io_utils.validate_paths([som_expr_col_avg_path])

    # raise error if cell_som_input_data doesn't contain meta labels
    if 'cell_meta_cluster' not in cell_som_input_data.columns.values:
        raise ValueError('cell_som_input_data does not have meta labels assigned')

    # if the column average file for cell meta clusters already exists, skip
    if os.path.exists(meta_expr_col_avg_path):
        if not overwrite:
            print("Already generated average expression file for cell meta clusters, skipping")
            return

        print(
            "Overwrite flag set, regenerating average expression file for cell meta clusters"
        )

    # compute the average value of each expression column per cell meta cluster
    print("Computing the average value of each training column specified per cell meta cluster")
    cell_meta_cluster_avgs = cell_cluster_utils.compute_cell_som_cluster_cols_avg(
        cell_som_input_data,
        cell_som_cluster_cols,
        'cell_meta_cluster',
        keep_count=True
    )

    # save the average expression values of cell_som_cluster_cols per cell meta cluster
    cell_meta_cluster_avgs.to_csv(
        meta_expr_col_avg_path,
        index=False
    )

    print(
        "Mapping meta cluster values onto average expression values across cell SOM clusters"
    )

    # read in the average number of pixel/SOM clusters across all cell SOM clusters
    cell_som_cluster_avgs = pd.read_csv(som_expr_col_avg_path)
    cell_som_cluster_avgs['cell_som_cluster'] = cell_som_cluster_avgs['cell_som_cluster'].astype(
        int)

    # this happens if the overwrite flag is set with previously generated data, need to overwrite
    if 'cell_meta_cluster' in cell_som_cluster_avgs.columns.values:
        cell_som_cluster_avgs = cell_som_cluster_avgs.drop(columns='cell_meta_cluster')

    # merge metacluster assignments in
    cell_som_cluster_avgs = pd.merge_asof(
        cell_som_cluster_avgs, mapping.astype(np.int32), on='cell_som_cluster'
    )

    # resave average number of pixel/SOM clusters across all cell SOM clusters
    # with metacluster assignments
    cell_som_cluster_avgs.to_csv(
        som_expr_col_avg_path,
        index=False
    )

In [None]:
generate_meta_avg_files(
    base_dir,
    remap_data.rename({'cluster': 'cell_som_cluster', 'metacluster': 'cell_meta_cluster'}, axis=1).drop(columns='mc_name'),
    cell_som_cluster_cols=cell_som_cluster_cols,
    cell_som_input_data=cluster_counts_size_norm,
    cell_som_expr_col_avg_name=cell_som_cluster_count_avg_name,
    cell_meta_expr_col_avg_name=cell_meta_cluster_count_avg_name
)

In [None]:
from ark.phenotyping.weighted_channel_comp import compute_cell_cluster_weighted_channel_avg
def generate_wc_avg_files(fovs, channels, base_dir, mapping, cell_som_input_data,
                          weighted_cell_channel_name='weighted_cell_channel.feather',
                          cell_som_cluster_channel_avg_name='cell_som_cluster_channel_avg.csv',
                          cell_meta_cluster_channel_avg_name='cell_meta_cluster_channel_avg.csv',
                          overwrite=False):
    """Generate the weighted channel average files per cell SOM and meta clusters.

    When running cell clustering with pixel clusters generated from Pixie, the counts of each
    pixel cluster per cell is computed. These are multiplied by the average expression profile of
    each pixel cluster to determine weighted channel average. This computation is averaged by both
    cell SOM and meta cluster.

    Args:
        fovs (list):
            The list of fovs to subset on
        channels (list):
            The list of channels to subset on
        base_dir (str):
            The path to the data directory
        mapping (pandas.DataFrame):
            The mapping from SOM to meta clusters
        cell_som_input_data (str):
            The input data used for SOM training. For weighted channel averaging, it should
            contain the number of pixel SOM/meta cluster counts of each cell,
            normalized by `cell_size`.
        weighted_cell_channel_name (str):
            The name of the file containing the weighted channel expression table
        cell_som_cluster_channel_avg_name (str):
            The name of the file to save the average weighted channel expression
            per cell SOM cluster
        cell_meta_cluster_channel_avg_name (str):
            Same as above except for cell meta clusters
        overwrite (bool):
            If set, regenerate average weighted channel expression for SOM and meta clusters
    """
    # define the paths to the data
    weighted_channel_path = os.path.join(base_dir, weighted_cell_channel_name)
    som_cluster_channel_avg_path = os.path.join(base_dir, cell_som_cluster_channel_avg_name)
    meta_cluster_channel_avg_path = os.path.join(base_dir, cell_meta_cluster_channel_avg_name)

    # check paths
    io_utils.validate_paths([weighted_channel_path])

    # if the weighted channel average files exist, skip
    if os.path.exists(som_cluster_channel_avg_path) and \
       os.path.exists(meta_cluster_channel_avg_path):
        if not overwrite:
            print("Already generated average weighted channel expression files, skipping")
            return

        print("Overwrite flag set, regenerating average weighted channel expression files")

    print("Compute average weighted channel expression across cell SOM clusters")
    cell_som_cluster_channel_avg = compute_cell_cluster_weighted_channel_avg(
        fovs,
        channels,
        base_dir,
        weighted_cell_channel_name,
        cell_som_input_data,
        'cell_som_cluster'
    )

    # merge metacluster assignments into cell_som_cluster_channel_avg
    print(
        "Mapping meta cluster values onto average weighted channel expression"
        "across cell SOM clusters"
    )
    cell_som_cluster_channel_avg = pd.merge_asof(
        cell_som_cluster_channel_avg,
        mapping.astype(np.int32),
        on='cell_som_cluster')

    # save the weighted channel average expression per cell cluster
    cell_som_cluster_channel_avg.to_csv(
        som_cluster_channel_avg_path,
        index=False
    )

    # compute the weighted channel average expression per cell meta cluster
    print("Compute average weighted channel expression across cell meta clusters")
    cell_meta_cluster_channel_avg = compute_cell_cluster_weighted_channel_avg(
        fovs,
        channels,
        base_dir,
        weighted_cell_channel_name,
        cell_som_input_data,
        'cell_meta_cluster'
    )

    # save the weighted channel average expression per cell cluster
    cell_meta_cluster_channel_avg.to_csv(
        meta_cluster_channel_avg_path,
        index=False
    )

In [None]:
generate_wc_avg_files(
    fovs,
    channels,
    base_dir,
    remap_data.rename({'cluster': 'cell_som_cluster', 'metacluster': 'cell_meta_cluster'}, axis=1).drop(columns='mc_name'),
    cell_som_input_data=cluster_counts_size_norm,
    weighted_cell_channel_name=weighted_cell_channel_name,
    cell_som_cluster_channel_avg_name=cell_som_cluster_channel_avg_name,
    cell_meta_cluster_channel_avg_name=cell_meta_cluster_channel_avg_name
)

In [None]:
remap_data.rename({'cluster': 'cell_som_cluster', 'metacluster': 'cell_meta_cluster', 'mc_name':'cell_meta_cluster_rename'}, axis=1).to_csv(os.path.join(base_dir, "pixie", cell_output_dir, 'cell_meta_cluster_mapping_renamed.csv'), index = False)

In [None]:
cell_meta_cluster_remap_name = os.path.join(base_dir, "pixie", cell_output_dir, 'cell_meta_cluster_mapping_renamed.csv')

In [None]:
%matplotlib widget
rc_file_defaults()
plt.ion()

cell_mcd = metaclusterdata_from_files(
    os.path.join(base_dir, cell_som_cluster_count_avg_name),
    cluster_type='cell',
    prefix_trim=pixel_cluster_col + '_'
)
cell_mcd.output_mapping_filename = os.path.join(base_dir, cell_meta_cluster_remap_name)
cell_mcg = MetaClusterGui(cell_mcd, width=17)

In [None]:
cluster_counts_size_norm = cell_meta_clustering.apply_cell_meta_cluster_remapping(
    base_dir,
    cluster_counts_size_norm,
     os.path.join(base_dir, "pixie", cell_output_dir, 'cell_meta_cluster_mapping_renamed.csv')
)

feather.write_dataframe(
    cluster_counts_size_norm,
    os.path.join(base_dir, cluster_counts_size_norm_name),
    compression='uncompressed'
)

cell_meta_clustering.generate_remap_avg_count_files(
    base_dir,
    cluster_counts_size_norm,
     os.path.join(base_dir, "pixie", cell_output_dir, 'cell_meta_cluster_mapping_renamed.csv'),
    cell_som_cluster_cols,
    cell_som_cluster_count_avg_name,
    cell_meta_cluster_count_avg_name,
)

weighted_channel_comp.generate_remap_avg_wc_files(
    fovs,
    channels,
    base_dir,
    cluster_counts_size_norm,
    os.path.join(base_dir, "pixie", cell_output_dir, 'cell_meta_cluster_mapping_renamed.csv'),
    weighted_cell_channel_name,
    cell_som_cluster_channel_avg_name,
    cell_meta_cluster_channel_avg_name
)

In [None]:
raw_cmap, renamed_cmap = colormap_helper.generate_meta_cluster_colormap_dict(
    cell_mcd.output_mapping_filename,
    cell_mcg.im_cl.cmap,
    cluster_type='cell'
)

In [None]:
weighted_channel_comp.generate_weighted_channel_avg_heatmap(
    os.path.join(base_dir, cell_som_cluster_channel_avg_name),
    'cell_som_cluster',
    channels,
    raw_cmap,
    renamed_cmap
)

In [None]:
weighted_channel_comp.generate_weighted_channel_avg_heatmap(
    os.path.join(base_dir, cell_meta_cluster_channel_avg_name),
    'cell_meta_cluster_rename',
    channels,
    raw_cmap,
    renamed_cmap
)

In [None]:
# select fovs to display
subset_cell_fovs = fovs

In [None]:
import numba as nb
import itertools
import os
import pathlib
import re
from typing import List, Union
from numpy.typing import ArrayLike, DTypeLike
from numpy import ma
import feather
import natsort as ns
import numpy as np
import pandas as pd
import skimage.io as io
from alpineer import data_utils, image_utils, io_utils, load_utils, misc_utils
from alpineer.settings import EXTENSION_TYPES
from tqdm.notebook import tqdm_notebook as tqdm
import xarray as xr
from ark import settings
from skimage.segmentation import find_boundaries


def save_fov_mask(fov, data_dir, mask_data, sub_dir=None, name_suffix=''):
    """Saves a provided cluster label mask overlay for a FOV.

    Args:
        fov (str):
            The FOV to save
        data_dir (str):
            The directory to save the cluster mask
        mask_data (numpy.ndarray):
            The cluster mask data for the FOV
        sub_dir (Optional[str]):
            The subdirectory to save the masks in. If specified images are saved to
            "data_dir/sub_dir". If `sub_dir = None` the images are saved to `"data_dir"`.
            Defaults to `None`.
        name_suffix (str):
            Specify what to append at the end of every fov.
    """

    # data_dir validation
    io_utils.validate_paths(data_dir)

    # ensure None is handled correctly in file path generation
    if sub_dir is None:
        sub_dir = ''

    save_dir = os.path.join(data_dir, sub_dir)

    # make the save_dir if it doesn't already exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # define the file name as the fov name with the name suffix appended
    fov_file = fov + name_suffix + '.tiff'

    # save the image to data_dir
    image_utils.save_image(os.path.join(save_dir, fov_file), mask_data)


def erode_mask(seg_mask: np.ndarray, **kwargs) -> np.ndarray:
    """
    Erodes the edges labels of a segmentation mask.
    Other keyword arguments get passed to `skimage.segmentation.find_boundaries`.

    Args:
        seg_mask (np.ndarray): The segmentation mask to erode.

    Returns:
        np.ndarray: The eroded segmentation mask
    """
    edges = find_boundaries(
        label_img=seg_mask, **kwargs)
    seg_mask = np.where(edges == 0, seg_mask, 0)
    return seg_mask


class ClusterMaskData:
    """
    A class containing the cell labels, cluster labels, and segmentation labels for the
    whole cohort. Also contains the mapping from the segmentation label to the cluster
    label for each FOV.
    """

    fov_column: str
    label_column: str
    cluster_column: str
    unique_fovs: List[str]
    cluster_id_column: str
    unassigned_id: int
    n_clusters: int
    mapping: pd.DataFrame

    def __init__(
        self, data: pd.DataFrame, fov_col: str, label_col: str, cluster_col: str
    ) -> None:
        """
        A class containing the cell data, cell label column, cluster column and the mapping from a
        cell label to a cluster.

        Args:
            data (pd.DataFrame):
                A cell table with the cell label column and the cluster column.
            fov_col (str):
                The name of the column in the cell table that contains the FOV ID.
            label_col (str):
                The name of the column in the cell table that contains the cell label.
            cluster_col (str):
                The name of the column in the cell table that contains the cluster label.
        """
        self.fov_column: str = fov_col
        self.label_column: str = label_col
        self.cluster_column: str = cluster_col
        self.cluster_id_column: str = "cluster_id"

        # Extract only the necessary columns: fov ID, segmentation label, cluster label
        mapping_data: pd.DataFrame = data[
            [self.fov_column, self.label_column, self.cluster_column]
        ].copy()
        print("The mapping data initialized is:")
        print(mapping_data)

        # Add a cluster_id_column to the column in case the cluster_column is
        # non-numeric (i.e. string)
        cluster_name_id = pd.DataFrame(
            {self.cluster_column: mapping_data[self.cluster_column].unique()})
        print("The cluster_name_id initialized is:")
        print(cluster_name_id)

        cluster_name_id[self.cluster_id_column] = (cluster_name_id.index + 1).astype(np.int32)
        print("The cluster_name_id after assigning index is:")
        print(cluster_name_id)

        self.cluster_name_id = cluster_name_id

        # merge the cluster_id_column to the mapping_data dataframe
        mapping_data = mapping_data.merge(right=self.cluster_name_id, on=self.cluster_column)

        mapping_data = mapping_data.astype(
            {
                self.fov_column: str,
                self.label_column: np.int32,
                self.cluster_id_column: np.int32,
            }
        )
        print("The mapping data after processing is:")
        print(mapping_data)
        self.unique_fovs: List[str] = ns.natsorted(
            mapping_data[self.fov_column].unique().tolist()
        )

        self.unassigned_id: np.int32 = np.int32(
            mapping_data[self.cluster_id_column].max() + 1
        )
        self.n_clusters: int = mapping_data[self.cluster_id_column].max()

        # For each FOV map the segmentation label 0 (background) to the cluster label 0
        cluster0_mapping: pd.DataFrame = pd.DataFrame(
            data={
                self.fov_column: self.unique_fovs,
                self.label_column: np.repeat(0, repeats=len(self.unique_fovs)),
                self.cluster_column: np.repeat(0, repeats=len(self.unique_fovs)),
                self.cluster_id_column: np.repeat(0, repeats=len(self.unique_fovs)),
            }
        )

        mapping_data = pd.concat(objs=[mapping_data, cluster0_mapping]).astype(
            {
                self.fov_column: str,
                self.label_column: np.int32,
                self.cluster_id_column: np.int32,
            }
        )

        # Sort by FOV first, then by segmentation label
        self.mapping = mapping_data.sort_values(by=[self.fov_column, self.label_column])

    def fov_mapping(self, fov: str) -> pd.DataFrame:
        """Returns the mapping for a specific FOV.
        Args:
            fov (str):
                The FOV to get the mapping for.
        Returns:
            pd.DataFrame:
                The mapping for the FOV.
        """
        misc_utils.verify_in_list(requested_fov=[fov], all_fovs=self.unique_fovs)
        fov_data: pd.DataFrame = self.mapping[self.mapping[self.fov_column] == fov]

        return fov_data.reset_index(drop=True)

    @property
    def cluster_names(self) -> List[str]:
        """Returns the cluster names.
        Returns:
            List[str]:
                The cluster names.
        """
        return self.cluster_name_id[self.cluster_column].tolist()


def label_cells_by_cluster(
        fov: str,
        cmd: ClusterMaskData,
        label_map: Union[np.ndarray, xr.DataArray],
) -> np.ndarray:
    """Translates cell-ID labeled images according to the clustering assignment
    found in cell_cluster_mask_data.


    Args:
        fov (str):
            The FOV to relabel
        cmd (ClusterMaskData):
            A dataclass containing the cell data, cell label column, cluster column and the
            mapping from the segmentation label to the cluster label for a given FOV.
        label_map (xarray.DataArray):
            label map for a single FOV.

    Returns:
        numpy.ndarray:
            The image with new designated label assignments
    """

    # verify that fov found in all_data
    misc_utils.verify_in_list(
        fov_name=[fov],
        all_data_fovs=cmd.unique_fovs
    )

    # condense extraneous axes if label_map is a DataArray
    if isinstance(label_map, xr.DataArray):
        labeled_image = label_map.squeeze().values.astype(np.int32)
    else:
        labeled_image: np.ndarray = label_map.squeeze().astype(np.int32)

    fov_clusters: pd.DataFrame = cmd.fov_mapping(fov=fov)
    # print("List of the FOV clusters")
    # print(fov_clusters["cell_meta_cluster"].unique())

    mapping: nb.typed.typeddict = nb.typed.Dict.empty(
        key_type=nb.types.int32,
        value_type=nb.types.int32,
    )
    
    # print(f"The label column is: {cmd.label_column}")
    # print(f"The cluster ID column is: {cmd.cluster_id_column}")

    # for label, cluster in fov_clusters[[cmd.label_column, cmd.cluster_id_column]].itertuples(
    #         index=False):
    for label, cluster in fov_clusters[[cmd.label_column, cmd.cluster_column]].itertuples(
            index=False):
        mapping[np.int32(label)] = np.int32(cluster)
    # print("The mapping to use is:")
    # print(np.sort(np.unique(np.array(list(mapping.values())))))

    relabeled_image: np.ndarray = relabel_segmentation(
        mapping=mapping,
        unassigned_id=cmd.unassigned_id,
        labeled_image=labeled_image,
        _dtype=np.int32)
    # print("The unique IDs in the relabeled image:")
    # print(np.sort(np.unique(relabeled_image)))

    return relabeled_image.astype(np.int16)


def map_segmentation_labels(
    labels: Union[pd.Series, np.ndarray],
    values: Union[pd.Series, np.ndarray],
    label_map: ArrayLike,
    unassigned_id: float = 0,
) -> np.ndarray:
    """
    Maps an image consisting of segmentation labels to an image consisting of a particular type of
    statistic, metric, or value of interest.

    Args:
        labels (Union[pd.Series, np.ndarray]): The segmentation labels.
        values (Union[pd.Series, np.ndarray]): The values to map to the segmentation labels.
        label_map (ArrayLike): The segmentation labels as an image to map to.
        unassigned_id (int | float, optional): A default value to assign there is exists no 1-to-1
        mapping from a label in the label_map to a label in the `labels` argument. Defaults to 0.

    Returns:
        np.ndarray: Returns the mapped image.
    """
    # condense extraneous axes if label_map is a DataArray
    if isinstance(label_map, xr.DataArray):
        labeled_image = label_map.squeeze().values.astype(np.int32)
    else:
        labeled_image: np.ndarray = label_map.squeeze().astype(np.int32)

    if isinstance(labels, pd.Series):
        labels = labels.to_numpy(dtype=np.int32)
    if isinstance(values, pd.Series):
        # handle NaNs, replace with 0
        values = ma.fix_invalid(values.to_numpy(dtype=np.float64), fill_value=0).data

    mapping: nb.typed.typeddict = nb.typed.Dict.empty(
        key_type=nb.types.int32, value_type=nb.types.float64
    )

    for label, value in zip(labels, values):
        mapping[label] = value

    relabeled_image: np.ndarray = relabel_segmentation(
        mapping=mapping,
        unassigned_id=unassigned_id,
        labeled_image=labeled_image,
        _dtype=np.float64,
    )

    return relabeled_image


@nb.njit(parallel=True)
def relabel_segmentation(
    mapping: nb.typed.typeddict,
    unassigned_id: np.int32,
    labeled_image: np.ndarray,
    _dtype: DTypeLike = np.float64,
) -> np.ndarray:
    """
    Relabels a labled segmentation image according to the provided values.

    Args:
        mapping (nb.typed.typeddict):
            A Numba typed dictionary mapping segmentation labels to cluster labels.
        unassigned_id (np.int32):
            The label given to a pixel with no associated cluster.
        labeled_image (np.ndarray):
            The labeled segmentation image.
        _dtype (DTypeLike, optional):
            The data type of the relabeled image. Defaults to `np.float64`.

    Returns:
        np.ndarray: The relabeled segmentation image.
    """
    relabeled_image: np.ndarray = np.empty(shape=labeled_image.shape, dtype=_dtype)
    for i in nb.prange(labeled_image.shape[0]):
        for j in nb.prange(labeled_image.shape[1]):
            relabeled_image[i, j] = mapping.get(labeled_image[i, j], unassigned_id)
    return relabeled_image


def generate_cluster_mask(
        fov: str,
        seg_dir: Union[str, pathlib.Path],
        cmd: ClusterMaskData,
        seg_suffix: str = "_whole_cell.tiff",
        erode: bool = True,
        **kwargs) -> np.ndarray:
    """For a fov, create a mask labeling each cell with their SOM or meta cluster label

    Args:
        fov (str):
            The fov to relabel
        seg_dir (str):
            The path to the segmentation data
        cmd (ClusterMaskData):
            A dataclass containing the cell data, cell label column, cluster column and the
            mapping from the segmentation label to the cluster label for a given FOV.
        seg_suffix (str):
            The suffix that the segmentation images use. Defaults to `'_whole_cell.tiff'`.
        erode (bool):
            Whether to erode the edges of the segmentation mask. Defaults to `True`.

    Returns:
        numpy.ndarray:
            The image where values represent cell cluster labels.
    """

    # path checking
    io_utils.validate_paths([seg_dir])

    # define the file for whole cell
    whole_cell_files = [fov + seg_suffix]

    # load the segmentation labels in for the FOV
    label_map = load_utils.load_imgs_from_dir(
        data_dir=seg_dir, files=whole_cell_files, xr_dim_name='compartments',
        xr_channel_names=['whole_cell'], trim_suffix=seg_suffix.split('.')[0]
    ).loc[fov, ...]

    if erode:
        label_map = erode_mask(label_map, connectivity=2, mode="thick", background=0)

    # use label_cells_by_cluster to create cell masks
    img_data: np.ndarray = label_cells_by_cluster(
        fov=fov,
        cmd=cmd,
        label_map=label_map
    )

    return img_data


def generate_and_save_cell_cluster_masks(
    fovs: List[str],
    save_dir: Union[pathlib.Path, str],
    seg_dir: Union[pathlib.Path, str],
    cell_data: pd.DataFrame,
    fov_col: str = settings.FOV_ID,
    label_col: str = settings.CELL_LABEL,
    cell_cluster_col: str = settings.CELL_TYPE,
    seg_suffix: str = "_whole_cell.tiff",
    sub_dir: str = None,
    name_suffix: str = "",
):
    """Generates cell cluster masks and saves them for downstream analysis.

    Args:
        fovs (List[str]):
            A list of fovs to generate and save pixel masks for.
        save_dir (Union[pathlib.Path, str]):
            The directory to save the generated cell cluster masks.
        seg_dir (Union[pathlib.Path, str]):
            The path to the segmentation data.
        cell_data (pd.DataFrame):
            The cell data with both cell SOM and meta cluster assignments.
        fov_col (str, optional):
            The column name containing the FOV IDs . Defaults to `settings.FOV_ID` (`"fov"`).
        label_col (str, optional):
            The column name containing the cell label. Defaults to
            `settings.CELL_LABEL` (`"label"`).
        cell_cluster_col (str, optional):
            Whether to assign SOM or meta clusters. Needs to be `"cell_som_cluster"` or
            `"cell_meta_cluster"`. Defaults to `settings.CELL_TYPE` (`"cell_meta_cluster"`).
        seg_suffix (str, optional):
            The suffix that the segmentation images use. Defaults to `"_whole_cell.tiff"`.
        sub_dir (str, optional):
            The subdirectory to save the images in. If specified images are saved to
            `"data_dir/sub_dir"`. If `sub_dir = None` the images are saved to `"data_dir"`.
            Defaults to `None`.
        name_suffix (str, optional):
            Specify what to append at the end of every cell mask. Defaults to `""`.
    """
    print(f"The label_col is: {label_col}")
    print(f"The cell_cluster_col is: {cell_cluster_col}")
    cmd = ClusterMaskData(
        data=cell_data,
        fov_col=fov_col,
        label_col=label_col,
        cluster_col=cell_cluster_col,
    )

    # create the pixel cluster masks across each fov
    with tqdm(total=len(fovs), desc="Cell Cluster Mask Generation", unit="FOVs") as pbar:
        for fov in fovs:
            pbar.set_postfix(FOV=fov)

            # generate the cell mask for the FOV
            cell_mask: np.ndarray = generate_cluster_mask(
                fov=fov, seg_dir=seg_dir, cmd=cmd, seg_suffix=seg_suffix
            )

            # save the cell mask generated
            save_fov_mask(
                fov,
                data_dir=save_dir,
                mask_data=cell_mask,
                sub_dir=sub_dir,
                name_suffix=name_suffix,
            )

            pbar.update(1)

In [None]:
generate_and_save_cell_cluster_masks(
    fovs=subset_cell_fovs,
    save_dir=os.path.join(base_dir, "pixie", cell_output_dir),
    seg_dir=os.path.join(base_dir, segmentation_dir),
    cell_data=cluster_counts_size_norm,
    seg_suffix=seg_suffix,
    sub_dir='cell_masks',
    name_suffix='_cell_mask'
)

Save the colored cell masks for each FOV in `subset_cell_fovs`.

In [None]:
plot_utils.save_colored_masks(
    fovs=subset_cell_fovs,
    mask_dir=os.path.join(base_dir, "pixie",cell_output_dir, "cell_masks"),
    save_dir=os.path.join(base_dir, "pixie",cell_output_dir, "cell_mask_colored"),
    cluster_id_to_name_path=os.path.join(base_dir, cell_meta_cluster_remap_name),
    metacluster_colors=raw_cmap,
    cluster_type="cell"
)

In [None]:
for cell_fov in subset_cell_fovs:
    cell_cluster_mask = load_utils.load_imgs_from_dir(
        data_dir = os.path.join(base_dir, "pixie", cell_output_dir, "cell_masks"),
        files=[cell_fov + "_cell_mask.tiff"],
        trim_suffix="_cell_mask",
        match_substring="_cell_mask",
        xr_dim_name="cell_mask",
        xr_channel_names=None,
    )
    
    plot_utils.plot_pixel_cell_cluster(
        cell_cluster_mask,
        [cell_fov],
        os.path.join(base_dir, cell_meta_cluster_remap_name),
        metacluster_colors=raw_cmap,
        cluster_type='cell',
        erode=True
    )

In [None]:
cell_cluster_utils.add_consensus_labels_cell_table(
    base_dir, cell_table_path, cluster_counts_size_norm
)

In [None]:
plot_utils.create_mantis_dir(
    fovs=subset_cell_fovs,
    mantis_project_path=os.path.join(base_dir, "2023-10-31_mantis_cell_calprotectin_0_015"),
    img_data_path=tiff_dir,
    mask_output_dir=os.path.join(base_dir, "pixie", cell_output_dir, "cell_masks"),
    mapping = os.path.join(base_dir, cell_meta_cluster_remap_name),
    seg_dir=os.path.join(base_dir, segmentation_dir),
    cluster_type='cell',
    mask_suffix="_cell_mask",
    seg_suffix_name=seg_suffix,
    img_sub_folder=img_sub_folder
)