## Perform pixel clustering on Spain and Stanford cohorts

In [None]:
import json
import os
from datetime import datetime as dt
import numpy as np
import matplotlib.pyplot as plt
from alpineer import io_utils, load_utils
from matplotlib import rc_file_defaults
from alpineer.io_utils import list_folders
from ark.phenotyping import (pixel_cluster_utils, pixel_meta_clustering,
                             pixel_som_clustering, pixie_preprocessing)
from ark.utils import data_utils, example_dataset, plot_utils
from ark.utils.metacluster_remap_gui import (MetaClusterGui,
                                             colormap_helper,
                                             metaclusterdata_from_files)

In [None]:
from alpineer import image_utils
def threshold_channels(fovs, tiff_dir, img_sub_folder, channels, threshold):
    """Thresholds selected channels as a preprocessing step

    Args:
        fovs (list):
            List of fovs to process
        tiff_dir (str):
            Name of the directory containing the tiff files
        img_sub_folder (str):
            sub-folder within each FOV containing image data
        channels (list):
            list of channels to apply smoothing to
        threshold (list or int):
            amount to threshold. If a single int, applies
            to all channels. Otherwise, a custom value per channel can be supplied

    """

    # no output if no channels specified
    if channels is None or len(channels) == 0:
        return

    # convert to path-compatible format
    if img_sub_folder is None:
        img_sub_folder = ''

    for fov in fovs:
        for idx, chan in enumerate(channels):
            img = load_utils.load_imgs_from_tree(data_dir=tiff_dir, img_sub_folder=img_sub_folder,
                                                 fovs=[fov], channels=[chan]).values[0, :, :, 0]
            
            chan_out = img.copy()
            chan_out[chan_out < threshold] = 0
            image_utils.save_image(
                os.path.join(tiff_dir, fov, img_sub_folder, chan + '_thresholded.tiff'),
                chan_out
            )

In [None]:
base_dir = "Z:\\Noah Greenwald\\TNBC_Cohorts\\SPAIN"

In [None]:
tiff_dir = os.path.join(base_dir, "image_data", "samples")
img_sub_folder = None
segmentation_dir = os.path.join("segmentation", "samples", "deepcell_output")
seg_suffix = '_whole_cell.tiff'

if segmentation_dir is not None:
    pixie_seg_dir = os.path.join(base_dir, segmentation_dir)
else:
    pixie_seg_dir = None

In [None]:
all_fovs = io_utils.list_folders(tiff_dir)
fovs = all_fovs

In [None]:
multiprocess = False
batch_size = 5

In [None]:
pixel_cluster_prefix = "20231031_full_cohort"

if pixel_cluster_prefix is None:
    pixel_cluster_prefix = dt.now().strftime('%Y-%m-%dT%H:%M:%S')

In [None]:
# define the output directory using the specified pixel cluster prefix
pixel_output_dir = os.path.join("pixie", "%s_pixel_output_dir" % pixel_cluster_prefix)
if not os.path.exists(os.path.join(base_dir, pixel_output_dir)):
    os.makedirs(os.path.join(base_dir, pixel_output_dir))

# define the preprocessed pixel data folders
pixel_data_dir = os.path.join(pixel_output_dir, 'pixel_mat_data')
pixel_subset_dir = os.path.join(pixel_output_dir, 'pixel_mat_subset')
norm_vals_name = os.path.join(pixel_output_dir, 'post_rowsum_chan_norm.feather') # pull Noah's post row normalized channel norm values instead of the default created one

In [None]:
# set an optional list of markers for additional blurring
blurred_channels = ["ECAD", "CK17"]
smooth_vals = 6

pixel_cluster_utils.smooth_channels(
    fovs=fovs,
    tiff_dir=tiff_dir,
    img_sub_folder=img_sub_folder,
    channels=blurred_channels,
    smooth_vals=smooth_vals,
)

In [None]:
threshold_channels(fovs, tiff_dir, img_sub_folder, ['Calprotectin_old'], 0.015)

In [None]:
for fov in fovs:
    if os.path.exists(os.path.join(tiff_dir, fov, 'Calprotectin.tiff')):
        os.remove(os.path.join(tiff_dir, fov, 'Calprotectin.tiff'))
    os.rename(os.path.join(tiff_dir, fov, 'Calprotectin_old_thresholded.tiff'), os.path.join(tiff_dir, fov, 'Calprotectin.tiff'))

In [None]:
filter_channel = 'CD11c'
nuclear_exclude = True

pixel_cluster_utils.filter_with_nuclear_mask(
    fovs=fovs,
    tiff_dir=tiff_dir,
    seg_dir=os.path.join(base_dir, segmentation_dir),
    channel=filter_channel,
    nuc_seg_suffix="_nuclear.tiff",
    img_sub_folder=img_sub_folder,
    exclude=nuclear_exclude
)

In [None]:
filter_channel = 'FOXP3'
nuclear_exclude = False

pixel_cluster_utils.filter_with_nuclear_mask(
    fovs=fovs,
    tiff_dir=tiff_dir,
    seg_dir=os.path.join(base_dir, segmentation_dir),
    channel=filter_channel,
    nuc_seg_suffix="_nuclear.tiff",
    img_sub_folder=img_sub_folder,
    exclude=nuclear_exclude
)

In [None]:
channels =["CD45", "SMA", "Vim", "FAP", "Fibronectin", "Collagen1", "CK17_smoothed", "ECAD_smoothed", "ChyTr",
           "Calprotectin",  "CD3", "CD4", "CD8",  "CD11c_nuc_exclude", "CD14","CD20", "CD31", "CD56",  "CD68",
           "CD163", "HLADR", "FOXP3_nuc_include"]
blur_factor = 2
subset_proportion = 0.01

In [None]:
TONIC_directory = r'Z:\Noah Greenwald\TONIC_Cohort\pixel_clustering\20220707_full_cohort_pixel_output_dir'

In [None]:
# run pixel data preprocessing
pixie_preprocessing.create_pixel_matrix(
    fovs,
    channels,
    base_dir,
    tiff_dir,
    pixie_seg_dir,
    img_sub_folder=img_sub_folder,
    seg_suffix=seg_suffix,
    pixel_output_dir=pixel_output_dir,
    data_dir=pixel_data_dir,
    subset_dir=pixel_subset_dir,
    norm_vals_name_pre_rownorm=os.path.join(TONIC_directory, "20220707_full_cohort_channel_norm_flattened.feather"),  # pre-row normalized channel normalization values from TONIC
    pixel_thresh_name=os.path.join(TONIC_directory, "20220707_full_cohort_pixel_norm_renamed.feather"), # pixel thresholded values from TONIC
    norm_vals_name_post_rownorm=norm_vals_name, # post-row normalized channel normalization values (ignore the one that gets outputted, use Noah's TONIC version post_rowsum_chan_norm.feather)
    blur_factor=blur_factor,
    subset_proportion=subset_proportion,
    multiprocess=multiprocess,
    batch_size=batch_size
)

In [None]:
pc_chan_avg_som_cluster_name = os.path.join(pixel_output_dir, 'pixel_channel_avg_som_cluster.csv')
pc_chan_avg_meta_cluster_name = os.path.join(pixel_output_dir, 'pixel_channel_avg_meta_cluster.csv')
pixel_meta_cluster_remap_name = os.path.join(pixel_output_dir, 'pixel_meta_cluster_mapping.csv')

In [None]:
pixel_som_weights_name = os.path.join(TONIC_directory, "20220707_full_cohort_pixel_weights.feather")

In [None]:
from ark.phenotyping.cluster_helpers import PixelSOMCluster
pixel_pysom = PixelSOMCluster(
    os.path.join(base_dir, pixel_subset_dir),
    os.path.join(TONIC_directory, 'post_rowsum_chan_norm.feather'),
    pixel_som_weights_name,
    fovs,
    channels,
    num_passes=1,
    xdim=17,
    ydim=17,
    lr_start=0.05,
    lr_end=0.01,
    seed=42
)

In [None]:
pixel_som_clustering.cluster_pixels(
    fovs,
    channels,
    base_dir,
    pixel_pysom,
    data_dir=pixel_data_dir,
    multiprocess=multiprocess,
    batch_size=batch_size
)

In [None]:
pixel_som_clustering.generate_som_avg_files(
    fovs,
    channels,
    base_dir,
    pixel_pysom,
    data_dir=pixel_data_dir,
    pc_chan_avg_som_cluster_name=pc_chan_avg_som_cluster_name
 )

In [None]:
import pandas as pd
remap_data_file = os.path.join(r'Z:\Noah Greenwald\TNBC_Cohorts\BELLINI\pixie\20231009_test_cohort_pixel_output_dir', "20220707_full_cohort_pixel_meta_cluster_mapping.csv")
remap_data = pd.read_csv(remap_data_file)

In [None]:
import feather
from shutil import rmtree
def assign_meta_clusters(pixel_data_dir, pixel_data_dir_temp, meta_remapping):
    if os.path.exists(pixel_data_dir_temp):
        rmtree(pixel_data_dir_temp)
    os.mkdir(pixel_data_dir_temp)
    fov_files = os.listdir(pixel_data_dir)
    som_to_meta = dict(meta_remapping[["pixel_som_cluster", "pixel_meta_cluster"]].values)
    meta_to_rename = dict(meta_remapping[["pixel_meta_cluster", "pixel_meta_cluster_rename"]].values)
    for fov in fov_files:
        arr = feather.read_dataframe(os.path.join(pixel_data_dir, fov))
        arr["pixel_meta_cluster"] = arr["pixel_som_cluster"].map(som_to_meta)
        arr["pixel_meta_cluster_rename"] = arr["pixel_meta_cluster"].map(meta_to_rename)
        feather.write_dataframe(arr, os.path.join(pixel_data_dir_temp, fov))

In [None]:
som_avg_data = pd.read_csv(os.path.join(base_dir, pc_chan_avg_som_cluster_name))
#som_to_meta = dict(remap_data[["cluster", "metacluster"]].values)
#meta_to_rename = dict(remap_data[["metacluster", "mc_name"]].values)
#som_avg_data["pixel_meta_cluster"] = som_avg_data["pixel_som_cluster"].map(som_to_meta)
#som_avg_data["pixel_meta_cluster_rename"] = som_avg_data["pixel_meta_cluster"].map(meta_to_rename)

som_to_meta = dict(remap_data[["pixel_som_cluster", "pixel_meta_cluster"]].values)
meta_to_rename = dict(remap_data[["pixel_meta_cluster", "pixel_meta_cluster_rename"]].values)
som_avg_data["pixel_meta_cluster"] = som_avg_data["pixel_som_cluster"].map(som_to_meta)
som_avg_data["pixel_meta_cluster_rename"] = som_avg_data["pixel_meta_cluster"].map(meta_to_rename).astype('str')

In [None]:
som_avg_data.to_csv(os.path.join(base_dir, pc_chan_avg_som_cluster_name), index=False)

In [None]:
assign_meta_clusters(os.path.join(base_dir, pixel_data_dir), os.path.join(base_dir, pixel_data_dir + "_temp"),  remap_data)

In [None]:
def generate_meta_avg_files(fovs, channels, base_dir, num_clusters, mapping, data_dir='pixel_mat_data',
                            pc_chan_avg_som_cluster_name='pixel_channel_avg_som_cluster.csv',
                            pc_chan_avg_meta_cluster_name='pixel_channel_avg_meta_cluster.csv',
                            num_fovs_subset=100, seed=42, overwrite=False):
    """Computes and saves the average channel expression across pixel meta clusters.
    Assigns meta cluster labels to the data stored in `pc_chan_avg_som_cluster_name`.

    Args:
        fovs (list):
            The list of fovs to subset on
        channels (list):
            The list of channels to subset on
        base_dir (str):
            The path to the data directory
        num_clusters (int):
            The number of clusters to use
        mapping (pandas.DataFrame):
            The mapping from SOM to meta clusters
        data_dir (str):
            Name of the directory which contains the full preprocessed pixel data.
            This data should also have the SOM cluster labels appended from `cluster_pixels`.
        pc_chan_avg_som_cluster_name (str):
            Name of file to save the channel-averaged results across all SOM clusters to
        pc_chan_avg_meta_cluster_name (str):
            Name of file to save the channel-averaged results across all meta clusters to
        num_fovs_subset (float):
            The number of FOVs to subset on for meta cluster channel averaging
        seed (int):
            The random seed to use for subsetting FOVs
        overwrite (bool):
            If set, force overwrites the existing average channel expression file if it exists
    """

    # define the paths to the data
    som_cluster_avg_path = os.path.join(base_dir, pc_chan_avg_som_cluster_name)
    meta_cluster_avg_path = os.path.join(base_dir, pc_chan_avg_meta_cluster_name)

    # path validation
    io_utils.validate_paths([som_cluster_avg_path])

    # if the channel meta average file already exists and the overwrite flag isn't set, skip
    if os.path.exists(meta_cluster_avg_path):
        if not overwrite:
            print("Already generated meta cluster channel average file, skipping")
            return

        print("Overwrite flag set, regenerating meta cluster channel average file")

    # compute average channel expression for each pixel meta cluster
    # and the number of pixels per meta cluster
    print("Computing average channel expression across pixel meta clusters")
    pixel_channel_avg_meta_cluster = pixel_cluster_utils.compute_pixel_cluster_channel_avg(
        fovs,
        channels,
        base_dir,
        'pixel_meta_cluster',
        num_clusters,
        data_dir,
        num_fovs_subset=num_fovs_subset,
        seed=seed,
        keep_count=True
    )

    # save pixel_channel_avg_meta_cluster
    pixel_channel_avg_meta_cluster.to_csv(
        meta_cluster_avg_path,
        index=False
    )

    # merge metacluster assignments in
    print("Mapping meta cluster values onto average channel expression across pixel SOM clusters")
    pixel_channel_avg_som_cluster = pd.read_csv(som_cluster_avg_path)

    # this happens if the overwrite flag is set with previously generated data, need to overwrite
    if 'pixel_meta_cluster' in pixel_channel_avg_som_cluster.columns.values:
        pixel_channel_avg_som_cluster = pixel_channel_avg_som_cluster.drop(
            columns='pixel_meta_cluster'
        )

    pixel_channel_avg_som_cluster["pixel_som_cluster"] =\
        pixel_channel_avg_som_cluster["pixel_som_cluster"].astype(int)
    pixel_channel_avg_som_cluster = pd.merge_asof(
        pixel_channel_avg_som_cluster, mapping, on="pixel_som_cluster"
    )

    # resave channel-averaged results across all pixel SOM clusters with metacluster assignments
    pixel_channel_avg_som_cluster.to_csv(
        som_cluster_avg_path,
        index=False
    )

In [None]:
remap_data['pixel_som_cluster'] = remap_data['pixel_som_cluster'].astype(np.int32)
remap_data['pixel_meta_cluster'] = remap_data['pixel_meta_cluster'].astype(np.int32)

In [None]:
generate_meta_avg_files(
    fovs,
    channels,
    base_dir,
    len(remap_data["pixel_meta_cluster"].unique()),
    remap_data,
    data_dir=pixel_data_dir+'_temp',
    pc_chan_avg_som_cluster_name=pc_chan_avg_som_cluster_name,
    pc_chan_avg_meta_cluster_name=pc_chan_avg_meta_cluster_name,
    overwrite=True
)

In [None]:
meta_avg_data = pd.read_csv(os.path.join(base_dir, pc_chan_avg_meta_cluster_name))
meta_to_rename = dict(remap_data[["pixel_meta_cluster", "pixel_meta_cluster_rename"]].values)
meta_avg_data["pixel_meta_cluster_rename"] = meta_avg_data["pixel_meta_cluster"].map(meta_to_rename)
meta_avg_data.to_csv(os.path.join(base_dir, pc_chan_avg_meta_cluster_name), index=False)

In [None]:
pixel_meta_cluster_remap_name = os.path.join(pixel_output_dir, "20220707_full_cohort_pixel_meta_cluster_mapping.csv")

In [None]:
import pandas as pd
remap_data_file = os.path.join(r'Z:\Noah Greenwald\TNBC_Cohorts\BELLINI\pixie\20231009_test_cohort_pixel_output_dir', "20220707_full_cohort_pixel_meta_cluster_mapping.csv")
remap_data = pd.read_csv(remap_data_file)

In [None]:
subset_pixel_fovs = fovs

In [None]:
if img_sub_folder is None:
    chan_file = os.path.join(
        io_utils.list_files(os.path.join(tiff_dir, fovs[0]), substrs=['.tiff'])[0]
    )
else:
    chan_file = os.path.join(
        img_sub_folder, io_utils.list_files(os.path.join(tiff_dir, fovs[0], img_sub_folder), substrs=['.tiff'])[0]
    )
data_utils.generate_and_save_pixel_cluster_masks(
    fovs=subset_pixel_fovs,
    base_dir=base_dir,
    save_dir=os.path.join(base_dir, pixel_output_dir),
    tiff_dir=tiff_dir,
    chan_file=chan_file,
    pixel_data_dir=pixel_data_dir+'_temp',
    pixel_cluster_col='pixel_meta_cluster',
    sub_dir='pixel_masks',
    name_suffix='_pixel_mask',
)

Save the colored pixel masks for each FOV in `subset_pixel_fovs`.

In [None]:
plot_utils.save_colored_masks(
    fovs=subset_pixel_fovs,
    mask_dir=os.path.join(base_dir, pixel_output_dir, "pixel_masks"),
    save_dir=os.path.join(base_dir, pixel_output_dir, "pixel_mask_colored"),
    cluster_id_to_name_path=os.path.join(base_dir, pixel_meta_cluster_remap_name),
    metacluster_colors=raw_cmap,
    cluster_type="pixel"
)

In [None]:
for pixel_fov in subset_pixel_fovs:
    pixel_cluster_mask = load_utils.load_imgs_from_dir(
        data_dir=os.path.join(base_dir, pixel_output_dir, "pixel_masks"),
        files=[pixel_fov + "_pixel_mask.tiff"],
        trim_suffix="_pixel_mask",
        match_substring="_pixel_mask",
        xr_dim_name="pixel_mask",
        xr_channel_names=None,
    )

    plot_utils.plot_pixel_cell_cluster(
        pixel_cluster_mask,
        [pixel_fov],
        os.path.join(base_dir, pixel_meta_cluster_remap_name),
        metacluster_colors=raw_cmap
    )

In [None]:
cell_clustering_params = {
    'fovs': io_utils.remove_file_extensions(io_utils.list_files(os.path.join(base_dir, pixel_data_dir), substrs='.feather')),
    'channels': channels,
    'tiff_dir': tiff_dir,
    'img_sub_folder': img_sub_folder,
    'segmentation_dir': segmentation_dir,
    'seg_suffix': seg_suffix,
    'pixel_data_dir': pixel_data_dir,
    'pc_chan_avg_som_cluster_name': pc_chan_avg_som_cluster_name,
    'pc_chan_avg_meta_cluster_name': pc_chan_avg_meta_cluster_name
}

# save the params dict
with open(os.path.join(base_dir, pixel_output_dir, 'cell_clustering_params.json'), 'w') as fh:
    json.dump(cell_clustering_params, fh)

In [None]:
plot_utils.create_mantis_dir(
    fovs=subset_pixel_fovs,
    mantis_project_path=os.path.join(base_dir, "2023-10-31_pixel_mantis_calprotectin_thresholded_0_015"), # viz
    img_data_path=tiff_dir,
    mask_output_dir=os.path.join(base_dir, pixel_output_dir, "pixel_masks"),
    mapping = os.path.join(base_dir, pixel_meta_cluster_remap_name),
    seg_dir=pixie_seg_dir,
    mask_suffix="_pixel_mask",
    seg_suffix_name=seg_suffix,
    img_sub_folder=img_sub_folder
)