In [None]:
# Pixel classification using Pixie
# authors: Pacome Prompsy
# contact: pacome.prompsy@chuv.ch
# Guenova Lab
# CHUV (Centre Hospitalier Universitaire Vaudois), Lausanne, Suisse

# Scripts adapted from ARK analysis pipeline:
# https://github.com/angelolab/ark-analysis
#
# Greenwald, N.F., Miller, G., Moen, E. et al. Whole-cell segmentation of tissue 
# images with human-level performance using large-scale data annotation and deep learning.
# Nat Biotechnol 40, 555–565 (2022). https://doi.org/10.1038/s41587-021-01094-0
#
# Liu, C.C., Greenwald, N.F., Kong, A. et al. Robust phenotyping of highly multiplexed
# tissue imaging data using pixel-level clustering. Nat Commun 14, 4618 (2023).
# https://doi.org/10.1038/s41467-023-40068-5 

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import json, os, re, argparse, feather
from typing import List
from skimage.io import imread
from alpineer import image_utils, io_utils, load_utils, misc_utils
import glob
import pandas as pd
import os
import pathlib
import re
from typing import List, Union
from PIL import Image

import feather
import natsort as ns
import numpy as np
from tqdm.notebook import tqdm_notebook as tqdm
from ark.utils import data_utils
from ark import settings

###############################################################################
## Multiplexing util functions
###############################################################################

def read_markers(marker_file):
    with open(marker_file, "r") as fp:
         lines = fp.readlines()
    markers = []
    for l in lines:
        markers.append(l.replace("\n", ""))
    return(markers)
    
def qiSettings_to_markers(qiSettings_file):
    """
    Create input links correctly named and create output directory structure.
    Returns the new input directory containing links to TIFF files.
    """
    f = open(qiSettings_file)
    data = json.load(f)
    markers = []
    for item in data['data']:
        markers.append(item['comparisonData']['dye'].replace(" ", ""))
    return markers

def initialize_input(input_dir, output_dir, sample_name, markers):
    """
    Create input links correctly named and create output directory structure.
    Returns the new input directory containing links to TIFF files.
    """
    
    
    tiffdir = os.path.join(output_dir, "input", sample_name)
    if not os.path.exists(tiffdir):
        os.makedirs(tiffdir)
            
    TIFFs = os.listdir(os.path.join(input_dir, sample_name))
    links = dict.fromkeys(markers, "")

    for file in TIFFs:
        for marker in markers:
            if marker == "DAPI" or marker == "TRBC1":
                regex = re.compile(".*" + sample_name + ".*-"  + marker + ".*.tif.*")
            else:
                regex = re.compile(".*" + sample_name + ".*-"  + marker + "_.*.tif.*")
            
            if regex.match(file):
                links[marker] = file
                
    widthsX = dict()
    heigthY = dict()
    for marker in links:
        url = os.path.join(input_dir, sample_name, links[marker])
        img = Image.open(url)
        widthsX[marker] = int(img.size[0])
        heigthY [marker] = int(img.size[1])
        
    final_width = min(widthsX.values())
    final_heigth = min(heigthY.values())
    print("Final Width = ")
    print(final_width)
    
    print("Final Heigth = ")
    print(final_heigth)
    
    for marker in links:
        if not os.path.exists(os.path.join(tiffdir, marker + ".tiff")):
            print(marker)
            url = os.path.join(input_dir, sample_name, links[marker])
            img = Image.open(url)
             
            startx = img.size[0] - final_width
            starty = 0
            stopx = img.size[0]
            stopy = final_heigth
        
            img_cropped = img.crop((startx, starty, stopx, stopy))
            img_cropped.save(os.path.join(tiffdir, marker + '.tiff'))

    return os.path.join(output_dir, "input") 

###############################################################################
## General util functions
###############################################################################

def dir_path(string):
    if os.path.isdir(string):
        return string
    else:
        raise NotADirectoryError(string)


def parse_var(s):
    """
    Parse a key, value pair, separated by '='
    That's the reverse of ShellArgs.

    On the command line (argparse) a declaration will typically look like:
        foo=hello
    or
        foo="hello world"
    """
    items = s.split('=')
    key = items[0].strip() # we remove blanks around keys, as is logical
    if len(items) > 1:
        # rejoin the rest:
        value = '='.join(items[1:])
    return (key, value)



def parse_vars(items, to = "int"):
    """
    Parse a series of key-value pairs and return a dictionary
    """
    d = {}

    if items:
        for item in items:
            key, value = parse_var(item)
            if to == "int":
                d[key] = int(value)
            elif to == "float":
                d[key] = float(value)
            elif to == "bool":
                d[key] = value == 'True'
            elif to == "str":
                d[key] = str(value)
            else :
                d[key] = value
    return d


def filter_with_nuclear_mask(fovs: List, tiff_dir: str, seg_dir: str, channel: str,
                             nuc_seg_suffix: str = "_nuclear.tiff", img_sub_folder: str = None,
                             exclude: bool = True):
    """
    Filters out background staining using subcellular marker localization.

    Non-nuclear signal is removed from nuclear markers and vice-versa for membrane markers.

    Args:
        fovs (list):
            The list of fovs to filter
        tiff_dir (str):
            Name of the directory containing the tiff files
        seg_dir (str):
            Name of the directory containing the segmented files
        channel (str):
            Channel to apply filtering to
        nuc_seg_suffix (str):
            The suffix for the nuclear channel.
            (i.e. for "fov1", a suffix of "_nuclear.tiff" would make a file named
            "fov1_nuclear.tiff")
        img_sub_folder (str):
            Name of the subdirectory inside `tiff_dir` containing the tiff files.
            Set to `None` if there isn't any.
        exclude (bool):
            Whether to filter out nuclear or membrane signal
    """
    # if seg_dir is None, the user cannot run filtering
    if seg_dir is None:
        print('No seg_dir provided, you must provide one to run nuclear filtering')
        return

    # raise an error if the provided seg_dir does not exist
    io_utils.validate_paths(seg_dir)

    # convert to path-compatible format
    if img_sub_folder is None:
        img_sub_folder = ''

    for fov in fovs:
        # load the channel image in
        img = load_utils.load_imgs_from_tree(data_dir=tiff_dir, img_sub_folder=img_sub_folder,
                                             fovs=[fov], channels=[channel]).values[0, :, :, 0]

        # load the segmented image in
        seg_img_name: str = f"{fov}{nuc_seg_suffix}"
        seg_img = imread(os.path.join(seg_dir, seg_img_name))

        # mask out the nucleus
        if exclude:
            suffix = "_nuc_exclude.tiff"
            seg_mask = seg_img > 0
        # mask out the membrane
        else:
            suffix = "_nuc_include.tiff"
            seg_mask = seg_img == 0

        # filter out the nucleus or membrane depending on exclude parameter
        img[seg_mask] = 0

        # save filtered image
        image_utils.save_image(os.path.join(tiff_dir, fov, img_sub_folder, channel + suffix), img)
        return

# define the cell table path
def combine_cell_tables(cell_table_dir, samples, cell_table_prefix = "cell_table_size_normalized"):
    """
    Combine multiple single-fov matrices into a large matrix    
    """
    cell_table = []

    for samp in samples:
        df = pd.read_csv(os.path.join(cell_table_dir, samp + "_" + cell_table_prefix + ".csv.gz"))
        cell_table.append(df)
    
    cell_table = pd.concat(cell_table, axis=0)
    cell_table_path = os.path.join(cell_table_dir, cell_table_prefix + ".csv.gz")
    cell_table.to_csv(cell_table_path)
    return(cell_table_path)
    

def generate_and_save_pixel_cluster_masks(fovs: List[str],
                                          base_dir: Union[pathlib.Path, str],
                                          save_dir: Union[pathlib.Path, str],
                                          tiff_dir: Union[pathlib.Path, str],
                                          chan_file: Union[pathlib.Path, str],
                                          pixel_data_dir: Union[pathlib.Path, str],
                                          pixel_cluster_col: str = 'pixel_meta_cluster',
                                          sub_dir: str = None,
                                          name_suffix: str = ''):
    """Generates pixel cluster masks and saves them for downstream analysis.

    Args:
        fovs (List[str]):
            A list of fovs to generate and save pixel masks for.
        base_dir (Union[pathlib.Path, str]):
            The path to the data directory.
        save_dir (Union[pathlib.Path, str]):
            The directory to save the generated pixel cluster masks.
        tiff_dir (Union[pathlib.Path, str]):
            The path to the directory with the tiff data.
        chan_file (Union[pathlib.Path, str]):
            The path to the channel file inside each FOV folder (FOV folder as root).
            Used to determine dimensions of the pixel mask.
        pixel_data_dir (Union[pathlib.Path, str]):
            The path to the data with full pixel data.
            This data should also have the SOM and meta cluster labels appended.
        pixel_cluster_col (str, optional):
            The path to the data with full pixel data.
            This data should also have the SOM and meta cluster labels appended.
            Defaults to 'pixel_meta_cluster'.
        sub_dir (str, optional):
            The subdirectory to save the images in. If specified images are saved to
            `"data_dir/sub_dir"`. If `sub_dir = None` the images are saved to `"data_dir"`.
            Defaults to `None`.
        name_suffix (str, optional):
            Specify what to append at the end of every pixel mask. Defaults to `''`.
    """

    # create the pixel cluster masks across each fov
    with tqdm(total=len(fovs), desc="Pixel Cluster Mask Generation") as pixel_mask_progress:
        for fov in fovs:
            # define the path to provided channel file in the fov dir, used to calculate dimensions
            chan_file_path = os.path.join(fov, chan_file)

            # generate the pixel mask for the FOV
            pixel_mask: np.ndarray =\
                generate_pixel_cluster_mask(fov=fov, base_dir=base_dir, tiff_dir=tiff_dir,
                                            chan_file_path=chan_file_path,
                                            pixel_data_dir=pixel_data_dir,
                                            pixel_cluster_col=pixel_cluster_col)

            # save the pixel mask generated
            data_utils.save_fov_mask(fov, data_dir=save_dir, mask_data=pixel_mask, sub_dir=sub_dir,
                          name_suffix=name_suffix)

            pixel_mask_progress.update(1)


def generate_pixel_cluster_mask(fov, base_dir, tiff_dir, chan_file_path,
                                pixel_data_dir, pixel_cluster_col='pixel_meta_cluster'):
    """For a fov, create a mask labeling each pixel with their SOM or meta cluster label

    Args:
        fov (list):
            The fov to relabel
        base_dir (str):
            The path to the data directory
        tiff_dir (str):
            The path to the tiff data
        chan_file_path (str):
            The path to the sample channel file to load (`tiff_dir` as root).
            Used to determine dimensions of the pixel mask.
        pixel_data_dir (str):
            The path to the data with full pixel data.
            This data should also have the SOM and meta cluster labels appended.
        pixel_cluster_col (str):
            Whether to assign SOM or meta clusters
            needs to be `'pixel_som_cluster'` or `'pixel_meta_cluster'`

    Returns:
        numpy.ndarray:
            The image overlaid with pixel cluster labels
    """

    # path checking
    io_utils.validate_paths([tiff_dir, os.path.join(tiff_dir, chan_file_path),
                             os.path.join(base_dir, pixel_data_dir)])

    # verify the pixel_cluster_col provided is valid
    misc_utils.verify_in_list(
        provided_cluster_col=[pixel_cluster_col],
        valid_cluster_cols=['pixel_som_cluster', 'pixel_meta_cluster']
    )

    # verify the fov is valid
    misc_utils.verify_in_list(
        provided_fov_file=[fov + '.feather'],
        consensus_fov_files=os.listdir(os.path.join(base_dir, pixel_data_dir))
    )

    # read the sample channel file to determine size of pixel cluster mask
    channel_data = np.squeeze(imread(os.path.join(tiff_dir, chan_file_path)))

    # define an array to hold the overlays for the fov
    # use int16 to allow for Photoshop loading
    img_data = np.zeros((channel_data.shape[0], channel_data.shape[1]), dtype='int16')

    fov_data = feather.read_dataframe(
        os.path.join(base_dir, pixel_data_dir, fov + '.feather')
    )

    # ensure integer display and not float
    fov_data[pixel_cluster_col] = fov_data[pixel_cluster_col].astype(int)

    # get the pixel coordinates
    x_coords = fov_data['row_index'].values
    y_coords = fov_data['column_index'].values

    # convert to 1D indexing
    coordinates = x_coords * img_data.shape[1] + y_coords

    # get the cooresponding cluster labels for each pixel
    cluster_labels = list(fov_data[pixel_cluster_col])

    # assign each coordinate in pixel_cluster_mask to its respective cluster label
    img_subset = img_data.ravel()
    img_subset[coordinates] = cluster_labels
    img_data = img_subset.reshape(img_data.shape)

    return img_data

In [None]:
class C:
    pass
args = C()
args.output_dir = "../output/"
args.name = "region"
args.proportion = 0.05
args.smoothing_channels = None
args.smoothing_factor = 6
args.blur_factor = 2
args.filter_channels = None
args.max_k = 20

In [None]:
import os 
print()
print("Running Pixel Classification...")
print()
print("######################################################################################")     
print("Output directory = %s" % args.output_dir)
print("Name = %s" % args.name)
samples = os.listdir(os.path.join(args.output_dir, "input")) ################## !!!!!!!!!!!!!!!!!!!!!!!!!!!!
#samples = [samples[18]]  ################## !!!!!!!!!!!!!!!!!!!!!!!!!!!!
print("Samples = %s" % samples)
print("Proportion = %s" % args.proportion)
print("Blur factor = %s" % args.blur_factor)
print("Smoothing factor = %s" % args.smoothing_factor)
print("Maximum K = %s" % args.max_k)


segmentation_dir = os.path.join(args.output_dir, "segmentation")
cell_table_dir = os.path.join(args.output_dir, "cell_table")

if args.smoothing_channels is not None:
    smoothing_channels = [str(s.strip()) for s in args.smoothing_channels.split(",")]
else:
    smoothing_channels = None

if args.filter_channels is not None:
        filter_channels = parse_vars(args.filter_channels, "bool")
else:
    filter_channels = None

print("Smoothing channels = %s" % smoothing_channels)
print("Filter background for channels = %s" % filter_channels)
print("######################################################################################")    
print()

In [None]:
import json
import os
import subprocess
from datetime import datetime as dt

import feather
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib import rc_file_defaults
from alpineer import io_utils, load_utils
from ark.analysis import visualize
from ark.phenotyping import pixel_cluster_utils
from ark.phenotyping.pixel_cluster_utils import find_fovs_missing_col
from ark.utils import data_utils, example_dataset, plot_utils
from ark.utils.metacluster_remap_gui import (MetaClusterData, MetaClusterGui,
                                             colormap_helper,
                                             metaclusterdata_from_files)
import time

In [None]:
from ark.phenotyping import pixel_som_clustering
from ark.phenotyping.pixie_preprocessing import  create_pixel_matrix
import ark.phenotyping.pixel_som_clustering as psc
import ark.phenotyping.pixel_cluster_utils as pcu
import ark.phenotyping.pixel_meta_clustering as pmc


In [None]:
from ark.phenotyping.pixel_meta_clustering import apply_pixel_meta_cluster_remapping

In [None]:
from ark.phenotyping.pixel_meta_clustering import generate_remap_avg_files

In [None]:

start_time = time.time()
tiff_dir = os.path.join(args.output_dir, "input")
img_sub_folder = None
seg_suffix = '_whole_cell.tiff'

# set to True to turn on multiprocessing
multiprocess = False

# define the number of samples to process in parallel, ignored if multiprocessing is set to False
batch_size = 5

pixel_cluster_prefix = args.name

# define the base output pixel folder using the specified pixel cluster prefix
pixel_output_dir = os.path.join("pixie", "%s_pixel_output_dir" % pixel_cluster_prefix)
if not os.path.exists(os.path.join(args.output_dir, pixel_output_dir)):
    os.makedirs(os.path.join(args.output_dir, pixel_output_dir))
 
# define the preprocessed pixel data folders
pixel_data_dir = os.path.join(pixel_output_dir, 'pixel_mat_data')
pixel_subset_dir = os.path.join(pixel_output_dir, 'pixel_mat_subset')
norm_vals_name = os.path.join(pixel_output_dir, 'channel_norm_post_rowsum.feather')
pixel_som_weights_name = os.path.join(pixel_output_dir, 'pixel_som_weights.feather')
pc_chan_avg_som_cluster_name = os.path.join(pixel_output_dir, 'pixel_channel_avg_som_cluster.csv')
pc_chan_avg_meta_cluster_name = os.path.join(pixel_output_dir, 'pixel_channel_avg_meta_cluster.csv')
pixel_meta_cluster_remap_name = os.path.join(pixel_output_dir, 'pixel_meta_cluster_mapping.csv')

In [None]:
# channels = io_utils.list_files(os.path.join(tiff_dir, samples[0]), substrs=['.tiff'])
# channels = [s.replace('.tiff', '') for s in channels]
# channels = np.array(channels, dtype='<U32')
print("Smoothing markers...")
# set an optional list of markers for additional blurring


In [None]:
channels_metadata = pd.read_csv(os.path.join(args.output_dir, "..", "annotation", "marker_metadata.csv"))

In [None]:
channels = channels_metadata.Marker[channels_metadata.UseForRegionPixelClustering].tolist()
print("Channels:")
print(channels)
channels = np.array(channels, dtype='<U32')

In [None]:
smoothing_channels = channels

In [None]:
if smoothing_channels is not None:
    pixel_cluster_utils.smooth_channels(
        fovs=samples,
        tiff_dir=tiff_dir,
        img_sub_folder=img_sub_folder,
        channels=smoothing_channels,
        smooth_vals=args.smoothing_factor
    )
    mask = np.isin(channels, smoothing_channels)
    channels[mask] = np.char.add(channels[mask], "_smoothed")



In [None]:
smoothing_channels

In [None]:
#filter_channels_list = channels_metadata.KeepOutsideNucleus[channels_metadata.UseForPixelClustering].tolist()
#n = 0
#filter_channels = dict()
#for i in channels:
#    if filter_channels_list[n] != "None":
#        filter_channels[channels[n]] = filter_channels_list[n]
#    n = n+1 
#filter_channels

In [None]:
channels

In [None]:
print("Filtering background for markers...")
if filter_channels is not None:
    for channel in filter_channels:
        for samp in samples:
            nuclear_exclude = filter_channels[channel] == 'True'
            filter_with_nuclear_mask(
                fovs=[samp],
                tiff_dir=tiff_dir,
                seg_dir = segmentation_dir,
                channel = channel,
                nuc_seg_suffix="_nuclear.tiff",
                img_sub_folder = img_sub_folder,
                exclude=nuclear_exclude
            )
            if nuclear_exclude == True:
                mask = np.isin(channels, channel)
                channels[mask] = np.char.add(channels[mask], "_nuc_exclude")
            else:
                mask = np.isin(channels, channel)
                channels[mask] = np.char.add(channels[mask], "_nuc_include")

In [None]:
print("Creating Pixel Matrix...")

# run pixel data preprocessing
create_pixel_matrix(
    samples,
    channels,
    args.output_dir,
    tiff_dir,
    segmentation_dir,
    img_sub_folder=img_sub_folder,
    seg_suffix=seg_suffix,
    pixel_output_dir=pixel_output_dir,
    data_dir=pixel_data_dir,
    subset_dir=pixel_subset_dir,
    norm_vals_name=norm_vals_name,
    blur_factor=args.blur_factor,
    subset_proportion=args.proportion,
    multiprocess=multiprocess,
    batch_size=batch_size
)

print("Training SOM Pixel clustering")

# create the pixel-level SOM weights
pixel_pysom = pixel_som_clustering.train_pixel_som(  
    samples,
    channels.tolist(),
    args.output_dir,
    subset_dir=pixel_subset_dir,
    norm_vals_name=norm_vals_name,
    som_weights_name=pixel_som_weights_name,
    num_passes=1,
    seed=42
)

# use pixel SOM weights to assign pixel clusters
pixel_som_clustering.cluster_pixels(
    samples,
    channels.tolist(),
    args.output_dir,
    pixel_pysom,
    data_dir=pixel_data_dir,
    multiprocess=multiprocess,
    batch_size=batch_size
)

# generate the SOM cluster summary files
pixel_som_clustering.generate_som_avg_files(
    samples,
    channels.tolist(),
    args.output_dir,
    pixel_pysom,
    data_dir=pixel_data_dir,
    pc_chan_avg_som_cluster_name=pc_chan_avg_som_cluster_name
)

In [None]:
print("Run pixel consensus clustering")
cap = 3

# Forcing re-running of consensus clustering as if not "pixel_cc" is not initialized
fovs_list = find_fovs_missing_col(args.output_dir, pixel_data_dir, 'pixel_meta_cluster')

if len(fovs_list) == 0:
    data_path = os.path.join(args.output_dir, pixel_data_dir)
    fov_files = io_utils.list_files(data_path, substrs='.feather')
    for file in fov_files:
        fov_data = feather.read_dataframe(os.path.join(data_path, file))
        fov_data = fov_data.drop('pixel_meta_cluster', axis=1)
        feather.write_dataframe(df= fov_data, dest= os.path.join(data_path, file))


In [None]:
# run hierarchical clustering based on pixel SOM cluster assignments
pixel_cc = pmc.pixel_consensus_cluster(
    samples,
    channels.tolist(),
    args.output_dir,
    max_k=args.max_k,
    cap=cap,
    data_dir=pixel_data_dir,
    pc_chan_avg_som_cluster_name=pc_chan_avg_som_cluster_name,
    multiprocess=multiprocess,
    batch_size=batch_size
)


# generate the meta cluster summary files
pmc.generate_meta_avg_files(
    samples,
    channels.tolist(),
    args.output_dir,
    pixel_cc,
    data_dir=pixel_data_dir,
    pc_chan_avg_som_cluster_name=pc_chan_avg_som_cluster_name,
    pc_chan_avg_meta_cluster_name=pc_chan_avg_meta_cluster_name
)

# Adding column for later
data_path = os.path.join(args.output_dir, pixel_data_dir)
fov_files = io_utils.list_files(data_path, substrs='.feather')
for file in fov_files:
    fov_data = feather.read_dataframe(os.path.join(data_path, file))
    fov_data["pixel_meta_cluster_rename"] = fov_data["pixel_meta_cluster"]
    feather.write_dataframe(df= fov_data, dest= os.path.join(data_path, file))



In [None]:
pc_chan_avg_som_cluster_name

In [None]:
%matplotlib widget

print("Plotting Heatmap")
# Heatmap
rc_file_defaults()
plt.ion()

pixel_mcd = metaclusterdata_from_files(
    os.path.join(args.output_dir, pc_chan_avg_som_cluster_name),
    cluster_type='pixel'
)
pixel_mcd.output_mapping_filename = os.path.join(args.output_dir, pixel_meta_cluster_remap_name)
pixel_mcg = MetaClusterGui(pixel_mcd, width=10,  enable_throttle=True)

In [None]:
cell_mcg

In [None]:
pixel_meta_cluster_remap_name

In [None]:
generate_remap_avg_files

In [None]:
# IF not running the remapping:
#df = pd.DataFrame(pixel_mcd.mapping)
#df["mc_name"] = df['metacluster']

#df2 = df.copy()
#df2.columns = ['pixel_meta_cluster', 'pixel_meta_cluster_rename']
#df2.index.name = "pixel_som_cluster"
#df2.reset_index(inplace=True)
#df2.to_csv( os.path.join(args.output_dir, pixel_meta_cluster_remap_name), index=False)




In [None]:
# rename the meta cluster values in the pixel dataset
apply_pixel_meta_cluster_remapping(
    samples,
    channels,
    args.output_dir,
    pixel_data_dir,
    pixel_meta_cluster_remap_name,
    multiprocess=multiprocess,
    batch_size=batch_size
)


In [None]:
# recompute the mean channel expression per meta cluster and apply these new names to the SOM cluster average data
generate_remap_avg_files(
    samples,
    smoothing_channels,
    args.output_dir,
    pixel_data_dir,
    pixel_meta_cluster_remap_name,
    pc_chan_avg_som_cluster_name,
    pc_chan_avg_meta_cluster_name
)

In [None]:
raw_cmap, _ = colormap_helper.generate_meta_cluster_colormap_dict(
    os.path.join(args.output_dir, pixel_meta_cluster_remap_name),
    pixel_mcg.im_cl.cmap
)

In [None]:
subset_pixel_fovs = samples

print("Generating pixel masks...")
# define the path to the channel file
if img_sub_folder is None:
    chan_file = os.path.join(
        io_utils.list_files(os.path.join(tiff_dir, samples[0]), substrs=['.tiff'])[0]
    )
else:
    chan_file = os.path.join(
        img_sub_folder, io_utils.list_files(os.path.join(tiff_dir, samples[0], img_sub_folder), substrs=['.tiff'])[0]
    )

In [None]:
generate_and_save_pixel_cluster_masks(
    fovs=subset_pixel_fovs,
    base_dir=args.output_dir,
    save_dir=os.path.join(args.output_dir, pixel_output_dir),
    tiff_dir=tiff_dir,
    chan_file=chan_file,
    pixel_data_dir=pixel_data_dir,
    pixel_cluster_col='pixel_meta_cluster',
    sub_dir='pixel_masks',
    name_suffix='_pixel_mask',
)

In [None]:
for pixel_fov in subset_pixel_fovs:
    pixel_cluster_mask = load_utils.load_imgs_from_dir(
        data_dir=os.path.join(args.output_dir, pixel_output_dir, "pixel_masks"),
        files=[pixel_fov + "_pixel_mask.tiff"],
        trim_suffix="_pixel_mask",
        match_substring="_pixel_mask",
        xr_dim_name="pixel_mask",
        xr_channel_names=None,
    )

    plot_utils.plot_pixel_cell_cluster_overlay(
        pixel_cluster_mask,
        [pixel_fov],
        os.path.join(args.output_dir, pixel_meta_cluster_remap_name),
        metacluster_colors = raw_cmap
    )
    plt.savefig(os.path.join(args.output_dir, pixel_output_dir, "pixel_masks", pixel_fov + "_overlay_meta_clusters" +'.pdf'))

In [None]:
# define the params dict
cell_clustering_params = {
    'fovs': io_utils.remove_file_extensions(io_utils.list_files(os.path.join(args.output_dir, pixel_data_dir), substrs='.feather')),
    'channels': channels.tolist(),
    'segmentation_dir': segmentation_dir,
    'seg_suffix': seg_suffix,
    'pixel_data_dir': pixel_data_dir,
    'pc_chan_avg_som_cluster_name': pc_chan_avg_som_cluster_name,
    'pc_chan_avg_meta_cluster_name': pc_chan_avg_meta_cluster_name
}

# save the params dict
with open(os.path.join(args.output_dir, pixel_output_dir, 'cell_clustering_params.json'), 'w') as fh:
    json.dump(cell_clustering_params, fh)
    
print('Finished Running Pixel Classification in %s seconds', round(time.time() - start_time, 2))