# Pixie: cell clustering notebook

NOTE: this notebook should be run after `1_Pixie_Cluster_Pixels.ipynb`

In [None]:
# Add directory above current directory to path
import sys; sys.path.append('..')

In [None]:
import json
import os

import feather
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import xarray as xr
from matplotlib import rc_file_defaults
from alpineer import io_utils, load_utils

from ark.analysis import visualize
from ark.phenotyping import (cell_cluster_utils,
                             cell_som_clustering,
                             cell_meta_clustering,
                             weighted_channel_comp)
from ark.utils import data_utils, example_dataset, plot_utils
from ark.utils.metacluster_remap_gui import (MetaClusterData, MetaClusterGui,
                                             colormap_helper,
                                             metaclusterdata_from_files)

In [None]:
SEED = 42
DATASET = "IMMUcan_2022_CancerExample"
RESULTS_DIR = 'Blur=2'

In [None]:
CONFIG_PATH = f"/home/dani/Documents/Thesis/Methods/IMCBenchmark/output/{DATASET}/pixie/{RESULTS_DIR}/config.json"

# load the params
with open(CONFIG_PATH) as f:
    pixie_config = json.load(f)
    
# assign the params to variables
input_dir = pixie_config['input_dir']
output_dir = pixie_config['output_dir']
fovs = pixie_config['fovs']

channels = pixie_config['channels']
type_channels = pixie_config['type_channels']

random.seed(SEED)
validation_fovs = random.sample(fovs, 4)

print(f'Data Folder: {input_dir}\n')
print(f'Output Folder: {output_dir}\n')
print(f'FOVS: {fovs}\n')
print(f'FOVS for validation: {validation_fovs}\n')
print(f'Channels to use: {channels}\n')

The following params are loaded from general config:
* `input_dir`: directory containing raw data for PIXIE
* `output_dir`: directory containing current experiment with PIXIE
* `fovs`: subset of fovs used for pixel/cell clustering
* `validation_fovs`: subset of fovs used for validating pixel/cell clustering
* `channels`: subset of channels used for pixel/cell clustering

## 1: Load parameters for cell clustering (saved during `1_Pixie_Cluster_Pixels.ipynb`)

In [None]:
# define the output directory of the pixel clustering
pixel_output_dir = 'pixel_output'

# define the base output cell folder
cell_output_dir = 'cell_output'
if not os.path.exists(os.path.join(output_dir, cell_output_dir)):
    os.mkdir(os.path.join(output_dir, cell_output_dir))

In [None]:
# define the name of the directory with the extracted image data
tiff_dir = os.path.join(input_dir, "images")

# define the name of the directory with segmentation masks
masks_dir = os.path.join(input_dir, "masks")

# define the cell table path
cell_table_path = os.path.join(input_dir, 'cells.csv')

# define suffix of the segmentation mask files
seg_suffix = '_whole_cell.tiff'

The following parameters are defined:
* `tiff_dir`: path to the directory containing your single channel tiff images
* `masks_dir`: path to the directory containing your segmented images 
* `cell_table_path`: path to the cell table where each row in the table is one cell, must contain `fov`, `label`, and `cell_size` columns.
* `seg_suffix`: suffix plus the file extension of the segmented images for each FOV

In [None]:
# define the name of the cell clustering params file
cell_clustering_params_name = 'cell_clustering_params.json'

# load the pixel clustering params
with open(os.path.join(output_dir, pixel_output_dir, cell_clustering_params_name)) as fh:
    cell_clustering_params = json.load(fh)
    
# assign the params to variables
pixel_data_dir = cell_clustering_params['pixel_data_dir']
pc_chan_avg_som_cluster_name = cell_clustering_params['pc_chan_avg_som_cluster_name']
pc_chan_avg_meta_cluster_name = cell_clustering_params['pc_chan_avg_meta_cluster_name']

The following params are loaded from previous pixel clustering workflow:

* `pixel_data_dir`: name of the directory containing pixel data with the pixel SOM and meta cluster assignments
* `pc_chan_avg_som_cluster_name`: name of the file containing the average channel expression per pixel SOM cluster, used for the visualization of weighted channel average per cell
* `pc_chan_avg_meta_cluster_name`: name of the file containing the average channel expression per pixel meta cluster, used for the visualization of weighted channel average per cell

## 2: Preprocess

Set a prefix to be applied to all data directories/files created during cell clustering. If the prefix is not set, a default of the datetime at the start of the run is used.

The following folders/files will be created:

* `cell_som_weights_name`: file name to store the cell SOM weights
* `cluster_counts_name`: file name to store the counts of each pixel cluster per cell
* `cluster_counts_size_norm_name`: same as above, except with each value normalized by the respective cell's size. The data will also contain the cell SOM and meta cluster labels assigned.
* `weighted_cell_channel_name`: file name to store the weighted cell channel expression for each cell. Refer to <a href=https://ark-analysis.readthedocs.io/en/latest/_markdown/ark.phenotyping.html#ark.phenotyping.cell_cluster_utils.compute_p2c_weighted_channel_avg>cell channel weighting docs</a> for how the weighting is computed.
* `cell_som_cluster_count_avg_name`: file name to store the average number of pixel clusters per cell SOM cluster
* `cell_meta_cluster_count_avg_name`: same as above for cell meta clusters
* `cell_som_cluster_channel_avg_name`: file name to store the average weighted channel expression per cell SOM cluster
* `cell_meta_cluster_channel_avg_name`: same as above for cell meta clusters
* `cell_meta_cluster_remap_name`: file name to store the SOM cluster to meta cluster manual mappings created using the GUI below

In [None]:
# define the paths to cell clustering files, explicitly set the variables to use custom names

def get_method_config(pixel_cluster_col):
    config = {}
    
    config['pixel_cluster_col'] = pixel_cluster_col
    
    # depending on which pixel_cluster_col is selected, choose the pixel channel average table accordingly
    if pixel_cluster_col == 'pixel_som_cluster':
        config['pc_chan_avg_name'] = pc_chan_avg_som_cluster_name
        config['method_dir'] = 'from_som'
    elif pixel_cluster_col == 'pixel_meta_cluster_rename':
        config['pc_chan_avg_name'] = pc_chan_avg_meta_cluster_name
        config['method_dir'] = 'from_meta'
        
    method_dir = config['method_dir']
    if not os.path.exists(os.path.join(output_dir, cell_output_dir, method_dir)):
        os.mkdir(os.path.join(output_dir, cell_output_dir, method_dir))
    
    config["cell_som_weights_name"] = os.path.join(cell_output_dir, method_dir, 'cell_som_weights.feather')
    config["cluster_counts_name"] = os.path.join(cell_output_dir, method_dir, 'cluster_counts.feather')
    config["cluster_counts_size_norm_name"] = os.path.join(cell_output_dir, method_dir, 'cluster_counts_size_norm.feather')
    config["weighted_cell_channel_name"] = os.path.join(cell_output_dir, method_dir, 'weighted_cell_channel.feather')
    config["cell_som_cluster_count_avg_name"] = os.path.join(cell_output_dir, method_dir, 'cell_som_cluster_count_avg.csv')
    config["cell_meta_cluster_count_avg_name"] = os.path.join(cell_output_dir, method_dir, 'cell_meta_cluster_count_avg.csv')
    config["cell_som_cluster_channel_avg_name"] = os.path.join(cell_output_dir, method_dir, 'cell_som_cluster_channel_avg.csv')
    config["cell_meta_cluster_channel_avg_name"] = os.path.join(cell_output_dir, method_dir, 'cell_meta_cluster_channel_avg.csv')
    config["cell_meta_cluster_remap_name"] = os.path.join(cell_output_dir, method_dir, 'cell_meta_cluster_mapping.csv')
    
    return config


# from_som_config = get_method_config('pixel_som_cluster')
from_meta_config = get_method_config('pixel_meta_cluster_rename')

configs = [from_meta_config]

In [None]:
for config in configs:
    # generate the preprocessed data before 
    cluster_counts, cluster_counts_size_norm = cell_cluster_utils.create_c2pc_data(
        fovs, os.path.join(output_dir, pixel_data_dir), cell_table_path, config['pixel_cluster_col']
    )
    
    config['cluster_counts'] = cluster_counts
    config['cluster_counts_size_norm'] = cluster_counts_size_norm

    # define the count columns found in cluster_counts_norm
    cell_som_cluster_cols = cluster_counts_size_norm.filter(
        regex=f'{config["pixel_cluster_col"]}.*'
    ).columns.values
    
    config['cell_som_cluster_cols'] = cell_som_cluster_cols

    # write the unnormalized input data to cluster_counts_name for reference
    feather.write_dataframe(
        cluster_counts,
        os.path.join(output_dir, config['cluster_counts_name']),
        compression='uncompressed'
    )

Generate the weighted cell channel expression file. This data will be needed to compute the weighted average channel expression per cell cluster (the data stored in `cell_som_cluster_channel_avg_name` and `cell_meta_cluster_channel_avg_name`). See documentation of `compute_p2c_weighted_channel_avg` for how weighted cell channel average is computed: <a href=https://ark-analysis.readthedocs.io/en/latest/_markdown/ark.phenotyping.html#ark.phenotyping.cell_cluster_utils.compute_p2c_weighted_channel_avg>cell channel weighting docs</a>.

In [None]:
for config in configs:
    # generate the weighted cell channel expression data
    pixel_channel_avg = pd.read_csv(os.path.join(output_dir, config['pc_chan_avg_name']))
    weighted_cell_channel = weighted_channel_comp.compute_p2c_weighted_channel_avg(
        pixel_channel_avg,
        channels,
        config['cluster_counts'],
        fovs=fovs,
        pixel_cluster_col=config['pixel_cluster_col']
    )
    
    config['weighted_cell_channel'] = weighted_cell_channel

    # write the data to weighted_cell_channel_name
    feather.write_dataframe(
        weighted_cell_channel,
        os.path.join(output_dir, config['weighted_cell_channel_name']),
        compression='uncompressed'
    )

## 3: Cell clustering

### 3.1: Train cell SOM

Train the cell SOM on the size-normalized number of pixel clusters per cell (the data stored in `cluster_counts_size_norm_name`). Training is done using the self-organizing map (SOM) algorithm. Note that each of the pixel SOM/meta cluster columns are normalized by their 99.9% value prior to training.

For a full set of parameters you can customize for `train_cell_som`, please consult <a href=https://ark-analysis.readthedocs.io/en/latest/_markdown/ark.phenotyping.html#ark.phenotyping.cell_cluster_utils.train_cell_som>cell training docs</a>.

In [None]:
for config in configs:
    # create the cell SOM weights
    cell_pysom = cell_som_clustering.train_cell_som(
        fovs,
        output_dir,
        cell_table_path=cell_table_path,
        cell_som_cluster_cols=config['cell_som_cluster_cols'],
        cell_som_input_data=config['cluster_counts_size_norm'],
        som_weights_name=config['cell_som_weights_name'],
        num_passes=1,
        seed=SEED
    )
    
    config['cell_pysom'] = cell_pysom

### 3.2: Assign cell SOM clusters

Use the weights learned from `train_cell_som` to assign cell clusters to the dataset. Note that this is done on the size-normalized pixel cluster counts table. As with `train_pixel_som`, each of the columns are normalized by their 99.9% value prior to assigning a cell SOM cluster label.

`generate_som_avg_files` will compute the average number of pixel clusters per cell SOM cluster, as well as the number of cells in each cell SOM cluster (the data placed in `cell_som_cluster_count_avg_name`). This is needed for cell consensus clustering.

In [None]:
for config in configs:
    # use cell SOM weights to assign cell clusters
    cluster_counts_size_norm = cell_som_clustering.cluster_cells(
        output_dir,
        config['cell_pysom'],
        cell_som_cluster_cols=config['cell_som_cluster_cols']
    )
    
    config['cluster_counts_size_norm'] = cluster_counts_size_norm

    # generate the SOM cluster summary files
    cell_som_clustering.generate_som_avg_files(
        output_dir,
        config['cluster_counts_size_norm'],
        cell_som_cluster_cols=config['cell_som_cluster_cols'],
        cell_som_expr_col_avg_name=config['cell_som_cluster_count_avg_name']
    )

### 3.3: Run cell consensus clustering

Use consensus hierarchical clustering to cluster cell SOM clusters into a user-defined number of meta clusters. The consensus clusters are trained on the average number of pixel clusters across all cell SOM clusters (the data stored in `cell_som_cluster_count_avg_name`). These values are z-scored and capped at the value specified in the `cap` argument prior to consensus clustering. This helps improve meta clustering performance.

After consensus clustering, the following are computed by `generate_meta_avg_files`:

* The average number of pixel clusters across all cell meta clusters, and the number of cells per meta cluster (the data placed in `cell_meta_cluster_count_avg_name`)
* The meta cluster mapping for each cell SOM cluster in `cell_som_cluster_count_avg_name` (data is resaved, same data except with an associated meta cluster column)

`generate_wc_avg_files` also creates the following:

* The weighted channel average across all cell clusters (the data placed in `cell_som_cluster_channel_avg_name` and `cell_meta_cluster_channel_avg_name`). This will be done for both `'cell_som_cluster'` and `'cell_meta_cluster'`.

For a full set of parameters you can customize for `cell_consensus_cluster`, please consult <a href=https://ark-analysis.readthedocs.io/en/latest/_markdown/ark.phenotyping.html#ark.phenotyping.cell_cluster_utils.cell_consensus_cluster>cell consensus clustering docs</a>

* `max_k`: the number of consensus clusters desired
* `cap`: used to clip z-scored values prior to consensus clustering (in the range `[-cap, cap]`)

In [None]:
max_k = pixie_config['cells']['meta_max_k']
cap = pixie_config['cells']['meta_cap']
print(f'For metaclustering using max_k: {max_k} and z-score cap: [-{cap}, +{cap}].\n')

for config in configs:
    # run hierarchical clustering using average count of pixel clusters per cell SOM cluster
    cell_cc, cluster_counts_size_norm = cell_meta_clustering.cell_consensus_cluster(
        output_dir,
        cell_som_cluster_cols=config['cell_som_cluster_cols'],
        cell_som_input_data=config['cluster_counts_size_norm'],
        cell_som_expr_col_avg_name=config['cell_som_cluster_count_avg_name'],
        max_k=max_k,
        cap=cap,
        seed=SEED,
    )
    
    config['cell_cc'] = cell_cc
    config['cluster_counts_size_norm'] = cluster_counts_size_norm

    # generate the meta cluster summary files
    cell_meta_clustering.generate_meta_avg_files(
        output_dir,
        config['cell_cc'],
        cell_som_cluster_cols=config['cell_som_cluster_cols'],
        cell_som_input_data=config['cluster_counts_size_norm'],
        cell_som_expr_col_avg_name=config['cell_som_cluster_count_avg_name'],
        cell_meta_expr_col_avg_name=config['cell_meta_cluster_count_avg_name']
    )

    # generate weighted channel summary files
    weighted_channel_comp.generate_wc_avg_files(
        fovs,
        channels,
        output_dir,
        config['cell_cc'],
        cell_som_input_data=config['cluster_counts_size_norm'],
        weighted_cell_channel_name=config['weighted_cell_channel_name'],
        cell_som_cluster_channel_avg_name=config['cell_som_cluster_channel_avg_name'],
        cell_meta_cluster_channel_avg_name=config['cell_meta_cluster_channel_avg_name']
    )

## 4: Visualize results

In [None]:
config = from_meta_config

### 4.1: Interactive adjustments to relabel cell meta clusters

The visualization shows the z-scored average pixel cluster count expression per cell SOM and meta cluster. The heatmaps are faceted by cell SOM clusters on the left and cell meta clusters on the right.

## Usage

### Quickstart
- **Select**: Left Click
- **Remap**: **New metacluster button** or Right Click
- **Edit Metacluster Name**: Textbox at bottom right of the heatmaps.

### Selection and Remapping details
- To select a SOM cluster, click on its respective position in the **selected** bar. Click on it again to deselect.
- To select a meta cluster, click on its corresponding color in the **metacluster** bar. Click on it again to deselect.
- To remap the selected clusters, click the **New metacluster** button (alternatively, right click anywhere). Note that remapping an entire metacluster deletes it.
- To clear the selected SOM/meta clusters, use the **Clear Selection** button.
- **After remapping a meta cluster, make sure to deselect the newly created one to prevent unwanted combinations.**

### Other features and notes
- You will likely need to zoom out to see the entire visualization. To toggle Zoom, use Ctrl -/Ctrl + on Windows or ⌘ +/⌘ - on Mac.
- The bars at the top show the number of cells in each SOM cluster.
- The text box at the bottom right allows you to rename a particular meta cluster. This can be useful as remapping may cause inconsistent numbering.
- Adjust the z-score limit using the slider on the bottom left to adjust your dynamic range.
- When meta clusters are combined or a meta cluster is renamed, the change is immediately saved to `cell_meta_cluster_remap_name`.
- **You won't be able to advance until you've clicked `New metacluster` or renamed a meta cluster at least once. If you do not want to make changes, just click `New metacluster` to trigger a save before continuing.**

In [None]:
%matplotlib widget
rc_file_defaults()
plt.ion()

cell_mcd = metaclusterdata_from_files(
    os.path.join(output_dir, config['cell_som_cluster_count_avg_name']),
    cluster_type='cell',
    prefix_trim=config['pixel_cluster_col'] + '_'
)
cell_mcd.output_mapping_filename = os.path.join(output_dir, config['cell_meta_cluster_remap_name'])
cell_mcg = MetaClusterGui(cell_mcd, width=9)

In [None]:
# rename the meta cluster values in the cell dataset
cluster_counts_size_norm = cell_meta_clustering.apply_cell_meta_cluster_remapping(
    output_dir,
    from_meta_config['cluster_counts_size_norm'],
    from_meta_config['cell_meta_cluster_remap_name']
)

from_meta_config['cluster_counts_size_norm'] = cluster_counts_size_norm

# recompute the mean column expression per meta cluster and apply these new names to the SOM cluster average data
cell_meta_clustering.generate_remap_avg_count_files(
    output_dir,
    from_meta_config['cluster_counts_size_norm'],
    from_meta_config['cell_meta_cluster_remap_name'],
    from_meta_config['cell_som_cluster_cols'],
    from_meta_config['cell_som_cluster_count_avg_name'],
    from_meta_config['cell_meta_cluster_count_avg_name'],
)

# recompute the mean weighted channel expression per meta cluster and apply these new names to the SOM channel average data
weighted_channel_comp.generate_remap_avg_wc_files(
    fovs,
    channels,
    output_dir,
    from_meta_config['cluster_counts_size_norm'],
    from_meta_config['cell_meta_cluster_remap_name'],
    from_meta_config['weighted_cell_channel_name'],
    from_meta_config['cell_som_cluster_channel_avg_name'],
    from_meta_config['cell_meta_cluster_channel_avg_name']
)

In [None]:
raw_cmap, renamed_cmap = colormap_helper.generate_meta_cluster_colormap_dict(
    cell_mcd.output_mapping_filename,
    cell_mcg.im_cl.cmap,
    cluster_type='cell'
)

### 4.2: Weighted cell SOM cluster average heatmap over channels (z-scored) - from pixel META config

In [None]:
weighted_channel_comp.generate_weighted_channel_avg_heatmap(
    os.path.join(output_dir, from_meta_config['cell_som_cluster_channel_avg_name']),
    'cell_som_cluster',
    type_channels,
    raw_cmap,
    renamed_cmap
)

### 4.3: Weighted cell meta cluster average heatmap over channels (z-scored) - from pixel META config

In [None]:
weighted_channel_comp.generate_weighted_channel_avg_heatmap(
    os.path.join(output_dir, from_meta_config['cell_meta_cluster_channel_avg_name']),
    'cell_meta_cluster_rename',
    type_channels,
    raw_cmap,
    renamed_cmap
)

### 4.4: Generate cell phenotype maps - from pixel META config

Generate cell phenotype maps, in which each pixel in the image corresponds to its cell meta cluster. Run this cell if you wish to create cell cluster mask images for downstream analysis. Note that because each pixel value corresponds to a metacluster number, masks likely will not render with colors in image viewer software.

In [None]:
# generate and save the cell cluster masks for each fov in subset_cell_fovs
data_utils.generate_and_save_cell_cluster_masks(
    fovs=validation_fovs,
    base_dir=output_dir,
    save_dir=os.path.join(output_dir, cell_output_dir, from_meta_config['method_dir']),
    seg_dir=masks_dir,
    cell_data=from_meta_config['cluster_counts_size_norm'],
    seg_suffix=seg_suffix,
    sub_dir='cell_masks',
    name_suffix='_cell_mask'
)

In [None]:
for cell_fov in validation_fovs:
    cell_cluster_mask = load_utils.load_imgs_from_dir(
        data_dir = os.path.join(output_dir, cell_output_dir, from_meta_config['method_dir'], "cell_masks"),
        files=[cell_fov + "_cell_mask.tiff"],
        trim_suffix="_cell_mask",
        match_substring="_cell_mask",
        xr_dim_name="cell_mask",
        xr_channel_names=None,
    )

    plot_utils.plot_pixel_cell_cluster_overlay(
        cell_cluster_mask,
        [cell_fov],
        os.path.join(output_dir, from_meta_config['cell_meta_cluster_remap_name']),
        metacluster_colors=raw_cmap,
        cluster_type='cell'
    )

### 4.5: Create phenotype prediction file

In [None]:
def merge_with_results(cell_table, cluster_counts_size_norm, appendix=''):
    results = cluster_counts_size_norm.rename(
        {'segmentation_label': 'label'}, axis=1
    )

    # merge the cell table with the consensus data to retrieve the meta clusters
    cell_table_merged = cell_table.merge(
        results, how='left', on=['fov', 'label']
    )
    
    # rename merged table results columns
    cell_table_merged = cell_table_merged.rename({
        'cell_som_cluster': f'pred_som_cluster{appendix}',
        'cell_meta_cluster_rename': f'pred_meta_cluster{appendix}'
    }, axis=1)
    
    return cell_table_merged

In [None]:
cell_table = pd.read_csv(cell_table_path, dtype={
        'fov': str,
})

# merge results from PIXEL SOM clusters
# cell_table = merge_with_results(cell_table, from_som_config['cluster_counts_size_norm'], '_from_pixel_som')

# merge results from PIXEL META clusters
cell_table = merge_with_results(cell_table, from_meta_config['cluster_counts_size_norm'], '_from_pixel_meta')

# subset on just the cell table columns plus the meta cluster rename column
cell_table = cell_table[[
    'fov', 'label', 'cell_type', 
    'pred_som_cluster_from_pixel_meta', 'pred_meta_cluster_from_pixel_meta'
]]

# rename merged table columns for simplicity
cell_table = cell_table.rename({
    'fov': 'sample_id',
    'label': 'object_id',
    'cell_type': 'label',
    'pred_som_cluster_from_pixel_meta': 'pred_som_cluster',
    'pred_meta_cluster_from_pixel_meta': 'pred_meta_cluster'
    }, axis=1)


cell_table.to_csv(os.path.join(output_dir, 'pixie_results.csv'), index=False)

### 4.6: Save the full results of Pixie cell clustering

`cluster_counts_size_norm` with the SOM, meta, and renamed meta cluster labels, is saved to `cluster_counts_size_norm_name` as a `.feather` file.

In [None]:
for config in configs:
    feather.write_dataframe(
        cluster_counts_size_norm,
        os.path.join(output_dir, config['cluster_counts_size_norm_name']),
        compression='uncompressed'
    )

### 4.7: Save images for Mantis Viewer

Mantis Viewer is a visualization tool for multi-dimensional imaging in pathology.

In [None]:
for config in configs:
    plot_utils.create_mantis_dir(
        fovs=validation_fovs,
        mantis_project_path=os.path.join(output_dir, cell_output_dir, config['method_dir'], "mantis"),
        img_data_path=tiff_dir,
        mask_output_dir=os.path.join(output_dir, cell_output_dir, config['method_dir'], "cell_masks"),
        mapping = os.path.join(output_dir, config['cell_meta_cluster_remap_name']),
        seg_dir=masks_dir,
        cluster_type='cell',
        mask_suffix="_cell_mask",
        seg_suffix_name=seg_suffix
    )