# Automating initial thymic region segmentation 



In [None]:
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
from sklearn.mixture import GaussianMixture
import tifffile
from pathlib import Path
import os
import cv2
import re
import typing
str_or_path = typing.Union[str, os.PathLike]
from deco import synchronized, concurrent

## Functions

In [None]:
def remove_tif(img_dir, image_to_remove= "image.tif") -> None:
    img_path = [os.path.join(img_dir, img_file) for img_file in os.listdir(img_dir) if re.search(image_to_remove, img_file) and not img_file.startswith(".")]
    assert len(img_path)  == 1, f"More than one file to remove in {img_dir}"
    img_path = img_path[0]
    os.remove(img_path)
    return None
    
def setup_image(image_dir, glob_str= "image.tif", channel_loc= 2, median_ksize= 11, mean_ksize = (11,11), normalize= False):
    image_name = [file_path for file_path in Path(image_dir).rglob(glob_str)]
    if len(image_name) > 1:
        print(image_dir + " has more than one tif")
    full_img  = tifffile.imread(os.path.join(image_dir, *image_name))
    img = full_img[channel_loc, ...]
    
    if normalize: 
        ## Normalize imaging data
        scaled_img = (img - np.mean(img)) / np.std(img)
        norm_img = np.arcsinh(scaled_img)
        norm_img = norm_img.astype("uint8")
        ## Median blur
        med_blur = cv2.medianBlur(norm_img, median_ksize)  
    else: 
        ## Median blur
        med_blur = cv2.medianBlur(img, median_ksize)  

    ## Adding a mean blur instead to get an even blurry medulla. 
    mean_ksize = (11,11)
    mean_blur = cv2.blur(med_blur, mean_ksize)

    return mean_blur

def run_image_GMM(img, n_gaussians= 4, downsample_divisor= 4):
    ## Downsampling to decrease computation time
    reshaped_img = img.reshape((img.size, 1)).flatten()

    ## Removing the outliers to deal w/ some bright medulla that might've caused some issues. 
        ## I thought about filtering the bottom outliers, but the values are always going to be 0. 
    top_outliers = np.percentile(reshaped_img, [97.5])
    reshaped_img = reshaped_img[(reshaped_img < top_outliers)]
    downsampled_img = np.random.choice(a       = reshaped_img, 
                                       size    = reshaped_img.size // downsample_divisor, 
                                       replace = False)

    ## Add dimension for use w/ GMM
    downsampled_img= downsampled_img[..., np.newaxis]
    gm = GaussianMixture(n_components= n_gaussians)
    gm.fit(downsampled_img)
        ## This is the step that takes a while 

    thresholds = gm.means_.flatten()
    thresholds.sort()

    return thresholds


@concurrent
def GMM_wrapper(img, n):
    ## Making this wrapper to add deco concurrency to GMM function
    model = GaussianMixture(n_components= n).fit(img)
    return model


@synchronized # And we add this for the function which calls the concurrent function
def parallelize_GMMs(img, max_components):
    models = [GMM_wrapper(img= img, n= n) for n in np.arange(2, max_components+1)]
    return(models)


def find_n_gaussians(img, max_components):
    downsampled_img= np.random.choice(a        = img.reshape((img.size, 1)).flatten(), 
                                      size     = img.size // 4, 
                                      replace  = False)

    GMM_models = parallelize_GMMs(img= downsampled_img, max_components= max_components)

    n_components = np.arange(2, max_components)
    plt.plot(n_components, [m.bic(downsampled_img) for m in GMM_models], label='BIC')
    plt.plot(n_components, [m.aic(downsampled_img) for m in GMM_models], label='AIC')
    plt.legend(loc='best')
    plt.xlabel('n_components')


def make_threshold_masks(img, sorted_thresholds, morph_ksize= (75, 75), morph_median_ksize= 251):

    num_thresholds = sorted_thresholds.size

    threshold_masks = np.zeros((num_thresholds, img.shape[0], img.shape[1]))

    for i in range(0, num_thresholds):
        threshold  = sorted_thresholds[i]
        binary_img = img > threshold
        binary_img = binary_img * np.uint8(1)     

        ## Generate tissue shapes 
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, morph_ksize)
        binary_img = cv2.morphologyEx(binary_img, cv2.MORPH_CLOSE, kernel)

        ## Remove salt-and-pepper noise
        binary_img = cv2.medianBlur(binary_img, morph_median_ksize)

        ## Removing small islands
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_img)

        min_spot_area = 50_000
        filtered_img = np.zeros_like(binary_img)
        for label in range(1, num_labels):
            area = stats[label, cv2.CC_STAT_AREA]
            if area >= min_spot_area:
                filtered_img[labels == label] = 255

        ## Filling small holes
        contour, hier = cv2.findContours(filtered_img, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
        max_hole_area= 250_000

        for cnt in contour:
            area = cv2.contourArea(cnt)
            if area <= max_hole_area:
                cv2.drawContours(filtered_img,[cnt],0,255,-1)

        threshold_masks[i,...] = filtered_img
    return threshold_masks

def overlay_cortex_and_medulla(image_dir, medulla, cortex, out_name= "tissue_segmented_image.ome.tif", image_name= "reordered_image.ome.tif") -> None:
    medulla = medulla[np.newaxis,...]
    cortex  = cortex[np.newaxis,...]

    all_channels_img = tifffile.imread(os.path.join(image_dir, image_name))
    tissue_img = np.concatenate((all_channels_img, cortex, medulla))
    tifffile.imwrite(os.path.join(image_dir, out_name), tissue_img)
    return None 

@concurrent
def histocytometry_wrapper(image_dir, whole_lobe_threshold_rank, medullary_threshold_rank, save_masks= True, plot_thresholds= True, plot_tissues= True, setup_kwargs= {}, run_image_GMM_kwargs= {}, make_threshold_masks_kwargs= {}):
    print("Starting segmentation of " + image_dir + "\n")
    
    img = setup_image(image_dir = image_dir, **setup_kwargs)

    sorted_thresholds = run_image_GMM(img= img, **run_image_GMM_kwargs)

    threshold_masks = make_threshold_masks(img= img, sorted_thresholds= sorted_thresholds, **make_threshold_masks_kwargs)
    threshold_masks = threshold_masks.astype("uint8") 
        ## Hard-coding, but I could replace with img.dtype

    medulla_mask = threshold_masks[medullary_threshold_rank-1,...]
    whole_lobe_mask = threshold_masks[whole_lobe_threshold_rank-1,...]
    cortex_mask = whole_lobe_mask - medulla_mask

    if plot_tissues:
        _, ax = plt.subplots(1,4)
        ax[0].imshow(img)
        ax[0].set_title('Marker')
        ax[1].imshow(whole_lobe_mask, cmap= plt.cm.gray)
        ax[1].set_title('Whole lobe')
        ax[2].imshow(medulla_mask, cmap= plt.cm.gray)
        ax[2].set_title('Medulla')
        ax[3].imshow(cortex_mask, cmap= plt.cm.gray)
        ax[3].set_title("Cortex")
        plt.suptitle(image_dir) 
        plt.show()

    if "n_gaussians" in list(run_image_GMM_kwargs.keys()):
        n_gaussians= run_image_GMM_kwargs["n_gaussians"]
    else:
        n_gaussians= 4

    if plot_thresholds:
        _, ax = plt.subplots(1,n_gaussians)
        for i in range(0, n_gaussians):
            ax[i].imshow(threshold_masks[i,...], cmap= plt.cm.gray)
            ax[i].set_title(f'Gaussian {i+1}')
            ax[i].axison= False
        plt.suptitle(image_dir) 
        plt.show() 


    if save_masks:
        tifffile.imwrite(os.path.join(image_dir, "medulla_mask.tif"), medulla_mask)
        tifffile.imwrite(os.path.join(image_dir, "cortex_mask.tif"),  cortex_mask)  

        overlay_cortex_and_medulla(image_dir = image_dir, 
                                   medulla   = medulla_mask, 
                                   cortex    = cortex_mask)


@synchronized
def histocytometry_wrapper_parallelization(image_dirs, whole_lobe_threshold_rank, medullary_threshold_rank, n_gaussians= 4, save_masks= True, plot_tissues= False, plot_thresholds= True):
    for image_dir in image_dirs:
        histocytometry_wrapper(
            image_dir                 = image_dir, 
            setup_kwargs              = {"glob_str" : "reordered_image.ome.tif"}, 
            save_masks                = save_masks, 
            run_image_GMM_kwargs      = {"n_gaussians" : n_gaussians}, 
            plot_tissues              = plot_tissues,  
            plot_thresholds           = plot_thresholds, 
            whole_lobe_threshold_rank = whole_lobe_threshold_rank,
            medullary_threshold_rank  = medullary_threshold_rank)
    return None


def tissue_segmentation_plot(img_dir, img, thresholds, threshold_masks):
    print("Starting plot")
    _, ax = plt.subplots(1,len(thresholds)+1, figsize= (12,6))
    ax[0].imshow(img)
    ax[0].set_title("Marker")
    for i in range(1, len(thresholds)+1):
        ax[i].imshow(threshold_masks[i-1,...], cmap= plt.cm.gray)
        ax[i].set_title(f'Thresh: {round(thresholds[i-1], 1)}')
    plt.suptitle(img_dir) 
    plt.show() 

@concurrent
def tissue_segmentation(img_dir, thresholds, morph_ksize, channel_loc, out_name= "", glob_str= "reordered_image.ome.tif", save_mask= False, normalize= False) -> None:
    print("Starting " + img_dir + "\n")

    ## Take stain and apply thresholds
    img = setup_image(image_dir= img_dir, glob_str= glob_str,  channel_loc= channel_loc, normalize= normalize)
    threshold_masks = make_threshold_masks(img= img, sorted_thresholds= thresholds, morph_ksize= morph_ksize)
    threshold_masks = threshold_masks.astype("uint8") 

    tissue_segmentation_plot(
        img_dir         = img_dir, 
        img             = img,
        thresholds      = thresholds, 
        threshold_masks = threshold_masks
    )

    if save_mask:
        assert min(threshold_masks.shape) == 1, "There are multiple threshold masks. Make sure there is only one threshold in the thresholds argument."
        tifffile.imwrite(os.path.join(img_dir, out_name), threshold_masks.squeeze())

    return None


@synchronized
def parallelize_segmentation(img_dirs, thresholds, morph_ksize, channel_loc, save_mask, out_name= "", glob_str= "reordered_image.ome.tif", normalize= False):
    ## Use this function to speed up plotting of whole lobe masks and medulla masks
    ## I'll visually inspect the results to pick the best thresholds
    for img_dir in img_dirs:
        tissue_segmentation(img_dir     = img_dir, 
                            thresholds  = thresholds, 
                            channel_loc = channel_loc, 
                            glob_str    = glob_str,
                            save_mask   = save_mask,
                            out_name    = out_name,
                            normalize   = normalize,
                            morph_ksize = morph_ksize)
        

@concurrent
def medulla_gmm_wrapper(img_dir, medullary_threshold_rank= 4, glob_str= "reordered_image.ome.tif", channel_loc= 2, downsample_divisor= 8, n_gaussians= 4, save_mask= False, out_name= "medulla_mask.tif", normalize= False) -> None:
    ## Take CD11c or CD63 stains 
    medulla_stain = setup_image(image_dir= img_dir, glob_str= glob_str,  channel_loc= channel_loc, normalize= normalize)
    
    ## Run various GMMs
    sorted_thresholds = run_image_GMM(img= medulla_stain, n_gaussians= n_gaussians, downsample_divisor= downsample_divisor)

    threshold_masks = make_threshold_masks(img= medulla_stain, sorted_thresholds= sorted_thresholds)
    threshold_masks = threshold_masks.astype("uint8") 

    tissue_segmentation_plot(
        img_dir         = img_dir, 
        img             = medulla_stain,
        thresholds      = sorted_thresholds, 
        threshold_masks = threshold_masks
    )

    medulla_mask= threshold_masks[medullary_threshold_rank-1,...]

    if save_mask:
        tifffile.imwrite(os.path.join(img_dir, out_name), medulla_mask)

    return None


@synchronized
def parallelize_medulla(img_dirs, medullary_threshold_rank, save_mask= False, channel_loc= 2, n_gaussians= 4, normalize= False):
    for img_dir in img_dirs: 
        medulla_gmm_wrapper(
            img_dir                  = img_dir, 
            save_mask                = save_mask,
            channel_loc              = channel_loc, 
            medullary_threshold_rank = medullary_threshold_rank,
            n_gaussians              = n_gaussians,
            normalize                = normalize
        )

def make_cortex(img_dir, medulla_mask= "medulla_mask.tif", whole_lobe_mask= "whole_lobe_mask.tif") -> None:

    medulla_mask = tifffile.imread(os.path.join(img_dir, medulla_mask))
    whole_lobe_mask = tifffile.imread(os.path.join(img_dir, whole_lobe_mask))
    cortex_mask = whole_lobe_mask - medulla_mask
    cortex_mask = cortex_mask.astype("uint8") 
    
    tifffile.imwrite(os.path.join(img_dir, "cortex_mask.tif"), cortex_mask)
    return None    

In [None]:
## Pulling out the cleaned Medulla masks 
raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"
img_dirs = [os.path.join(raw_dir, img_dir) for img_dir in os.listdir(raw_dir) if re.search("_[A-D]$", img_dir) and not re.search("_Sirpa_C$", img_dir)]
img_dirs.sort()

for img_dir in img_dirs:
    img_w_medulla = tifffile.imread(os.path.join(img_dir, "img_w_medulla.ome.tif"))
    medulla_mask  = img_w_medulla[4,...]
    binary_medulla_mask = medulla_mask > 0 

    # plt.imshow(binary_medulla_mask)
    
    tifffile.imwrite(os.path.join(img_dir, "cleaned_medulla_mask.tif"), binary_medulla_mask)
    print(f"Done with {img_dir}")

In [None]:
## Saving cortex and medulla cortex
raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"
img_dirs = [os.path.join(raw_dir, img_dir) for img_dir in os.listdir(raw_dir) if re.search("_[A-D]$", img_dir) and not re.search("Sirpa_C$", img_dir)]
img_dirs.sort()

for img_dir in img_dirs:
    print(img_dir)
    img= tifffile.imread(os.path.join(img_dir, "reordered_image.ome.tif"))
    whole_lobe_mask = tifffile.imread(os.path.join(img_dir, "whole_lobe_mask.tif")).astype("uint8")
    ## Convert my poorly formatted whole lobe mask to 0 and 1 instead of 0 and 255
    whole_lobe_mask = whole_lobe_mask > 0
    cleaned_medulla_mask = tifffile.imread(os.path.join(img_dir, "cleaned_medulla_mask.tif")).astype("uint8")

    cortex_mask = whole_lobe_mask - cleaned_medulla_mask

    medulla        = img * cleaned_medulla_mask 
    cortex         = img * cortex_mask

    tifffile.imwrite(os.path.join(img_dir, "cleaned_medulla_img.ome.tif"), medulla)
    tifffile.imwrite(os.path.join(img_dir, "cleaned_cortex_img.ome.tif"), cortex)
    

## Whole lobe segmentation


In [None]:
raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"
img_dirs = [os.path.join(raw_dir, img_dir) for img_dir in os.listdir(raw_dir) if re.search("_[A-D]$", img_dir)]
img_dirs.sort()

In [None]:
## This is what I used to generate the whole_lobe_mask.tif
## I followed this up with some manual filling of holes and removing artifacts.  
if True:
    raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"
    img_dirs = [os.path.join(raw_dir, img_dir) for img_dir in os.listdir(raw_dir) if re.search("_[A-D]$", img_dir)]
    img_dirs.sort()

    thresholds = np.arange(13,14)

    parallelize_segmentation(
        img_dirs    = img_dirs, 
        thresholds  = thresholds, 
        channel_loc = 0, ## DAPI
        glob_str    = "reordered_image.ome.tif",
        save_mask   = True,
        out_name    = "whole_lobe_mask.tif",
        morph_ksize = (201, 201) 
    )

## Medullary segmentation

In [None]:
raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"
img_dirs= [os.path.join(raw_dir, img_dir) for img_dir in os.listdir(raw_dir) if re.search("_[A-D]$", img_dir)]
M3_dirs= ["20x_pan_DAPI_CD207_CD11c_XCR1_C", "20x_pan_DAPI_CD207_CD11c_XCR1_D", "20x_pan_DAPI_B220_CD11c_SiglecH_C", "20x_pan_DAPI_CD63_CD11c_Sirpa_B"]
CD14_dirs = [img_dir for img_dir in os.listdir(raw_dir) if re.search("CD14_[A-D]$", img_dir)]
CD14_dirs.sort()
M4_dirs = [os.path.join(raw_dir, img_dir) for img_dir in os.listdir(raw_dir) if re.search("_[A-D]", img_dir) and img_dir not in M3_dirs and img_dir not in CD14_dirs]
M4_dirs.sort()

CD14_dirs = [os.path.join(raw_dir, img_dir) for img_dir in CD14_dirs]
M3_dirs= [os.path.join(raw_dir, img_dir) for img_dir in M3_dirs]

img_dict = {
    "M4" : M4_dirs,
    "M3" : M3_dirs
}

In [None]:
raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"

## Using the third gaussian 
assigned_M3_dirs = ["CD63_CD11c_Sirpa_A",  
                    "B220_CD11c_SiglecH_A",
                    "Sirpa_CD11c_CD14_B", 
                    "B220_CD11c_SiglecH_C"] 
M3_dirs = [os.path.join(raw_dir, "20x_pan_DAPI_" + img_dir) for img_dir in assigned_M3_dirs]
M3_dirs.sort()

for img_dir in M3_dirs:
    print(img_dir)
    
parallelize_medulla(
    img_dirs                 = M3_dirs, 
    medullary_threshold_rank = 3, 
    save_mask                = True, 
    channel_loc              = 2, 
    n_gaussians              = 4, 
    normalize                = False
) 

In [None]:
raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"

## Using the fourth gaussian
assigned_M4_dirs = ["B220_CD11c_SiglecH_B", "B220_CD11c_SiglecH_C", 
                    "CD207_CD11c_XCR1_A", "CD207_CD11c_XCR1_B", "CD207_CD11c_XCR1_C", "CD207_CD11c_XCR1_D",
                    "CD63_CD11c_XCR1_A", "CD63_CD11c_XCR1_B", "CD63_CD11c_XCR1_C", "CD63_CD11c_XCR1_D",
                    "CD63_CD11c_Sirpa_B", "CD63_CD11c_Sirpa_D",
                    "Sirpa_CD63_MerTK_A", "Sirpa_CD63_MerTK_B", "Sirpa_CD63_MerTK_C", "Sirpa_CD63_MerTK_D",
                    "Sirpa_CD11c_CD14_A", "Sirpa_CD11c_CD14_C"]
M4_dirs = [os.path.join(raw_dir, "20x_pan_DAPI_" + img_dir) for img_dir in assigned_M4_dirs]
M4_dirs.sort()

for img_dir in M4_dirs:
    print(img_dir)

parallelize_medulla(
    img_dirs                 = M4_dirs, 
    medullary_threshold_rank = 4, 
    save_mask                = True, 
    channel_loc              = 2, 
    n_gaussians              = 4, 
    normalize                = False
) 

In [None]:
raw_dir= "/stor/scratch/Ehrlich/Users/John/histocytometry/raw_images/images_2023-08-10"
img_dirs = [os.path.join(raw_dir, img_dir) for img_dir in os.listdir(raw_dir) if re.search("_[A-D]$", img_dir) and not re.search("Sirpa_C$", img_dir)]
img_dirs.sort()
for img_dir in img_dirs:
    print(img_dir)

out_name = "img_w_medulla.ome.tif"
for img_dir in img_dirs:
    medulla= tifffile.imread(os.path.join(img_dir, "medulla_mask.tif"))
    medulla = medulla[np.newaxis,...]
    img    = tifffile.imread(os.path.join(img_dir, "reordered_image.ome.tif"))

    img_w_medulla= np.concatenate((img, medulla))
    tifffile.imwrite(os.path.join(img_dir, out_name), img_w_medulla)
    print(f"Done with {img_dir}")

## Making cortex out of medulla and whole lobe masks

In [None]:
for img_dir in img_dirs:
    make_cortex(img_dir, medulla_mask= "medulla_mask.tif", whole_lobe_mask= "whole_lobe_mask.tif")
    print(f"Done with {img_dir}")

## Overlay tissue masks on original image

In [None]:
for img_dir in img_dirs:
    medulla_mask = tifffile.imread(os.path.join(img_dir, "medulla_mask.tif"))
    cortex_mask  = tifffile.imread(os.path.join(img_dir, "cortex_mask.tif"))

    overlay_cortex_and_medulla(
        image_dir  = img_dir,
        medulla    = medulla_mask, 
        cortex     = cortex_mask, 
        out_name   = "tissue_segmented_image.ome.tif", 
        image_name = "reordered_image.ome.tif"
    )

<!--
## I wonder if I'm overfitting for my training image. 
## This really is crazy slow. I should think about writing this in julia or with cython as well as using a gamma model instead. 

## The parallelization messes up the print statements. There is probably a way to hold the prints until everything is done on one image, but I don't want to mess with that.

## I need to go through the images and decide which ones look good and which ones look bad. 
## Several images have what really looks like clear batch effects from the tiles. 
## The biggest issues with the images are the technical issues. 
    ## I'll have to go through and manually remove the bright spots. (I have to document which ones have blacked out regions.)
## I think picking the medulla and cortex images will have to be done manually. 

## Why didn't any of the images get printed? I think this was a one off jupyter issue, but I removed the image figure size to see if that did anything. 
## I don't think this was the issue, but it worked after that. 

## After this run, I'll try it with five gaussians. 

## I need to break this apart into images that are using medullary or cortical for each gaussian. 
-->